MLIR  22.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Inliner that uses a basic inlining
10 // algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11 // of the CallGraph. This enables a more incremental propagation of inlining
12 // decisions from the leafs to the roots of the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 #include "mlir/IR/Threading.h"
22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/DebugLog.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
31 
32 //===----------------------------------------------------------------------===//
33 // Symbol Use Tracking
34 //===----------------------------------------------------------------------===//
35 
36 /// Walk all of the used symbol callgraph nodes referenced with the given op.
38  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
40  function_ref<void(CallGraphNode *, Operation *)> callback) {
41  auto symbolUses = SymbolTable::getSymbolUses(op);
42  assert(symbolUses && "expected uses to be valid");
43 
44  Operation *symbolTableOp = op->getParentOp();
45  for (const SymbolTable::SymbolUse &use : *symbolUses) {
46  auto refIt = resolvedRefs.try_emplace(use.getSymbolRef());
47  CallGraphNode *&node = refIt.first->second;
48 
49  // If this is the first instance of this reference, try to resolve a
50  // callgraph node for it.
51  if (refIt.second) {
52  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
53  use.getSymbolRef());
54  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
55  if (!callableOp)
56  continue;
57  node = cg.lookupNode(callableOp.getCallableRegion());
58  }
59  if (node)
60  callback(node, use.getUser());
61  }
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // CGUseList
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This struct tracks the uses of callgraph nodes that can be dropped when
70 /// use_empty. It directly tracks and manages a use-list for all of the
71 /// call-graph nodes. This is necessary because many callgraph nodes are
72 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
73 /// class.
74 struct CGUseList {
75  /// This struct tracks the uses of callgraph nodes within a specific
76  /// operation.
77  struct CGUser {
78  /// Any nodes referenced in the top-level attribute list of this user. We
79  /// use a set here because the number of references does not matter.
80  DenseSet<CallGraphNode *> topLevelUses;
81 
82  /// Uses of nodes referenced by nested operations.
84  };
85 
86  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
87 
88  /// Drop uses of nodes referred to by the given call operation that resides
89  /// within 'userNode'.
90  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
91 
92  /// Remove the given node from the use list.
93  void eraseNode(CallGraphNode *node);
94 
95  /// Returns true if the given callgraph node has no uses and can be pruned.
96  bool isDead(CallGraphNode *node) const;
97 
98  /// Returns true if the given callgraph node has a single use and can be
99  /// discarded.
100  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
101 
102  /// Recompute the uses held by the given callgraph node.
103  void recomputeUses(CallGraphNode *node, CallGraph &cg);
104 
105  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
106  /// of 'lhs' into 'rhs'.
107  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
108 
109 private:
110  /// Decrement the uses of discardable nodes referenced by the given user.
111  void decrementDiscardableUses(CGUser &uses);
112 
113  /// A mapping between a discardable callgraph node (that is a symbol) and the
114  /// number of uses for this node.
115  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
116 
117  /// A mapping between a callgraph node and the symbol callgraph nodes that it
118  /// uses.
120 
121  /// A symbol table to use when resolving call lookups.
122  SymbolTableCollection &symbolTable;
123 };
124 } // namespace
125 
126 CGUseList::CGUseList(Operation *op, CallGraph &cg,
127  SymbolTableCollection &symbolTable)
128  : symbolTable(symbolTable) {
129  /// A set of callgraph nodes that are always known to be live during inlining.
130  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
131 
132  // Walk each of the symbol tables looking for discardable callgraph nodes.
133  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
134  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
135  // If this is a callgraph operation, check to see if it is discardable.
136  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
137  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
138  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
139  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
140  symbol.canDiscardOnUseEmpty()) {
141  discardableSymNodeUses.try_emplace(node, 0);
142  }
143  continue;
144  }
145  }
146  // Otherwise, check for any referenced nodes. These will be always-live.
147  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
148  [](CallGraphNode *, Operation *) {});
149  }
150  };
151  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
152  walkFn);
153 
154  // Drop the use information for any discardable nodes that are always live.
155  for (auto &it : alwaysLiveNodes)
156  discardableSymNodeUses.erase(it.second);
157 
158  // Compute the uses for each of the callable nodes in the graph.
159  for (CallGraphNode *node : cg)
160  recomputeUses(node, cg);
161 }
162 
163 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
164  CallGraph &cg) {
165  auto &userRefs = nodeUses[userNode].innerUses;
166  auto walkFn = [&](CallGraphNode *node, Operation *user) {
167  auto parentIt = userRefs.find(node);
168  if (parentIt == userRefs.end())
169  return;
170  --parentIt->second;
171  --discardableSymNodeUses[node];
172  };
174  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
175 }
176 
177 void CGUseList::eraseNode(CallGraphNode *node) {
178  // Drop all child nodes.
179  for (auto &edge : *node)
180  if (edge.isChild())
181  eraseNode(edge.getTarget());
182 
183  // Drop the uses held by this node and erase it.
184  auto useIt = nodeUses.find(node);
185  assert(useIt != nodeUses.end() && "expected node to be valid");
186  decrementDiscardableUses(useIt->getSecond());
187  nodeUses.erase(useIt);
188  discardableSymNodeUses.erase(node);
189 }
190 
191 bool CGUseList::isDead(CallGraphNode *node) const {
192  // If the parent operation isn't a symbol, simply check normal SSA deadness.
193  Operation *nodeOp = node->getCallableRegion()->getParentOp();
194  if (!isa<SymbolOpInterface>(nodeOp))
195  return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
196 
197  // Otherwise, check the number of symbol uses.
198  auto symbolIt = discardableSymNodeUses.find(node);
199  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
200 }
201 
202 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
203  // If this isn't a symbol node, check for side-effects and SSA use count.
204  Operation *nodeOp = node->getCallableRegion()->getParentOp();
205  if (!isa<SymbolOpInterface>(nodeOp))
206  return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
207 
208  // Otherwise, check the number of symbol uses.
209  auto symbolIt = discardableSymNodeUses.find(node);
210  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
211 }
212 
213 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
214  Operation *parentOp = node->getCallableRegion()->getParentOp();
215  CGUser &uses = nodeUses[node];
216  decrementDiscardableUses(uses);
217 
218  // Collect the new discardable uses within this node.
219  uses = CGUser();
221  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
222  auto discardSymIt = discardableSymNodeUses.find(refNode);
223  if (discardSymIt == discardableSymNodeUses.end())
224  return;
225 
226  if (user != parentOp)
227  ++uses.innerUses[refNode];
228  else if (!uses.topLevelUses.insert(refNode).second)
229  return;
230  ++discardSymIt->second;
231  };
232  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
233 }
234 
235 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
236  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
237  for (auto &useIt : lhsUses.innerUses) {
238  rhsUses.innerUses[useIt.first] += useIt.second;
239  discardableSymNodeUses[useIt.first] += useIt.second;
240  }
241 }
242 
243 void CGUseList::decrementDiscardableUses(CGUser &uses) {
244  for (CallGraphNode *node : uses.topLevelUses)
245  --discardableSymNodeUses[node];
246  for (auto &it : uses.innerUses)
247  discardableSymNodeUses[it.first] -= it.second;
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // CallGraph traversal
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 /// This class represents a specific callgraph SCC.
256 class CallGraphSCC {
257 public:
258  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
259  : parentIterator(parentIterator) {}
260  /// Return a range over the nodes within this SCC.
261  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
262  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
263 
264  /// Reset the nodes of this SCC with those provided.
265  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
266 
267  /// Remove the given node from this SCC.
268  void remove(CallGraphNode *node) {
269  auto it = llvm::find(nodes, node);
270  if (it != nodes.end()) {
271  nodes.erase(it);
272  parentIterator.ReplaceNode(node, nullptr);
273  }
274  }
275 
276 private:
277  std::vector<CallGraphNode *> nodes;
278  llvm::scc_iterator<const CallGraph *> &parentIterator;
279 };
280 } // namespace
281 
282 /// Run a given transformation over the SCCs of the callgraph in a bottom up
283 /// traversal.
284 static LogicalResult runTransformOnCGSCCs(
285  const CallGraph &cg,
286  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
287  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
288  CallGraphSCC currentSCC(cgi);
289  while (!cgi.isAtEnd()) {
290  // Copy the current SCC and increment so that the transformer can modify the
291  // SCC without invalidating our iterator.
292  currentSCC.reset(*cgi);
293  ++cgi;
294  if (failed(sccTransformer(currentSCC)))
295  return failure();
296  }
297  return success();
298 }
299 
300 /// Collect all of the callable operations within the given range of blocks. If
301 /// `traverseNestedCGNodes` is true, this will also collect call operations
302 /// inside of nested callgraph nodes.
304  CallGraphNode *sourceNode, CallGraph &cg,
305  SymbolTableCollection &symbolTable,
307  bool traverseNestedCGNodes) {
309  auto addToWorklist = [&](CallGraphNode *node,
311  for (Block &block : blocks)
312  worklist.emplace_back(&block, node);
313  };
314 
315  addToWorklist(sourceNode, blocks);
316  while (!worklist.empty()) {
317  Block *block;
318  std::tie(block, sourceNode) = worklist.pop_back_val();
319 
320  for (Operation &op : *block) {
321  if (auto call = dyn_cast<CallOpInterface>(op)) {
322  // TODO: Support inlining nested call references.
323  CallInterfaceCallable callable = call.getCallableForCallee();
324  if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
325  if (!isa<FlatSymbolRefAttr>(symRef))
326  continue;
327  }
328 
329  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
330  if (!targetNode->isExternal())
331  calls.emplace_back(call, sourceNode, targetNode);
332  continue;
333  }
334 
335  // If this is not a call, traverse the nested regions. If
336  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
337  // regions.
338  for (auto &nestedRegion : op.getRegions()) {
339  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
340  if (traverseNestedCGNodes || !nestedNode)
341  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
342  }
343  }
344  }
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // InlinerInterfaceImpl
349 //===----------------------------------------------------------------------===//
350 
351 static std::string getNodeName(CallOpInterface op) {
352  if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
353  return debugString(op);
354  return "_unnamed_callee_";
355 }
356 
357 /// Return true if the specified `inlineHistoryID` indicates an inline history
358 /// that already includes `node`.
360  CallGraphNode *node, std::optional<size_t> inlineHistoryID,
361  MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
362  inlineHistory) {
363  while (inlineHistoryID.has_value()) {
364  assert(*inlineHistoryID < inlineHistory.size() &&
365  "Invalid inline history ID");
366  if (inlineHistory[*inlineHistoryID].first == node)
367  return true;
368  inlineHistoryID = inlineHistory[*inlineHistoryID].second;
369  }
370  return false;
371 }
372 
373 namespace {
374 /// This class provides a specialization of the main inlining interface.
375 struct InlinerInterfaceImpl : public InlinerInterface {
376  InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
377  SymbolTableCollection &symbolTable)
378  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
379 
380  /// Process a set of blocks that have been inlined. This callback is invoked
381  /// *before* inlined terminator operations have been processed.
382  void
384  // Find the closest callgraph node from the first block.
385  CallGraphNode *node;
386  Region *region = inlinedBlocks.begin()->getParent();
387  while (!(node = cg.lookupNode(region))) {
388  region = region->getParentRegion();
389  assert(region && "expected valid parent node");
390  }
391 
392  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
393  /*traverseNestedCGNodes=*/true);
394  }
395 
396  /// Mark the given callgraph node for deletion.
397  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
398 
399  /// This method properly disposes of callables that became dead during
400  /// inlining. This should not be called while iterating over the SCCs.
401  void eraseDeadCallables() {
402  for (CallGraphNode *node : deadNodes)
403  node->getCallableRegion()->getParentOp()->erase();
404  }
405 
406  /// The set of callables known to be dead.
408 
409  /// The current set of call instructions to consider for inlining.
411 
412  /// The callgraph being operated on.
413  CallGraph &cg;
414 
415  /// A symbol table to use when resolving call lookups.
416  SymbolTableCollection &symbolTable;
417 };
418 } // namespace
419 
420 namespace mlir {
421 
423 public:
424  Impl(Inliner &inliner) : inliner(inliner) {}
425 
426  /// Attempt to inline calls within the given scc, and run simplifications,
427  /// until a fixed point is reached. This allows for the inlining of newly
428  /// devirtualized calls. Returns failure if there was a fatal error during
429  /// inlining.
430  LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
431  CGUseList &useList, CallGraphSCC &currentSCC,
432  MLIRContext *context);
433 
434 private:
435  /// Optimize the nodes within the given SCC with one of the held optimization
436  /// pass pipelines. Returns failure if an error occurred during the
437  /// optimization of the SCC, success otherwise.
438  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
439  CallGraphSCC &currentSCC, MLIRContext *context);
440 
441  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
442  /// error occurred during the optimization of the SCC, success otherwise.
443  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
444  MLIRContext *context);
445 
446  /// Optimize the given callable node with one of the pass managers provided
447  /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
448  /// an error occurred during the optimization of the callable, success
449  /// otherwise.
450  LogicalResult optimizeCallable(CallGraphNode *node,
451  llvm::StringMap<OpPassManager> &pipelines);
452 
453  /// Attempt to inline calls within the given scc. This function returns
454  /// success if any calls were inlined, failure otherwise.
455  LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
456  CGUseList &useList, CallGraphSCC &currentSCC);
457 
458  /// Returns true if the given call should be inlined.
459  bool shouldInline(ResolvedCall &resolvedCall);
460 
461 private:
462  Inliner &inliner;
464 };
465 
466 LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
467  CGUseList &useList,
468  CallGraphSCC &currentSCC,
469  MLIRContext *context) {
470  // Continuously simplify and inline until we either reach a fixed point, or
471  // hit the maximum iteration count. Simplifying early helps to refine the cost
472  // model, and in future iterations may devirtualize new calls.
473  unsigned iterationCount = 0;
474  do {
475  if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
476  return failure();
477  if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
478  break;
479  } while (++iterationCount < inliner.config.getMaxInliningIterations());
480  return success();
481 }
482 
483 LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
484  CallGraphSCC &currentSCC,
485  MLIRContext *context) {
486  // Collect the sets of nodes to simplify.
487  SmallVector<CallGraphNode *, 4> nodesToVisit;
488  for (auto *node : currentSCC) {
489  if (node->isExternal())
490  continue;
491 
492  // Don't simplify nodes with children. Nodes with children require special
493  // handling as we may remove the node during simplification. In the future,
494  // we should be able to handle this case with proper node deletion tracking.
495  if (node->hasChildren())
496  continue;
497 
498  // We also won't apply simplifications to nodes that can't have passes
499  // scheduled on them.
500  auto *region = node->getCallableRegion();
502  continue;
503  nodesToVisit.push_back(node);
504  }
505  if (nodesToVisit.empty())
506  return success();
507 
508  // Optimize each of the nodes within the SCC in parallel.
509  if (failed(optimizeSCCAsync(nodesToVisit, context)))
510  return failure();
511 
512  // Recompute the uses held by each of the nodes.
513  for (CallGraphNode *node : nodesToVisit)
514  useList.recomputeUses(node, cg);
515  return success();
516 }
517 
518 LogicalResult
519 Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
520  MLIRContext *ctx) {
521  // We must maintain a fixed pool of pass managers which is at least as large
522  // as the maximum parallelism of the failableParallelForEach below.
523  // Note: The number of pass managers here needs to remain constant
524  // to prevent issues with pass instrumentations that rely on having the same
525  // pass manager for the main thread.
526  size_t numThreads = ctx->getNumThreads();
527  const auto &opPipelines = inliner.config.getOpPipelines();
528  if (pipelines.size() < numThreads) {
529  pipelines.reserve(numThreads);
530  pipelines.resize(numThreads, opPipelines);
531  }
532 
533  // Ensure an analysis manager has been constructed for each of the nodes.
534  // This prevents thread races when running the nested pipelines.
535  for (CallGraphNode *node : nodesToVisit)
536  inliner.am.nest(node->getCallableRegion()->getParentOp());
537 
538  // An atomic failure variable for the async executors.
539  std::vector<std::atomic<bool>> activePMs(pipelines.size());
540  llvm::fill(activePMs, false);
541  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
542  // Find a pass manager for this operation.
543  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
544  bool expectedInactive = false;
545  return isActive.compare_exchange_strong(expectedInactive, true);
546  });
547  assert(it != activePMs.end() &&
548  "could not find inactive pass manager for thread");
549  unsigned pmIndex = it - activePMs.begin();
550 
551  // Optimize this callable node.
552  LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
553 
554  // Reset the active bit for this pass manager.
555  activePMs[pmIndex].store(false);
556  return result;
557  });
558 }
559 
560 LogicalResult
561 Inliner::Impl::optimizeCallable(CallGraphNode *node,
562  llvm::StringMap<OpPassManager> &pipelines) {
563  Operation *callable = node->getCallableRegion()->getParentOp();
564  StringRef opName = callable->getName().getStringRef();
565  auto pipelineIt = pipelines.find(opName);
566  const auto &defaultPipeline = inliner.config.getDefaultPipeline();
567  if (pipelineIt == pipelines.end()) {
568  // If a pipeline didn't exist, use the generic pipeline if possible.
569  if (!defaultPipeline)
570  return success();
571 
572  OpPassManager defaultPM(opName);
573  defaultPipeline(defaultPM);
574  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
575  }
576  return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
577 }
578 
579 /// Attempt to inline calls within the given scc. This function returns
580 /// success if any calls were inlined, failure otherwise.
581 LogicalResult
582 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
583  CGUseList &useList, CallGraphSCC &currentSCC) {
584  CallGraph &cg = inlinerIface.cg;
585  auto &calls = inlinerIface.calls;
586 
587  // A set of dead nodes to remove after inlining.
588  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
589 
590  // Collect all of the direct calls within the nodes of the current SCC. We
591  // don't traverse nested callgraph nodes, because they are handled separately
592  // likely within a different SCC.
593  for (CallGraphNode *node : currentSCC) {
594  if (node->isExternal())
595  continue;
596 
597  // Don't collect calls if the node is already dead.
598  if (useList.isDead(node)) {
599  deadNodes.insert(node);
600  } else {
601  collectCallOps(*node->getCallableRegion(), node, cg,
602  inlinerIface.symbolTable, calls,
603  /*traverseNestedCGNodes=*/false);
604  }
605  }
606 
607  // When inlining a callee produces new call sites, we want to keep track of
608  // the fact that they were inlined from the callee. This allows us to avoid
609  // infinite inlining.
610  using InlineHistoryT = std::optional<size_t>;
612  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
613 
614  LLVM_DEBUG({
615  LDBG() << "* Inliner: Initial calls in SCC are: {";
616  for (unsigned i = 0, e = calls.size(); i < e; ++i)
617  LDBG() << " " << i << ". " << calls[i].call << ",";
618  LDBG() << "}";
619  });
620 
621  // Try to inline each of the call operations. Don't cache the end iterator
622  // here as more calls may be added during inlining.
623  bool inlinedAnyCalls = false;
624  for (unsigned i = 0; i < calls.size(); ++i) {
625  if (deadNodes.contains(calls[i].sourceNode))
626  continue;
627  ResolvedCall it = calls[i];
628 
629  InlineHistoryT inlineHistoryID = callHistory[i];
630  bool inHistory =
631  inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
632  bool doInline = !inHistory && shouldInline(it);
633  CallOpInterface call = it.call;
634  LLVM_DEBUG({
635  if (doInline)
636  LDBG() << "* Inlining call: " << i << ". " << call;
637  else
638  LDBG() << "* Not inlining call: " << i << ". " << call;
639  });
640  if (!doInline)
641  continue;
642 
643  unsigned prevSize = calls.size();
644  Region *targetRegion = it.targetNode->getCallableRegion();
645 
646  // If this is the last call to the target node and the node is discardable,
647  // then inline it in-place and delete the node if successful.
648  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
649 
650  LogicalResult inlineResult =
651  inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
652  cast<CallableOpInterface>(targetRegion->getParentOp()),
653  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
654  if (failed(inlineResult)) {
655  LDBG() << "** Failed to inline";
656  continue;
657  }
658  inlinedAnyCalls = true;
659 
660  // Create a inline history entry for this inlined call, so that we remember
661  // that new callsites came about due to inlining Callee.
662  InlineHistoryT newInlineHistoryID{inlineHistory.size()};
663  inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
664 
665  auto historyToString = [](InlineHistoryT h) {
666  return h.has_value() ? std::to_string(*h) : "root";
667  };
668  LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
669  << getNodeName(call) << ", " << historyToString(inlineHistoryID)
670  << "]";
671 
672  for (unsigned k = prevSize; k != calls.size(); ++k) {
673  callHistory.push_back(newInlineHistoryID);
674  LDBG() << "* new call " << k << " {" << calls[k].call
675  << "}\n with historyID = " << newInlineHistoryID
676  << ", added due to inlining of\n call {" << call
677  << "}\n with historyID = " << historyToString(inlineHistoryID);
678  }
679 
680  // If the inlining was successful, Merge the new uses into the source node.
681  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
682  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
683 
684  // then erase the call.
685  call.erase();
686 
687  // If we inlined in place, mark the node for deletion.
688  if (inlineInPlace) {
689  useList.eraseNode(it.targetNode);
690  deadNodes.insert(it.targetNode);
691  }
692  }
693 
694  for (CallGraphNode *node : deadNodes) {
695  currentSCC.remove(node);
696  inlinerIface.markForDeletion(node);
697  }
698  calls.clear();
699  return success(inlinedAnyCalls);
700 }
701 
702 /// Returns true if the given call should be inlined.
703 bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
704  // Don't allow inlining terminator calls. We currently don't support this
705  // case.
706  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
707  return false;
708 
709  // Don't allow inlining if the target is a self-recursive function.
710  // Don't allow inlining if the call graph is like A->B->A.
711  if (llvm::count_if(*resolvedCall.targetNode,
712  [&](CallGraphNode::Edge const &edge) -> bool {
713  return edge.getTarget() == resolvedCall.targetNode ||
714  edge.getTarget() == resolvedCall.sourceNode;
715  }) > 0)
716  return false;
717 
718  // Don't allow inlining if the target is an ancestor of the call. This
719  // prevents inlining recursively.
720  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
721  if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
722  return false;
723 
724  // Don't allow inlining if the callee has multiple blocks (unstructured
725  // control flow) but we cannot be sure that the caller region supports that.
726  if (!inliner.config.getCanHandleMultipleBlocks()) {
727  bool calleeHasMultipleBlocks =
728  llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
729  // If both parent ops have the same type, it is safe to inline. Otherwise,
730  // decide based on whether the op has the SingleBlock trait or not.
731  // Note: This check does currently not account for
732  // SizedRegion/MaxSizedRegion.
733  auto callerRegionSupportsMultipleBlocks = [&]() {
734  return callableRegion->getParentOp()->getName() ==
735  resolvedCall.call->getParentOp()->getName() ||
736  !resolvedCall.call->getParentOp()
737  ->mightHaveTrait<OpTrait::SingleBlock>();
738  };
739  if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
740  return false;
741  }
742 
743  if (!inliner.isProfitableToInline(resolvedCall))
744  return false;
745 
746  // Otherwise, inline.
747  return true;
748 }
749 
750 LogicalResult Inliner::doInlining() {
751  Impl impl(*this);
752  auto *context = op->getContext();
753  // Run the inline transform in post-order over the SCCs in the callgraph.
754  SymbolTableCollection symbolTable;
755  // FIXME: some clean-up can be done for the arguments
756  // of the Impl's methods, if the inlinerIface and useList
757  // become the states of the Impl.
758  InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
759  CGUseList useList(op, cg, symbolTable);
760  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
761  return impl.inlineSCC(inlinerIface, useList, scc, context);
762  });
763  if (failed(result))
764  return result;
765 
766  // After inlining, make sure to erase any callables proven to be dead.
767  inlinerIface.eraseDeadCallables();
768  return success();
769 }
770 } // namespace mlir
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:303
Inliner::ResolvedCall ResolvedCall
Definition: Inliner.cpp:30
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:37
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
Definition: Inliner.cpp:359
static std::string getNodeName(CallOpInterface op)
Definition: Inliner.cpp:351
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:284
Block represents an ordered list of Operations.
Definition: Block.h:33
This class represents a directed edge between two nodes in the callgraph.
Definition: CallGraph.h:43
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:32
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:59
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:36
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:147
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:139
unsigned getMaxInliningIterations() const
Definition: Inliner.h:42
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC &currentSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
Definition: Inliner.cpp:466
Impl(Inliner &inliner)
Definition: Inliner.cpp:424
This is an implementation of the inliner that operates bottom up over the Strongly Connected Componen...
Definition: Inliner.h:103
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
Definition: Inliner.cpp:750
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
iterator begin()
Definition: Region.h:55
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
Definition: SymbolTable.h:183
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
Definition: Inliner.h:109
CallGraphNode * sourceNode
Definition: Inliner.h:114
CallOpInterface call
Definition: Inliner.h:113
CallGraphNode * targetNode
Definition: Inliner.h:114
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:881