MLIR  20.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Inliner that uses a basic inlining
10 // algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11 // of the CallGraph. This enables a more incremental propagation of inlining
12 // decisions from the leafs to the roots of the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "inlining"
29 
30 using namespace mlir;
31 
33 
34 //===----------------------------------------------------------------------===//
35 // Symbol Use Tracking
36 //===----------------------------------------------------------------------===//
37 
38 /// Walk all of the used symbol callgraph nodes referenced with the given op.
40  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
42  function_ref<void(CallGraphNode *, Operation *)> callback) {
43  auto symbolUses = SymbolTable::getSymbolUses(op);
44  assert(symbolUses && "expected uses to be valid");
45 
46  Operation *symbolTableOp = op->getParentOp();
47  for (const SymbolTable::SymbolUse &use : *symbolUses) {
48  auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
49  CallGraphNode *&node = refIt.first->second;
50 
51  // If this is the first instance of this reference, try to resolve a
52  // callgraph node for it.
53  if (refIt.second) {
54  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
55  use.getSymbolRef());
56  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57  if (!callableOp)
58  continue;
59  node = cg.lookupNode(callableOp.getCallableRegion());
60  }
61  if (node)
62  callback(node, use.getUser());
63  }
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // CGUseList
68 
69 namespace {
70 /// This struct tracks the uses of callgraph nodes that can be dropped when
71 /// use_empty. It directly tracks and manages a use-list for all of the
72 /// call-graph nodes. This is necessary because many callgraph nodes are
73 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
74 /// class.
75 struct CGUseList {
76  /// This struct tracks the uses of callgraph nodes within a specific
77  /// operation.
78  struct CGUser {
79  /// Any nodes referenced in the top-level attribute list of this user. We
80  /// use a set here because the number of references does not matter.
81  DenseSet<CallGraphNode *> topLevelUses;
82 
83  /// Uses of nodes referenced by nested operations.
85  };
86 
87  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
88 
89  /// Drop uses of nodes referred to by the given call operation that resides
90  /// within 'userNode'.
91  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
92 
93  /// Remove the given node from the use list.
94  void eraseNode(CallGraphNode *node);
95 
96  /// Returns true if the given callgraph node has no uses and can be pruned.
97  bool isDead(CallGraphNode *node) const;
98 
99  /// Returns true if the given callgraph node has a single use and can be
100  /// discarded.
101  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
102 
103  /// Recompute the uses held by the given callgraph node.
104  void recomputeUses(CallGraphNode *node, CallGraph &cg);
105 
106  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
107  /// of 'lhs' into 'rhs'.
108  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
109 
110 private:
111  /// Decrement the uses of discardable nodes referenced by the given user.
112  void decrementDiscardableUses(CGUser &uses);
113 
114  /// A mapping between a discardable callgraph node (that is a symbol) and the
115  /// number of uses for this node.
116  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
117 
118  /// A mapping between a callgraph node and the symbol callgraph nodes that it
119  /// uses.
121 
122  /// A symbol table to use when resolving call lookups.
123  SymbolTableCollection &symbolTable;
124 };
125 } // namespace
126 
127 CGUseList::CGUseList(Operation *op, CallGraph &cg,
128  SymbolTableCollection &symbolTable)
129  : symbolTable(symbolTable) {
130  /// A set of callgraph nodes that are always known to be live during inlining.
131  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
132 
133  // Walk each of the symbol tables looking for discardable callgraph nodes.
134  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
135  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
136  // If this is a callgraph operation, check to see if it is discardable.
137  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
138  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
139  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
140  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
141  symbol.canDiscardOnUseEmpty()) {
142  discardableSymNodeUses.try_emplace(node, 0);
143  }
144  continue;
145  }
146  }
147  // Otherwise, check for any referenced nodes. These will be always-live.
148  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
149  [](CallGraphNode *, Operation *) {});
150  }
151  };
152  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
153  walkFn);
154 
155  // Drop the use information for any discardable nodes that are always live.
156  for (auto &it : alwaysLiveNodes)
157  discardableSymNodeUses.erase(it.second);
158 
159  // Compute the uses for each of the callable nodes in the graph.
160  for (CallGraphNode *node : cg)
161  recomputeUses(node, cg);
162 }
163 
164 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
165  CallGraph &cg) {
166  auto &userRefs = nodeUses[userNode].innerUses;
167  auto walkFn = [&](CallGraphNode *node, Operation *user) {
168  auto parentIt = userRefs.find(node);
169  if (parentIt == userRefs.end())
170  return;
171  --parentIt->second;
172  --discardableSymNodeUses[node];
173  };
175  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
176 }
177 
178 void CGUseList::eraseNode(CallGraphNode *node) {
179  // Drop all child nodes.
180  for (auto &edge : *node)
181  if (edge.isChild())
182  eraseNode(edge.getTarget());
183 
184  // Drop the uses held by this node and erase it.
185  auto useIt = nodeUses.find(node);
186  assert(useIt != nodeUses.end() && "expected node to be valid");
187  decrementDiscardableUses(useIt->getSecond());
188  nodeUses.erase(useIt);
189  discardableSymNodeUses.erase(node);
190 }
191 
192 bool CGUseList::isDead(CallGraphNode *node) const {
193  // If the parent operation isn't a symbol, simply check normal SSA deadness.
194  Operation *nodeOp = node->getCallableRegion()->getParentOp();
195  if (!isa<SymbolOpInterface>(nodeOp))
196  return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
197 
198  // Otherwise, check the number of symbol uses.
199  auto symbolIt = discardableSymNodeUses.find(node);
200  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
201 }
202 
203 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
204  // If this isn't a symbol node, check for side-effects and SSA use count.
205  Operation *nodeOp = node->getCallableRegion()->getParentOp();
206  if (!isa<SymbolOpInterface>(nodeOp))
207  return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
208 
209  // Otherwise, check the number of symbol uses.
210  auto symbolIt = discardableSymNodeUses.find(node);
211  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
212 }
213 
214 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
215  Operation *parentOp = node->getCallableRegion()->getParentOp();
216  CGUser &uses = nodeUses[node];
217  decrementDiscardableUses(uses);
218 
219  // Collect the new discardable uses within this node.
220  uses = CGUser();
222  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
223  auto discardSymIt = discardableSymNodeUses.find(refNode);
224  if (discardSymIt == discardableSymNodeUses.end())
225  return;
226 
227  if (user != parentOp)
228  ++uses.innerUses[refNode];
229  else if (!uses.topLevelUses.insert(refNode).second)
230  return;
231  ++discardSymIt->second;
232  };
233  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
234 }
235 
236 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
237  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
238  for (auto &useIt : lhsUses.innerUses) {
239  rhsUses.innerUses[useIt.first] += useIt.second;
240  discardableSymNodeUses[useIt.first] += useIt.second;
241  }
242 }
243 
244 void CGUseList::decrementDiscardableUses(CGUser &uses) {
245  for (CallGraphNode *node : uses.topLevelUses)
246  --discardableSymNodeUses[node];
247  for (auto &it : uses.innerUses)
248  discardableSymNodeUses[it.first] -= it.second;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // CallGraph traversal
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This class represents a specific callgraph SCC.
257 class CallGraphSCC {
258 public:
259  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
260  : parentIterator(parentIterator) {}
261  /// Return a range over the nodes within this SCC.
262  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
263  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
264 
265  /// Reset the nodes of this SCC with those provided.
266  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
267 
268  /// Remove the given node from this SCC.
269  void remove(CallGraphNode *node) {
270  auto it = llvm::find(nodes, node);
271  if (it != nodes.end()) {
272  nodes.erase(it);
273  parentIterator.ReplaceNode(node, nullptr);
274  }
275  }
276 
277 private:
278  std::vector<CallGraphNode *> nodes;
279  llvm::scc_iterator<const CallGraph *> &parentIterator;
280 };
281 } // namespace
282 
283 /// Run a given transformation over the SCCs of the callgraph in a bottom up
284 /// traversal.
285 static LogicalResult runTransformOnCGSCCs(
286  const CallGraph &cg,
287  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
288  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
289  CallGraphSCC currentSCC(cgi);
290  while (!cgi.isAtEnd()) {
291  // Copy the current SCC and increment so that the transformer can modify the
292  // SCC without invalidating our iterator.
293  currentSCC.reset(*cgi);
294  ++cgi;
295  if (failed(sccTransformer(currentSCC)))
296  return failure();
297  }
298  return success();
299 }
300 
301 /// Collect all of the callable operations within the given range of blocks. If
302 /// `traverseNestedCGNodes` is true, this will also collect call operations
303 /// inside of nested callgraph nodes.
305  CallGraphNode *sourceNode, CallGraph &cg,
306  SymbolTableCollection &symbolTable,
308  bool traverseNestedCGNodes) {
310  auto addToWorklist = [&](CallGraphNode *node,
312  for (Block &block : blocks)
313  worklist.emplace_back(&block, node);
314  };
315 
316  addToWorklist(sourceNode, blocks);
317  while (!worklist.empty()) {
318  Block *block;
319  std::tie(block, sourceNode) = worklist.pop_back_val();
320 
321  for (Operation &op : *block) {
322  if (auto call = dyn_cast<CallOpInterface>(op)) {
323  // TODO: Support inlining nested call references.
324  CallInterfaceCallable callable = call.getCallableForCallee();
325  if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
326  if (!isa<FlatSymbolRefAttr>(symRef))
327  continue;
328  }
329 
330  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
331  if (!targetNode->isExternal())
332  calls.emplace_back(call, sourceNode, targetNode);
333  continue;
334  }
335 
336  // If this is not a call, traverse the nested regions. If
337  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
338  // regions.
339  for (auto &nestedRegion : op.getRegions()) {
340  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
341  if (traverseNestedCGNodes || !nestedNode)
342  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
343  }
344  }
345  }
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // InlinerInterfaceImpl
350 //===----------------------------------------------------------------------===//
351 
352 #ifndef NDEBUG
353 static std::string getNodeName(CallOpInterface op) {
354  if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
355  return debugString(op);
356  return "_unnamed_callee_";
357 }
358 #endif
359 
360 /// Return true if the specified `inlineHistoryID` indicates an inline history
361 /// that already includes `node`.
363  CallGraphNode *node, std::optional<size_t> inlineHistoryID,
364  MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
365  inlineHistory) {
366  while (inlineHistoryID.has_value()) {
367  assert(*inlineHistoryID < inlineHistory.size() &&
368  "Invalid inline history ID");
369  if (inlineHistory[*inlineHistoryID].first == node)
370  return true;
371  inlineHistoryID = inlineHistory[*inlineHistoryID].second;
372  }
373  return false;
374 }
375 
376 namespace {
377 /// This class provides a specialization of the main inlining interface.
378 struct InlinerInterfaceImpl : public InlinerInterface {
379  InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
380  SymbolTableCollection &symbolTable)
381  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
382 
383  /// Process a set of blocks that have been inlined. This callback is invoked
384  /// *before* inlined terminator operations have been processed.
385  void
387  // Find the closest callgraph node from the first block.
388  CallGraphNode *node;
389  Region *region = inlinedBlocks.begin()->getParent();
390  while (!(node = cg.lookupNode(region))) {
391  region = region->getParentRegion();
392  assert(region && "expected valid parent node");
393  }
394 
395  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
396  /*traverseNestedCGNodes=*/true);
397  }
398 
399  /// Mark the given callgraph node for deletion.
400  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
401 
402  /// This method properly disposes of callables that became dead during
403  /// inlining. This should not be called while iterating over the SCCs.
404  void eraseDeadCallables() {
405  for (CallGraphNode *node : deadNodes)
406  node->getCallableRegion()->getParentOp()->erase();
407  }
408 
409  /// The set of callables known to be dead.
411 
412  /// The current set of call instructions to consider for inlining.
414 
415  /// The callgraph being operated on.
416  CallGraph &cg;
417 
418  /// A symbol table to use when resolving call lookups.
419  SymbolTableCollection &symbolTable;
420 };
421 } // namespace
422 
423 namespace mlir {
424 
426 public:
427  Impl(Inliner &inliner) : inliner(inliner) {}
428 
429  /// Attempt to inline calls within the given scc, and run simplifications,
430  /// until a fixed point is reached. This allows for the inlining of newly
431  /// devirtualized calls. Returns failure if there was a fatal error during
432  /// inlining.
433  LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
434  CGUseList &useList, CallGraphSCC &currentSCC,
435  MLIRContext *context);
436 
437 private:
438  /// Optimize the nodes within the given SCC with one of the held optimization
439  /// pass pipelines. Returns failure if an error occurred during the
440  /// optimization of the SCC, success otherwise.
441  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
442  CallGraphSCC &currentSCC, MLIRContext *context);
443 
444  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
445  /// error occurred during the optimization of the SCC, success otherwise.
446  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
447  MLIRContext *context);
448 
449  /// Optimize the given callable node with one of the pass managers provided
450  /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
451  /// an error occurred during the optimization of the callable, success
452  /// otherwise.
453  LogicalResult optimizeCallable(CallGraphNode *node,
454  llvm::StringMap<OpPassManager> &pipelines);
455 
456  /// Attempt to inline calls within the given scc. This function returns
457  /// success if any calls were inlined, failure otherwise.
458  LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
459  CGUseList &useList, CallGraphSCC &currentSCC);
460 
461  /// Returns true if the given call should be inlined.
462  bool shouldInline(ResolvedCall &resolvedCall);
463 
464 private:
465  Inliner &inliner;
467 };
468 
469 LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
470  CGUseList &useList,
471  CallGraphSCC &currentSCC,
472  MLIRContext *context) {
473  // Continuously simplify and inline until we either reach a fixed point, or
474  // hit the maximum iteration count. Simplifying early helps to refine the cost
475  // model, and in future iterations may devirtualize new calls.
476  unsigned iterationCount = 0;
477  do {
478  if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
479  return failure();
480  if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
481  break;
482  } while (++iterationCount < inliner.config.getMaxInliningIterations());
483  return success();
484 }
485 
486 LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
487  CallGraphSCC &currentSCC,
488  MLIRContext *context) {
489  // Collect the sets of nodes to simplify.
490  SmallVector<CallGraphNode *, 4> nodesToVisit;
491  for (auto *node : currentSCC) {
492  if (node->isExternal())
493  continue;
494 
495  // Don't simplify nodes with children. Nodes with children require special
496  // handling as we may remove the node during simplification. In the future,
497  // we should be able to handle this case with proper node deletion tracking.
498  if (node->hasChildren())
499  continue;
500 
501  // We also won't apply simplifications to nodes that can't have passes
502  // scheduled on them.
503  auto *region = node->getCallableRegion();
505  continue;
506  nodesToVisit.push_back(node);
507  }
508  if (nodesToVisit.empty())
509  return success();
510 
511  // Optimize each of the nodes within the SCC in parallel.
512  if (failed(optimizeSCCAsync(nodesToVisit, context)))
513  return failure();
514 
515  // Recompute the uses held by each of the nodes.
516  for (CallGraphNode *node : nodesToVisit)
517  useList.recomputeUses(node, cg);
518  return success();
519 }
520 
521 LogicalResult
522 Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
523  MLIRContext *ctx) {
524  // We must maintain a fixed pool of pass managers which is at least as large
525  // as the maximum parallelism of the failableParallelForEach below.
526  // Note: The number of pass managers here needs to remain constant
527  // to prevent issues with pass instrumentations that rely on having the same
528  // pass manager for the main thread.
529  size_t numThreads = ctx->getNumThreads();
530  const auto &opPipelines = inliner.config.getOpPipelines();
531  if (pipelines.size() < numThreads) {
532  pipelines.reserve(numThreads);
533  pipelines.resize(numThreads, opPipelines);
534  }
535 
536  // Ensure an analysis manager has been constructed for each of the nodes.
537  // This prevents thread races when running the nested pipelines.
538  for (CallGraphNode *node : nodesToVisit)
539  inliner.am.nest(node->getCallableRegion()->getParentOp());
540 
541  // An atomic failure variable for the async executors.
542  std::vector<std::atomic<bool>> activePMs(pipelines.size());
543  std::fill(activePMs.begin(), activePMs.end(), false);
544  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
545  // Find a pass manager for this operation.
546  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
547  bool expectedInactive = false;
548  return isActive.compare_exchange_strong(expectedInactive, true);
549  });
550  assert(it != activePMs.end() &&
551  "could not find inactive pass manager for thread");
552  unsigned pmIndex = it - activePMs.begin();
553 
554  // Optimize this callable node.
555  LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
556 
557  // Reset the active bit for this pass manager.
558  activePMs[pmIndex].store(false);
559  return result;
560  });
561 }
562 
563 LogicalResult
564 Inliner::Impl::optimizeCallable(CallGraphNode *node,
565  llvm::StringMap<OpPassManager> &pipelines) {
566  Operation *callable = node->getCallableRegion()->getParentOp();
567  StringRef opName = callable->getName().getStringRef();
568  auto pipelineIt = pipelines.find(opName);
569  const auto &defaultPipeline = inliner.config.getDefaultPipeline();
570  if (pipelineIt == pipelines.end()) {
571  // If a pipeline didn't exist, use the generic pipeline if possible.
572  if (!defaultPipeline)
573  return success();
574 
575  OpPassManager defaultPM(opName);
576  defaultPipeline(defaultPM);
577  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
578  }
579  return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
580 }
581 
582 /// Attempt to inline calls within the given scc. This function returns
583 /// success if any calls were inlined, failure otherwise.
584 LogicalResult
585 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
586  CGUseList &useList, CallGraphSCC &currentSCC) {
587  CallGraph &cg = inlinerIface.cg;
588  auto &calls = inlinerIface.calls;
589 
590  // A set of dead nodes to remove after inlining.
591  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
592 
593  // Collect all of the direct calls within the nodes of the current SCC. We
594  // don't traverse nested callgraph nodes, because they are handled separately
595  // likely within a different SCC.
596  for (CallGraphNode *node : currentSCC) {
597  if (node->isExternal())
598  continue;
599 
600  // Don't collect calls if the node is already dead.
601  if (useList.isDead(node)) {
602  deadNodes.insert(node);
603  } else {
604  collectCallOps(*node->getCallableRegion(), node, cg,
605  inlinerIface.symbolTable, calls,
606  /*traverseNestedCGNodes=*/false);
607  }
608  }
609 
610  // When inlining a callee produces new call sites, we want to keep track of
611  // the fact that they were inlined from the callee. This allows us to avoid
612  // infinite inlining.
613  using InlineHistoryT = std::optional<size_t>;
615  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
616 
617  LLVM_DEBUG({
618  llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
619  for (unsigned i = 0, e = calls.size(); i < e; ++i)
620  llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
621  llvm::dbgs() << "}\n";
622  });
623 
624  // Try to inline each of the call operations. Don't cache the end iterator
625  // here as more calls may be added during inlining.
626  bool inlinedAnyCalls = false;
627  for (unsigned i = 0; i < calls.size(); ++i) {
628  if (deadNodes.contains(calls[i].sourceNode))
629  continue;
630  ResolvedCall it = calls[i];
631 
632  InlineHistoryT inlineHistoryID = callHistory[i];
633  bool inHistory =
634  inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
635  bool doInline = !inHistory && shouldInline(it);
636  CallOpInterface call = it.call;
637  LLVM_DEBUG({
638  if (doInline)
639  llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
640  else
641  llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
642  });
643  if (!doInline)
644  continue;
645 
646  unsigned prevSize = calls.size();
647  Region *targetRegion = it.targetNode->getCallableRegion();
648 
649  // If this is the last call to the target node and the node is discardable,
650  // then inline it in-place and delete the node if successful.
651  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
652 
653  LogicalResult inlineResult =
654  inlineCall(inlinerIface, call,
655  cast<CallableOpInterface>(targetRegion->getParentOp()),
656  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
657  if (failed(inlineResult)) {
658  LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
659  continue;
660  }
661  inlinedAnyCalls = true;
662 
663  // Create a inline history entry for this inlined call, so that we remember
664  // that new callsites came about due to inlining Callee.
665  InlineHistoryT newInlineHistoryID{inlineHistory.size()};
666  inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
667 
668  auto historyToString = [](InlineHistoryT h) {
669  return h.has_value() ? std::to_string(*h) : "root";
670  };
671  (void)historyToString;
672  LLVM_DEBUG(llvm::dbgs()
673  << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
674  << getNodeName(call) << ", " << historyToString(inlineHistoryID)
675  << "]\n");
676 
677  for (unsigned k = prevSize; k != calls.size(); ++k) {
678  callHistory.push_back(newInlineHistoryID);
679  LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
680  << "}\n with historyID = " << newInlineHistoryID
681  << ", added due to inlining of\n call {" << call
682  << "}\n with historyID = "
683  << historyToString(inlineHistoryID) << "\n");
684  }
685 
686  // If the inlining was successful, Merge the new uses into the source node.
687  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
688  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
689 
690  // then erase the call.
691  call.erase();
692 
693  // If we inlined in place, mark the node for deletion.
694  if (inlineInPlace) {
695  useList.eraseNode(it.targetNode);
696  deadNodes.insert(it.targetNode);
697  }
698  }
699 
700  for (CallGraphNode *node : deadNodes) {
701  currentSCC.remove(node);
702  inlinerIface.markForDeletion(node);
703  }
704  calls.clear();
705  return success(inlinedAnyCalls);
706 }
707 
708 /// Returns true if the given call should be inlined.
709 bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
710  // Don't allow inlining terminator calls. We currently don't support this
711  // case.
712  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
713  return false;
714 
715  // Don't allow inlining if the target is a self-recursive function.
716  if (llvm::count_if(*resolvedCall.targetNode,
717  [&](CallGraphNode::Edge const &edge) -> bool {
718  return edge.getTarget() == resolvedCall.targetNode;
719  }) > 0)
720  return false;
721 
722  // Don't allow inlining if the target is an ancestor of the call. This
723  // prevents inlining recursively.
724  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
725  if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
726  return false;
727 
728  // Don't allow inlining if the callee has multiple blocks (unstructured
729  // control flow) but we cannot be sure that the caller region supports that.
730  bool calleeHasMultipleBlocks =
731  llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
732  // If both parent ops have the same type, it is safe to inline. Otherwise,
733  // decide based on whether the op has the SingleBlock trait or not.
734  // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
735  auto callerRegionSupportsMultipleBlocks = [&]() {
736  return callableRegion->getParentOp()->getName() ==
737  resolvedCall.call->getParentOp()->getName() ||
738  !resolvedCall.call->getParentOp()
739  ->mightHaveTrait<OpTrait::SingleBlock>();
740  };
741  if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
742  return false;
743 
744  if (!inliner.isProfitableToInline(resolvedCall))
745  return false;
746 
747  // Otherwise, inline.
748  return true;
749 }
750 
751 LogicalResult Inliner::doInlining() {
752  Impl impl(*this);
753  auto *context = op->getContext();
754  // Run the inline transform in post-order over the SCCs in the callgraph.
755  SymbolTableCollection symbolTable;
756  // FIXME: some clean-up can be done for the arguments
757  // of the Impl's methods, if the inlinerIface and useList
758  // become the states of the Impl.
759  InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
760  CGUseList useList(op, cg, symbolTable);
761  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
762  return impl.inlineSCC(inlinerIface, useList, scc, context);
763  });
764  if (failed(result))
765  return result;
766 
767  // After inlining, make sure to erase any callables proven to be dead.
768  inlinerIface.eraseDeadCallables();
769  return success();
770 }
771 } // namespace mlir
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:304
Inliner::ResolvedCall ResolvedCall
Definition: Inliner.cpp:32
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:39
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
Definition: Inliner.cpp:362
static std::string getNodeName(CallOpInterface op)
Definition: Inliner.cpp:353
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:285
Block represents an ordered list of Operations.
Definition: Block.h:33
This class represents a directed edge between two nodes in the callgraph.
Definition: CallGraph.h:43
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:32
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:59
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:36
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:147
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:139
unsigned getMaxInliningIterations() const
Definition: Inliner.h:41
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC &currentSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
Definition: Inliner.cpp:469
Impl(Inliner &inliner)
Definition: Inliner.cpp:427
This is an implementation of the inliner that operates bottom up over the Strongly Connected Componen...
Definition: Inliner.h:69
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
Definition: Inliner.cpp:751
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
iterator begin()
Definition: Region.h:55
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
Definition: SymbolTable.h:183
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
Definition: Inliner.h:75
CallGraphNode * sourceNode
Definition: Inliner.h:80
CallOpInterface call
Definition: Inliner.h:79
CallGraphNode * targetNode
Definition: Inliner.h:80
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:871