MLIR  18.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Transforms/Passes.h"
17 
19 #include "mlir/IR/Threading.h"
22 #include "mlir/Pass/PassManager.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_INLINER
30 #include "mlir/Transforms/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "inlining"
34 
35 using namespace mlir;
36 
37 /// This function implements the default inliner optimization pipeline.
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Symbol Use Tracking
44 //===----------------------------------------------------------------------===//
45 
46 /// Walk all of the used symbol callgraph nodes referenced with the given op.
48  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
50  function_ref<void(CallGraphNode *, Operation *)> callback) {
51  auto symbolUses = SymbolTable::getSymbolUses(op);
52  assert(symbolUses && "expected uses to be valid");
53 
54  Operation *symbolTableOp = op->getParentOp();
55  for (const SymbolTable::SymbolUse &use : *symbolUses) {
56  auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
57  CallGraphNode *&node = refIt.first->second;
58 
59  // If this is the first instance of this reference, try to resolve a
60  // callgraph node for it.
61  if (refIt.second) {
62  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
63  use.getSymbolRef());
64  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
65  if (!callableOp)
66  continue;
67  node = cg.lookupNode(callableOp.getCallableRegion());
68  }
69  if (node)
70  callback(node, use.getUser());
71  }
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // CGUseList
76 
77 namespace {
78 /// This struct tracks the uses of callgraph nodes that can be dropped when
79 /// use_empty. It directly tracks and manages a use-list for all of the
80 /// call-graph nodes. This is necessary because many callgraph nodes are
81 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
82 /// class.
83 struct CGUseList {
84  /// This struct tracks the uses of callgraph nodes within a specific
85  /// operation.
86  struct CGUser {
87  /// Any nodes referenced in the top-level attribute list of this user. We
88  /// use a set here because the number of references does not matter.
89  DenseSet<CallGraphNode *> topLevelUses;
90 
91  /// Uses of nodes referenced by nested operations.
93  };
94 
95  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
96 
97  /// Drop uses of nodes referred to by the given call operation that resides
98  /// within 'userNode'.
99  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
100 
101  /// Remove the given node from the use list.
102  void eraseNode(CallGraphNode *node);
103 
104  /// Returns true if the given callgraph node has no uses and can be pruned.
105  bool isDead(CallGraphNode *node) const;
106 
107  /// Returns true if the given callgraph node has a single use and can be
108  /// discarded.
109  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
110 
111  /// Recompute the uses held by the given callgraph node.
112  void recomputeUses(CallGraphNode *node, CallGraph &cg);
113 
114  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
115  /// of 'lhs' into 'rhs'.
116  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
117 
118 private:
119  /// Decrement the uses of discardable nodes referenced by the given user.
120  void decrementDiscardableUses(CGUser &uses);
121 
122  /// A mapping between a discardable callgraph node (that is a symbol) and the
123  /// number of uses for this node.
124  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
125 
126  /// A mapping between a callgraph node and the symbol callgraph nodes that it
127  /// uses.
129 
130  /// A symbol table to use when resolving call lookups.
131  SymbolTableCollection &symbolTable;
132 };
133 } // namespace
134 
135 CGUseList::CGUseList(Operation *op, CallGraph &cg,
136  SymbolTableCollection &symbolTable)
137  : symbolTable(symbolTable) {
138  /// A set of callgraph nodes that are always known to be live during inlining.
139  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
140 
141  // Walk each of the symbol tables looking for discardable callgraph nodes.
142  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
143  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
144  // If this is a callgraph operation, check to see if it is discardable.
145  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
146  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
147  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
148  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
149  symbol.canDiscardOnUseEmpty()) {
150  discardableSymNodeUses.try_emplace(node, 0);
151  }
152  continue;
153  }
154  }
155  // Otherwise, check for any referenced nodes. These will be always-live.
156  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
157  [](CallGraphNode *, Operation *) {});
158  }
159  };
160  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
161  walkFn);
162 
163  // Drop the use information for any discardable nodes that are always live.
164  for (auto &it : alwaysLiveNodes)
165  discardableSymNodeUses.erase(it.second);
166 
167  // Compute the uses for each of the callable nodes in the graph.
168  for (CallGraphNode *node : cg)
169  recomputeUses(node, cg);
170 }
171 
172 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
173  CallGraph &cg) {
174  auto &userRefs = nodeUses[userNode].innerUses;
175  auto walkFn = [&](CallGraphNode *node, Operation *user) {
176  auto parentIt = userRefs.find(node);
177  if (parentIt == userRefs.end())
178  return;
179  --parentIt->second;
180  --discardableSymNodeUses[node];
181  };
183  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
184 }
185 
186 void CGUseList::eraseNode(CallGraphNode *node) {
187  // Drop all child nodes.
188  for (auto &edge : *node)
189  if (edge.isChild())
190  eraseNode(edge.getTarget());
191 
192  // Drop the uses held by this node and erase it.
193  auto useIt = nodeUses.find(node);
194  assert(useIt != nodeUses.end() && "expected node to be valid");
195  decrementDiscardableUses(useIt->getSecond());
196  nodeUses.erase(useIt);
197  discardableSymNodeUses.erase(node);
198 }
199 
200 bool CGUseList::isDead(CallGraphNode *node) const {
201  // If the parent operation isn't a symbol, simply check normal SSA deadness.
202  Operation *nodeOp = node->getCallableRegion()->getParentOp();
203  if (!isa<SymbolOpInterface>(nodeOp))
204  return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
205 
206  // Otherwise, check the number of symbol uses.
207  auto symbolIt = discardableSymNodeUses.find(node);
208  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
209 }
210 
211 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
212  // If this isn't a symbol node, check for side-effects and SSA use count.
213  Operation *nodeOp = node->getCallableRegion()->getParentOp();
214  if (!isa<SymbolOpInterface>(nodeOp))
215  return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
216 
217  // Otherwise, check the number of symbol uses.
218  auto symbolIt = discardableSymNodeUses.find(node);
219  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
220 }
221 
222 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
223  Operation *parentOp = node->getCallableRegion()->getParentOp();
224  CGUser &uses = nodeUses[node];
225  decrementDiscardableUses(uses);
226 
227  // Collect the new discardable uses within this node.
228  uses = CGUser();
230  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
231  auto discardSymIt = discardableSymNodeUses.find(refNode);
232  if (discardSymIt == discardableSymNodeUses.end())
233  return;
234 
235  if (user != parentOp)
236  ++uses.innerUses[refNode];
237  else if (!uses.topLevelUses.insert(refNode).second)
238  return;
239  ++discardSymIt->second;
240  };
241  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
242 }
243 
244 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
245  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
246  for (auto &useIt : lhsUses.innerUses) {
247  rhsUses.innerUses[useIt.first] += useIt.second;
248  discardableSymNodeUses[useIt.first] += useIt.second;
249  }
250 }
251 
252 void CGUseList::decrementDiscardableUses(CGUser &uses) {
253  for (CallGraphNode *node : uses.topLevelUses)
254  --discardableSymNodeUses[node];
255  for (auto &it : uses.innerUses)
256  discardableSymNodeUses[it.first] -= it.second;
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // CallGraph traversal
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents a specific callgraph SCC.
265 class CallGraphSCC {
266 public:
267  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
268  : parentIterator(parentIterator) {}
269  /// Return a range over the nodes within this SCC.
270  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
271  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
272 
273  /// Reset the nodes of this SCC with those provided.
274  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
275 
276  /// Remove the given node from this SCC.
277  void remove(CallGraphNode *node) {
278  auto it = llvm::find(nodes, node);
279  if (it != nodes.end()) {
280  nodes.erase(it);
281  parentIterator.ReplaceNode(node, nullptr);
282  }
283  }
284 
285 private:
286  std::vector<CallGraphNode *> nodes;
287  llvm::scc_iterator<const CallGraph *> &parentIterator;
288 };
289 } // namespace
290 
291 /// Run a given transformation over the SCCs of the callgraph in a bottom up
292 /// traversal.
294  const CallGraph &cg,
295  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
296  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
297  CallGraphSCC currentSCC(cgi);
298  while (!cgi.isAtEnd()) {
299  // Copy the current SCC and increment so that the transformer can modify the
300  // SCC without invalidating our iterator.
301  currentSCC.reset(*cgi);
302  ++cgi;
303  if (failed(sccTransformer(currentSCC)))
304  return failure();
305  }
306  return success();
307 }
308 
309 namespace {
310 /// This struct represents a resolved call to a given callgraph node. Given that
311 /// the call does not actually contain a direct reference to the
312 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
313 /// explicitly.
314 struct ResolvedCall {
315  ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
316  CallGraphNode *targetNode)
317  : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
318  CallOpInterface call;
319  CallGraphNode *sourceNode, *targetNode;
320 };
321 } // namespace
322 
323 /// Collect all of the callable operations within the given range of blocks. If
324 /// `traverseNestedCGNodes` is true, this will also collect call operations
325 /// inside of nested callgraph nodes.
327  CallGraphNode *sourceNode, CallGraph &cg,
328  SymbolTableCollection &symbolTable,
330  bool traverseNestedCGNodes) {
332  auto addToWorklist = [&](CallGraphNode *node,
334  for (Block &block : blocks)
335  worklist.emplace_back(&block, node);
336  };
337 
338  addToWorklist(sourceNode, blocks);
339  while (!worklist.empty()) {
340  Block *block;
341  std::tie(block, sourceNode) = worklist.pop_back_val();
342 
343  for (Operation &op : *block) {
344  if (auto call = dyn_cast<CallOpInterface>(op)) {
345  // TODO: Support inlining nested call references.
346  CallInterfaceCallable callable = call.getCallableForCallee();
347  if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
348  if (!isa<FlatSymbolRefAttr>(symRef))
349  continue;
350  }
351 
352  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
353  if (!targetNode->isExternal())
354  calls.emplace_back(call, sourceNode, targetNode);
355  continue;
356  }
357 
358  // If this is not a call, traverse the nested regions. If
359  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
360  // regions.
361  for (auto &nestedRegion : op.getRegions()) {
362  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
363  if (traverseNestedCGNodes || !nestedNode)
364  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
365  }
366  }
367  }
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Inliner
372 //===----------------------------------------------------------------------===//
373 
374 #ifndef NDEBUG
375 static std::string getNodeName(CallOpInterface op) {
376  if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
377  return debugString(op);
378  return "_unnamed_callee_";
379 }
380 #endif
381 
382 /// Return true if the specified `inlineHistoryID` indicates an inline history
383 /// that already includes `node`.
385  CallGraphNode *node, std::optional<size_t> inlineHistoryID,
386  MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
387  inlineHistory) {
388  while (inlineHistoryID.has_value()) {
389  assert(*inlineHistoryID < inlineHistory.size() &&
390  "Invalid inline history ID");
391  if (inlineHistory[*inlineHistoryID].first == node)
392  return true;
393  inlineHistoryID = inlineHistory[*inlineHistoryID].second;
394  }
395  return false;
396 }
397 
398 namespace {
399 /// This class provides a specialization of the main inlining interface.
400 struct Inliner : public InlinerInterface {
401  Inliner(MLIRContext *context, CallGraph &cg,
402  SymbolTableCollection &symbolTable)
403  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
404 
405  /// Process a set of blocks that have been inlined. This callback is invoked
406  /// *before* inlined terminator operations have been processed.
407  void
409  // Find the closest callgraph node from the first block.
410  CallGraphNode *node;
411  Region *region = inlinedBlocks.begin()->getParent();
412  while (!(node = cg.lookupNode(region))) {
413  region = region->getParentRegion();
414  assert(region && "expected valid parent node");
415  }
416 
417  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
418  /*traverseNestedCGNodes=*/true);
419  }
420 
421  /// Mark the given callgraph node for deletion.
422  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
423 
424  /// This method properly disposes of callables that became dead during
425  /// inlining. This should not be called while iterating over the SCCs.
426  void eraseDeadCallables() {
427  for (CallGraphNode *node : deadNodes)
428  node->getCallableRegion()->getParentOp()->erase();
429  }
430 
431  /// The set of callables known to be dead.
433 
434  /// The current set of call instructions to consider for inlining.
436 
437  /// The callgraph being operated on.
438  CallGraph &cg;
439 
440  /// A symbol table to use when resolving call lookups.
441  SymbolTableCollection &symbolTable;
442 };
443 } // namespace
444 
445 /// Returns true if the given call should be inlined.
446 static bool shouldInline(ResolvedCall &resolvedCall) {
447  // Don't allow inlining terminator calls. We currently don't support this
448  // case.
449  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
450  return false;
451 
452  // Don't allow inlining if the target is an ancestor of the call. This
453  // prevents inlining recursively.
454  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
455  if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
456  return false;
457 
458  // Don't allow inlining if the callee has multiple blocks (unstructured
459  // control flow) but we cannot be sure that the caller region supports that.
460  bool calleeHasMultipleBlocks =
461  llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
462  // If both parent ops have the same type, it is safe to inline. Otherwise,
463  // decide based on whether the op has the SingleBlock trait or not.
464  // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
465  auto callerRegionSupportsMultipleBlocks = [&]() {
466  return callableRegion->getParentOp()->getName() ==
467  resolvedCall.call->getParentOp()->getName() ||
468  !resolvedCall.call->getParentOp()
469  ->mightHaveTrait<OpTrait::SingleBlock>();
470  };
471  if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
472  return false;
473 
474  // Otherwise, inline.
475  return true;
476 }
477 
478 /// Attempt to inline calls within the given scc. This function returns
479 /// success if any calls were inlined, failure otherwise.
480 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
481  CallGraphSCC &currentSCC) {
482  CallGraph &cg = inliner.cg;
483  auto &calls = inliner.calls;
484 
485  // A set of dead nodes to remove after inlining.
486  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
487 
488  // Collect all of the direct calls within the nodes of the current SCC. We
489  // don't traverse nested callgraph nodes, because they are handled separately
490  // likely within a different SCC.
491  for (CallGraphNode *node : currentSCC) {
492  if (node->isExternal())
493  continue;
494 
495  // Don't collect calls if the node is already dead.
496  if (useList.isDead(node)) {
497  deadNodes.insert(node);
498  } else {
499  collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
500  calls, /*traverseNestedCGNodes=*/false);
501  }
502  }
503 
504  // When inlining a callee produces new call sites, we want to keep track of
505  // the fact that they were inlined from the callee. This allows us to avoid
506  // infinite inlining.
507  using InlineHistoryT = std::optional<size_t>;
509  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
510 
511  LLVM_DEBUG({
512  llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
513  for (unsigned i = 0, e = calls.size(); i < e; ++i)
514  llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
515  llvm::dbgs() << "}\n";
516  });
517 
518  // Try to inline each of the call operations. Don't cache the end iterator
519  // here as more calls may be added during inlining.
520  bool inlinedAnyCalls = false;
521  for (unsigned i = 0; i < calls.size(); ++i) {
522  if (deadNodes.contains(calls[i].sourceNode))
523  continue;
524  ResolvedCall it = calls[i];
525 
526  InlineHistoryT inlineHistoryID = callHistory[i];
527  bool inHistory =
528  inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
529  bool doInline = !inHistory && shouldInline(it);
530  CallOpInterface call = it.call;
531  LLVM_DEBUG({
532  if (doInline)
533  llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
534  else
535  llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
536  });
537  if (!doInline)
538  continue;
539 
540  unsigned prevSize = calls.size();
541  Region *targetRegion = it.targetNode->getCallableRegion();
542 
543  // If this is the last call to the target node and the node is discardable,
544  // then inline it in-place and delete the node if successful.
545  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
546 
547  LogicalResult inlineResult = inlineCall(
548  inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
549  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
550  if (failed(inlineResult)) {
551  LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
552  continue;
553  }
554  inlinedAnyCalls = true;
555 
556  // Create a inline history entry for this inlined call, so that we remember
557  // that new callsites came about due to inlining Callee.
558  InlineHistoryT newInlineHistoryID{inlineHistory.size()};
559  inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
560 
561  auto historyToString = [](InlineHistoryT h) {
562  return h.has_value() ? std::to_string(*h) : "root";
563  };
564  (void)historyToString;
565  LLVM_DEBUG(llvm::dbgs()
566  << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
567  << getNodeName(call) << ", " << historyToString(inlineHistoryID)
568  << "]\n");
569 
570  for (unsigned k = prevSize; k != calls.size(); ++k) {
571  callHistory.push_back(newInlineHistoryID);
572  LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
573  << "}\n with historyID = " << newInlineHistoryID
574  << ", added due to inlining of\n call {" << call
575  << "}\n with historyID = "
576  << historyToString(inlineHistoryID) << "\n");
577  }
578 
579  // If the inlining was successful, Merge the new uses into the source node.
580  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
581  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
582 
583  // then erase the call.
584  call.erase();
585 
586  // If we inlined in place, mark the node for deletion.
587  if (inlineInPlace) {
588  useList.eraseNode(it.targetNode);
589  deadNodes.insert(it.targetNode);
590  }
591  }
592 
593  for (CallGraphNode *node : deadNodes) {
594  currentSCC.remove(node);
595  inliner.markForDeletion(node);
596  }
597  calls.clear();
598  return success(inlinedAnyCalls);
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // InlinerPass
603 //===----------------------------------------------------------------------===//
604 
605 namespace {
606 class InlinerPass : public impl::InlinerBase<InlinerPass> {
607 public:
608  InlinerPass();
609  InlinerPass(const InlinerPass &) = default;
610  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
611  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
612  llvm::StringMap<OpPassManager> opPipelines);
613  void runOnOperation() override;
614 
615 private:
616  /// Attempt to inline calls within the given scc, and run simplifications,
617  /// until a fixed point is reached. This allows for the inlining of newly
618  /// devirtualized calls. Returns failure if there was a fatal error during
619  /// inlining.
620  LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
621  CallGraphSCC &currentSCC, MLIRContext *context);
622 
623  /// Optimize the nodes within the given SCC with one of the held optimization
624  /// pass pipelines. Returns failure if an error occurred during the
625  /// optimization of the SCC, success otherwise.
626  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
627  CallGraphSCC &currentSCC, MLIRContext *context);
628 
629  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
630  /// error occurred during the optimization of the SCC, success otherwise.
631  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
632  MLIRContext *context);
633 
634  /// Optimize the given callable node with one of the pass managers provided
635  /// with `pipelines`, or the default pipeline. Returns failure if an error
636  /// occurred during the optimization of the callable, success otherwise.
637  LogicalResult optimizeCallable(CallGraphNode *node,
638  llvm::StringMap<OpPassManager> &pipelines);
639 
640  /// Attempt to initialize the options of this pass from the given string.
641  /// Derived classes may override this method to hook into the point at which
642  /// options are initialized, but should generally always invoke this base
643  /// class variant.
644  LogicalResult initializeOptions(StringRef options) override;
645 
646  /// An optional function that constructs a default optimization pipeline for
647  /// a given operation.
648  std::function<void(OpPassManager &)> defaultPipeline;
649  /// A map of operation names to pass pipelines to use when optimizing
650  /// callable operations of these types. This provides a specialized pipeline
651  /// instead of the default. The vector size is the number of threads used
652  /// during optimization.
654 };
655 } // namespace
656 
657 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
658 InlinerPass::InlinerPass(
659  std::function<void(OpPassManager &)> defaultPipelineArg)
660  : defaultPipeline(std::move(defaultPipelineArg)) {
661  opPipelines.push_back({});
662 }
663 
664 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
665  llvm::StringMap<OpPassManager> opPipelines)
666  : InlinerPass(std::move(defaultPipeline)) {
667  if (opPipelines.empty())
668  return;
669 
670  // Update the option for the op specific optimization pipelines.
671  for (auto &it : opPipelines)
672  opPipelineList.addValue(it.second);
673  this->opPipelines.emplace_back(std::move(opPipelines));
674 }
675 
676 void InlinerPass::runOnOperation() {
677  CallGraph &cg = getAnalysis<CallGraph>();
678  auto *context = &getContext();
679 
680  // The inliner should only be run on operations that define a symbol table,
681  // as the callgraph will need to resolve references.
682  Operation *op = getOperation();
683  if (!op->hasTrait<OpTrait::SymbolTable>()) {
684  op->emitOpError() << " was scheduled to run under the inliner, but does "
685  "not define a symbol table";
686  return signalPassFailure();
687  }
688 
689  // Run the inline transform in post-order over the SCCs in the callgraph.
690  SymbolTableCollection symbolTable;
691  Inliner inliner(context, cg, symbolTable);
692  CGUseList useList(getOperation(), cg, symbolTable);
693  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
694  return inlineSCC(inliner, useList, scc, context);
695  });
696  if (failed(result))
697  return signalPassFailure();
698 
699  // After inlining, make sure to erase any callables proven to be dead.
700  inliner.eraseDeadCallables();
701 }
702 
703 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
704  CallGraphSCC &currentSCC,
705  MLIRContext *context) {
706  // Continuously simplify and inline until we either reach a fixed point, or
707  // hit the maximum iteration count. Simplifying early helps to refine the cost
708  // model, and in future iterations may devirtualize new calls.
709  unsigned iterationCount = 0;
710  do {
711  if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
712  return failure();
713  if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
714  break;
715  } while (++iterationCount < maxInliningIterations);
716  return success();
717 }
718 
719 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
720  CallGraphSCC &currentSCC,
721  MLIRContext *context) {
722  // Collect the sets of nodes to simplify.
723  SmallVector<CallGraphNode *, 4> nodesToVisit;
724  for (auto *node : currentSCC) {
725  if (node->isExternal())
726  continue;
727 
728  // Don't simplify nodes with children. Nodes with children require special
729  // handling as we may remove the node during simplification. In the future,
730  // we should be able to handle this case with proper node deletion tracking.
731  if (node->hasChildren())
732  continue;
733 
734  // We also won't apply simplifications to nodes that can't have passes
735  // scheduled on them.
736  auto *region = node->getCallableRegion();
738  continue;
739  nodesToVisit.push_back(node);
740  }
741  if (nodesToVisit.empty())
742  return success();
743 
744  // Optimize each of the nodes within the SCC in parallel.
745  if (failed(optimizeSCCAsync(nodesToVisit, context)))
746  return failure();
747 
748  // Recompute the uses held by each of the nodes.
749  for (CallGraphNode *node : nodesToVisit)
750  useList.recomputeUses(node, cg);
751  return success();
752 }
753 
755 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
756  MLIRContext *ctx) {
757  // We must maintain a fixed pool of pass managers which is at least as large
758  // as the maximum parallelism of the failableParallelForEach below.
759  // Note: The number of pass managers here needs to remain constant
760  // to prevent issues with pass instrumentations that rely on having the same
761  // pass manager for the main thread.
762  size_t numThreads = ctx->getNumThreads();
763  if (opPipelines.size() < numThreads) {
764  // Reserve before resizing so that we can use a reference to the first
765  // element.
766  opPipelines.reserve(numThreads);
767  opPipelines.resize(numThreads, opPipelines.front());
768  }
769 
770  // Ensure an analysis manager has been constructed for each of the nodes.
771  // This prevents thread races when running the nested pipelines.
772  for (CallGraphNode *node : nodesToVisit)
773  getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
774 
775  // An atomic failure variable for the async executors.
776  std::vector<std::atomic<bool>> activePMs(opPipelines.size());
777  std::fill(activePMs.begin(), activePMs.end(), false);
778  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
779  // Find a pass manager for this operation.
780  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
781  bool expectedInactive = false;
782  return isActive.compare_exchange_strong(expectedInactive, true);
783  });
784  assert(it != activePMs.end() &&
785  "could not find inactive pass manager for thread");
786  unsigned pmIndex = it - activePMs.begin();
787 
788  // Optimize this callable node.
789  LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
790 
791  // Reset the active bit for this pass manager.
792  activePMs[pmIndex].store(false);
793  return result;
794  });
795 }
796 
798 InlinerPass::optimizeCallable(CallGraphNode *node,
799  llvm::StringMap<OpPassManager> &pipelines) {
800  Operation *callable = node->getCallableRegion()->getParentOp();
801  StringRef opName = callable->getName().getStringRef();
802  auto pipelineIt = pipelines.find(opName);
803  if (pipelineIt == pipelines.end()) {
804  // If a pipeline didn't exist, use the default if possible.
805  if (!defaultPipeline)
806  return success();
807 
808  OpPassManager defaultPM(opName);
809  defaultPipeline(defaultPM);
810  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
811  }
812  return runPipeline(pipelineIt->second, callable);
813 }
814 
815 LogicalResult InlinerPass::initializeOptions(StringRef options) {
816  if (failed(Pass::initializeOptions(options)))
817  return failure();
818 
819  // Initialize the default pipeline builder to use the option string.
820  // TODO: Use a generic pass manager for default pipelines, and remove this.
821  if (!defaultPipelineStr.empty()) {
822  std::string defaultPipelineCopy = defaultPipelineStr;
823  defaultPipeline = [=](OpPassManager &pm) {
824  (void)parsePassPipeline(defaultPipelineCopy, pm);
825  };
826  } else if (defaultPipelineStr.getNumOccurrences()) {
827  defaultPipeline = nullptr;
828  }
829 
830  // Initialize the op specific pass pipelines.
831  llvm::StringMap<OpPassManager> pipelines;
832  for (OpPassManager pipeline : opPipelineList)
833  if (!pipeline.empty())
834  pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
835  opPipelines.assign({std::move(pipelines)});
836 
837  return success();
838 }
839 
840 std::unique_ptr<Pass> mlir::createInlinerPass() {
841  return std::make_unique<InlinerPass>();
842 }
843 std::unique_ptr<Pass>
844 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
845  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
846  std::move(opPipelines));
847 }
848 std::unique_ptr<Pass> mlir::createInlinerPass(
849  llvm::StringMap<OpPassManager> opPipelines,
850  std::function<void(OpPassManager &)> defaultPipelineBuilder) {
851  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
852  std::move(opPipelines));
853 }
static MLIRContext * getContext(OpFoldResult val)
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:326
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:47
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.
Definition: Inliner.cpp:446
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
Definition: Inliner.cpp:384
static std::string getNodeName(CallOpInterface op)
Definition: Inliner.cpp:375
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC)
Attempt to inline calls within the given scc.
Definition: Inliner.cpp:480
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:293
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
Definition: Inliner.cpp:38
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:30
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:32
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:59
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:36
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:147
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:139
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:48
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:346
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
iterator begin()
Definition: Region.h:55
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
Definition: SymbolTable.h:183
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
Definition: Inliner.cpp:840
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class provides APIs and verifiers for ops with regions having a single block.
Definition: OpDefinition.h:869