MLIR  16.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Transforms/Passes.h"
17 
19 #include "mlir/IR/Threading.h"
22 #include "mlir/Pass/PassManager.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_INLINER
30 #include "mlir/Transforms/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "inlining"
34 
35 using namespace mlir;
36 
37 /// This function implements the default inliner optimization pipeline.
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Symbol Use Tracking
44 //===----------------------------------------------------------------------===//
45 
46 /// Walk all of the used symbol callgraph nodes referenced with the given op.
48  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
50  function_ref<void(CallGraphNode *, Operation *)> callback) {
51  auto symbolUses = SymbolTable::getSymbolUses(op);
52  assert(symbolUses && "expected uses to be valid");
53 
54  Operation *symbolTableOp = op->getParentOp();
55  for (const SymbolTable::SymbolUse &use : *symbolUses) {
56  auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
57  CallGraphNode *&node = refIt.first->second;
58 
59  // If this is the first instance of this reference, try to resolve a
60  // callgraph node for it.
61  if (refIt.second) {
62  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
63  use.getSymbolRef());
64  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
65  if (!callableOp)
66  continue;
67  node = cg.lookupNode(callableOp.getCallableRegion());
68  }
69  if (node)
70  callback(node, use.getUser());
71  }
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // CGUseList
76 
77 namespace {
78 /// This struct tracks the uses of callgraph nodes that can be dropped when
79 /// use_empty. It directly tracks and manages a use-list for all of the
80 /// call-graph nodes. This is necessary because many callgraph nodes are
81 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
82 /// class.
83 struct CGUseList {
84  /// This struct tracks the uses of callgraph nodes within a specific
85  /// operation.
86  struct CGUser {
87  /// Any nodes referenced in the top-level attribute list of this user. We
88  /// use a set here because the number of references does not matter.
89  DenseSet<CallGraphNode *> topLevelUses;
90 
91  /// Uses of nodes referenced by nested operations.
93  };
94 
95  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
96 
97  /// Drop uses of nodes referred to by the given call operation that resides
98  /// within 'userNode'.
99  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
100 
101  /// Remove the given node from the use list.
102  void eraseNode(CallGraphNode *node);
103 
104  /// Returns true if the given callgraph node has no uses and can be pruned.
105  bool isDead(CallGraphNode *node) const;
106 
107  /// Returns true if the given callgraph node has a single use and can be
108  /// discarded.
109  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
110 
111  /// Recompute the uses held by the given callgraph node.
112  void recomputeUses(CallGraphNode *node, CallGraph &cg);
113 
114  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
115  /// of 'lhs' into 'rhs'.
116  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
117 
118 private:
119  /// Decrement the uses of discardable nodes referenced by the given user.
120  void decrementDiscardableUses(CGUser &uses);
121 
122  /// A mapping between a discardable callgraph node (that is a symbol) and the
123  /// number of uses for this node.
124  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
125 
126  /// A mapping between a callgraph node and the symbol callgraph nodes that it
127  /// uses.
129 
130  /// A symbol table to use when resolving call lookups.
131  SymbolTableCollection &symbolTable;
132 };
133 } // namespace
134 
135 CGUseList::CGUseList(Operation *op, CallGraph &cg,
136  SymbolTableCollection &symbolTable)
137  : symbolTable(symbolTable) {
138  /// A set of callgraph nodes that are always known to be live during inlining.
139  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
140 
141  // Walk each of the symbol tables looking for discardable callgraph nodes.
142  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
143  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
144  // If this is a callgraph operation, check to see if it is discardable.
145  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
146  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
147  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
148  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
149  symbol.canDiscardOnUseEmpty()) {
150  discardableSymNodeUses.try_emplace(node, 0);
151  }
152  continue;
153  }
154  }
155  // Otherwise, check for any referenced nodes. These will be always-live.
156  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
157  [](CallGraphNode *, Operation *) {});
158  }
159  };
160  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
161  walkFn);
162 
163  // Drop the use information for any discardable nodes that are always live.
164  for (auto &it : alwaysLiveNodes)
165  discardableSymNodeUses.erase(it.second);
166 
167  // Compute the uses for each of the callable nodes in the graph.
168  for (CallGraphNode *node : cg)
169  recomputeUses(node, cg);
170 }
171 
172 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
173  CallGraph &cg) {
174  auto &userRefs = nodeUses[userNode].innerUses;
175  auto walkFn = [&](CallGraphNode *node, Operation *user) {
176  auto parentIt = userRefs.find(node);
177  if (parentIt == userRefs.end())
178  return;
179  --parentIt->second;
180  --discardableSymNodeUses[node];
181  };
183  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
184 }
185 
186 void CGUseList::eraseNode(CallGraphNode *node) {
187  // Drop all child nodes.
188  for (auto &edge : *node)
189  if (edge.isChild())
190  eraseNode(edge.getTarget());
191 
192  // Drop the uses held by this node and erase it.
193  auto useIt = nodeUses.find(node);
194  assert(useIt != nodeUses.end() && "expected node to be valid");
195  decrementDiscardableUses(useIt->getSecond());
196  nodeUses.erase(useIt);
197  discardableSymNodeUses.erase(node);
198 }
199 
200 bool CGUseList::isDead(CallGraphNode *node) const {
201  // If the parent operation isn't a symbol, simply check normal SSA deadness.
202  Operation *nodeOp = node->getCallableRegion()->getParentOp();
203  if (!isa<SymbolOpInterface>(nodeOp))
204  return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
205 
206  // Otherwise, check the number of symbol uses.
207  auto symbolIt = discardableSymNodeUses.find(node);
208  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
209 }
210 
211 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
212  // If this isn't a symbol node, check for side-effects and SSA use count.
213  Operation *nodeOp = node->getCallableRegion()->getParentOp();
214  if (!isa<SymbolOpInterface>(nodeOp))
215  return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
216 
217  // Otherwise, check the number of symbol uses.
218  auto symbolIt = discardableSymNodeUses.find(node);
219  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
220 }
221 
222 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
223  Operation *parentOp = node->getCallableRegion()->getParentOp();
224  CGUser &uses = nodeUses[node];
225  decrementDiscardableUses(uses);
226 
227  // Collect the new discardable uses within this node.
228  uses = CGUser();
230  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
231  auto discardSymIt = discardableSymNodeUses.find(refNode);
232  if (discardSymIt == discardableSymNodeUses.end())
233  return;
234 
235  if (user != parentOp)
236  ++uses.innerUses[refNode];
237  else if (!uses.topLevelUses.insert(refNode).second)
238  return;
239  ++discardSymIt->second;
240  };
241  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
242 }
243 
244 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
245  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
246  for (auto &useIt : lhsUses.innerUses) {
247  rhsUses.innerUses[useIt.first] += useIt.second;
248  discardableSymNodeUses[useIt.first] += useIt.second;
249  }
250 }
251 
252 void CGUseList::decrementDiscardableUses(CGUser &uses) {
253  for (CallGraphNode *node : uses.topLevelUses)
254  --discardableSymNodeUses[node];
255  for (auto &it : uses.innerUses)
256  discardableSymNodeUses[it.first] -= it.second;
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // CallGraph traversal
261 //===----------------------------------------------------------------------===//
262 
263 namespace {
264 /// This class represents a specific callgraph SCC.
265 class CallGraphSCC {
266 public:
267  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
268  : parentIterator(parentIterator) {}
269  /// Return a range over the nodes within this SCC.
270  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
271  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
272 
273  /// Reset the nodes of this SCC with those provided.
274  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
275 
276  /// Remove the given node from this SCC.
277  void remove(CallGraphNode *node) {
278  auto it = llvm::find(nodes, node);
279  if (it != nodes.end()) {
280  nodes.erase(it);
281  parentIterator.ReplaceNode(node, nullptr);
282  }
283  }
284 
285 private:
286  std::vector<CallGraphNode *> nodes;
287  llvm::scc_iterator<const CallGraph *> &parentIterator;
288 };
289 } // namespace
290 
291 /// Run a given transformation over the SCCs of the callgraph in a bottom up
292 /// traversal.
294  const CallGraph &cg,
295  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
296  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
297  CallGraphSCC currentSCC(cgi);
298  while (!cgi.isAtEnd()) {
299  // Copy the current SCC and increment so that the transformer can modify the
300  // SCC without invalidating our iterator.
301  currentSCC.reset(*cgi);
302  ++cgi;
303  if (failed(sccTransformer(currentSCC)))
304  return failure();
305  }
306  return success();
307 }
308 
309 namespace {
310 /// This struct represents a resolved call to a given callgraph node. Given that
311 /// the call does not actually contain a direct reference to the
312 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
313 /// explicitly.
314 struct ResolvedCall {
315  ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
316  CallGraphNode *targetNode)
317  : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
318  CallOpInterface call;
319  CallGraphNode *sourceNode, *targetNode;
320 };
321 } // namespace
322 
323 /// Collect all of the callable operations within the given range of blocks. If
324 /// `traverseNestedCGNodes` is true, this will also collect call operations
325 /// inside of nested callgraph nodes.
327  CallGraphNode *sourceNode, CallGraph &cg,
328  SymbolTableCollection &symbolTable,
330  bool traverseNestedCGNodes) {
332  auto addToWorklist = [&](CallGraphNode *node,
334  for (Block &block : blocks)
335  worklist.emplace_back(&block, node);
336  };
337 
338  addToWorklist(sourceNode, blocks);
339  while (!worklist.empty()) {
340  Block *block;
341  std::tie(block, sourceNode) = worklist.pop_back_val();
342 
343  for (Operation &op : *block) {
344  if (auto call = dyn_cast<CallOpInterface>(op)) {
345  // TODO: Support inlining nested call references.
346  CallInterfaceCallable callable = call.getCallableForCallee();
347  if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
348  if (!symRef.isa<FlatSymbolRefAttr>())
349  continue;
350  }
351 
352  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
353  if (!targetNode->isExternal())
354  calls.emplace_back(call, sourceNode, targetNode);
355  continue;
356  }
357 
358  // If this is not a call, traverse the nested regions. If
359  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
360  // regions.
361  for (auto &nestedRegion : op.getRegions()) {
362  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
363  if (traverseNestedCGNodes || !nestedNode)
364  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
365  }
366  }
367  }
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Inliner
372 //===----------------------------------------------------------------------===//
373 
374 #ifndef NDEBUG
375 static std::string getNodeName(CallOpInterface op) {
376  if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
377  return debugString(op);
378  return "_unnamed_callee_";
379 }
380 #endif
381 
382 /// Return true if the specified `inlineHistoryID` indicates an inline history
383 /// that already includes `node`.
385  CallGraphNode *node, Optional<size_t> inlineHistoryID,
387  inlineHistory) {
388  while (inlineHistoryID.has_value()) {
389  assert(inlineHistoryID.value() < inlineHistory.size() &&
390  "Invalid inline history ID");
391  if (inlineHistory[inlineHistoryID.value()].first == node)
392  return true;
393  inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
394  }
395  return false;
396 }
397 
398 namespace {
399 /// This class provides a specialization of the main inlining interface.
400 struct Inliner : public InlinerInterface {
401  Inliner(MLIRContext *context, CallGraph &cg,
402  SymbolTableCollection &symbolTable)
403  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
404 
405  /// Process a set of blocks that have been inlined. This callback is invoked
406  /// *before* inlined terminator operations have been processed.
407  void
409  // Find the closest callgraph node from the first block.
410  CallGraphNode *node;
411  Region *region = inlinedBlocks.begin()->getParent();
412  while (!(node = cg.lookupNode(region))) {
413  region = region->getParentRegion();
414  assert(region && "expected valid parent node");
415  }
416 
417  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
418  /*traverseNestedCGNodes=*/true);
419  }
420 
421  /// Mark the given callgraph node for deletion.
422  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
423 
424  /// This method properly disposes of callables that became dead during
425  /// inlining. This should not be called while iterating over the SCCs.
426  void eraseDeadCallables() {
427  for (CallGraphNode *node : deadNodes)
428  node->getCallableRegion()->getParentOp()->erase();
429  }
430 
431  /// The set of callables known to be dead.
433 
434  /// The current set of call instructions to consider for inlining.
436 
437  /// The callgraph being operated on.
438  CallGraph &cg;
439 
440  /// A symbol table to use when resolving call lookups.
441  SymbolTableCollection &symbolTable;
442 };
443 } // namespace
444 
445 /// Returns true if the given call should be inlined.
446 static bool shouldInline(ResolvedCall &resolvedCall) {
447  // Don't allow inlining terminator calls. We currently don't support this
448  // case.
449  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
450  return false;
451 
452  // Don't allow inlining if the target is an ancestor of the call. This
453  // prevents inlining recursively.
454  if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
455  resolvedCall.call->getParentRegion()))
456  return false;
457 
458  // Otherwise, inline.
459  return true;
460 }
461 
462 /// Attempt to inline calls within the given scc. This function returns
463 /// success if any calls were inlined, failure otherwise.
464 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
465  CallGraphSCC &currentSCC) {
466  CallGraph &cg = inliner.cg;
467  auto &calls = inliner.calls;
468 
469  // A set of dead nodes to remove after inlining.
470  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
471 
472  // Collect all of the direct calls within the nodes of the current SCC. We
473  // don't traverse nested callgraph nodes, because they are handled separately
474  // likely within a different SCC.
475  for (CallGraphNode *node : currentSCC) {
476  if (node->isExternal())
477  continue;
478 
479  // Don't collect calls if the node is already dead.
480  if (useList.isDead(node)) {
481  deadNodes.insert(node);
482  } else {
483  collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
484  calls, /*traverseNestedCGNodes=*/false);
485  }
486  }
487 
488  // When inlining a callee produces new call sites, we want to keep track of
489  // the fact that they were inlined from the callee. This allows us to avoid
490  // infinite inlining.
491  using InlineHistoryT = Optional<size_t>;
493  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
494 
495  LLVM_DEBUG({
496  llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
497  for (unsigned i = 0, e = calls.size(); i < e; ++i)
498  llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
499  llvm::dbgs() << "}\n";
500  });
501 
502  // Try to inline each of the call operations. Don't cache the end iterator
503  // here as more calls may be added during inlining.
504  bool inlinedAnyCalls = false;
505  for (unsigned i = 0; i < calls.size(); ++i) {
506  if (deadNodes.contains(calls[i].sourceNode))
507  continue;
508  ResolvedCall it = calls[i];
509 
510  InlineHistoryT inlineHistoryID = callHistory[i];
511  bool inHistory =
512  inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
513  bool doInline = !inHistory && shouldInline(it);
514  CallOpInterface call = it.call;
515  LLVM_DEBUG({
516  if (doInline)
517  llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
518  else
519  llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
520  });
521  if (!doInline)
522  continue;
523 
524  unsigned prevSize = calls.size();
525  Region *targetRegion = it.targetNode->getCallableRegion();
526 
527  // If this is the last call to the target node and the node is discardable,
528  // then inline it in-place and delete the node if successful.
529  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
530 
531  LogicalResult inlineResult = inlineCall(
532  inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
533  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
534  if (failed(inlineResult)) {
535  LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
536  continue;
537  }
538  inlinedAnyCalls = true;
539 
540  // Create a inline history entry for this inlined call, so that we remember
541  // that new callsites came about due to inlining Callee.
542  InlineHistoryT newInlineHistoryID{inlineHistory.size()};
543  inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
544 
545  auto historyToString = [](InlineHistoryT h) {
546  return h.has_value() ? std::to_string(h.value()) : "root";
547  };
548  (void)historyToString;
549  LLVM_DEBUG(llvm::dbgs()
550  << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
551  << getNodeName(call) << ", " << historyToString(inlineHistoryID)
552  << "]\n");
553 
554  for (unsigned k = prevSize; k != calls.size(); ++k) {
555  callHistory.push_back(newInlineHistoryID);
556  LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
557  << "}\n with historyID = " << newInlineHistoryID
558  << ", added due to inlining of\n call {" << call
559  << "}\n with historyID = "
560  << historyToString(inlineHistoryID) << "\n");
561  }
562 
563  // If the inlining was successful, Merge the new uses into the source node.
564  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
565  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
566 
567  // then erase the call.
568  call.erase();
569 
570  // If we inlined in place, mark the node for deletion.
571  if (inlineInPlace) {
572  useList.eraseNode(it.targetNode);
573  deadNodes.insert(it.targetNode);
574  }
575  }
576 
577  for (CallGraphNode *node : deadNodes) {
578  currentSCC.remove(node);
579  inliner.markForDeletion(node);
580  }
581  calls.clear();
582  return success(inlinedAnyCalls);
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // InlinerPass
587 //===----------------------------------------------------------------------===//
588 
589 namespace {
590 class InlinerPass : public impl::InlinerBase<InlinerPass> {
591 public:
592  InlinerPass();
593  InlinerPass(const InlinerPass &) = default;
594  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
595  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
596  llvm::StringMap<OpPassManager> opPipelines);
597  void runOnOperation() override;
598 
599 private:
600  /// Attempt to inline calls within the given scc, and run simplifications,
601  /// until a fixed point is reached. This allows for the inlining of newly
602  /// devirtualized calls. Returns failure if there was a fatal error during
603  /// inlining.
604  LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
605  CallGraphSCC &currentSCC, MLIRContext *context);
606 
607  /// Optimize the nodes within the given SCC with one of the held optimization
608  /// pass pipelines. Returns failure if an error occurred during the
609  /// optimization of the SCC, success otherwise.
610  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
611  CallGraphSCC &currentSCC, MLIRContext *context);
612 
613  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
614  /// error occurred during the optimization of the SCC, success otherwise.
615  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
616  MLIRContext *context);
617 
618  /// Optimize the given callable node with one of the pass managers provided
619  /// with `pipelines`, or the default pipeline. Returns failure if an error
620  /// occurred during the optimization of the callable, success otherwise.
621  LogicalResult optimizeCallable(CallGraphNode *node,
622  llvm::StringMap<OpPassManager> &pipelines);
623 
624  /// Attempt to initialize the options of this pass from the given string.
625  /// Derived classes may override this method to hook into the point at which
626  /// options are initialized, but should generally always invoke this base
627  /// class variant.
628  LogicalResult initializeOptions(StringRef options) override;
629 
630  /// An optional function that constructs a default optimization pipeline for
631  /// a given operation.
632  std::function<void(OpPassManager &)> defaultPipeline;
633  /// A map of operation names to pass pipelines to use when optimizing
634  /// callable operations of these types. This provides a specialized pipeline
635  /// instead of the default. The vector size is the number of threads used
636  /// during optimization.
638 };
639 } // namespace
640 
641 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
642 InlinerPass::InlinerPass(
643  std::function<void(OpPassManager &)> defaultPipelineArg)
644  : defaultPipeline(std::move(defaultPipelineArg)) {
645  opPipelines.push_back({});
646 }
647 
648 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
649  llvm::StringMap<OpPassManager> opPipelines)
650  : InlinerPass(std::move(defaultPipeline)) {
651  if (opPipelines.empty())
652  return;
653 
654  // Update the option for the op specific optimization pipelines.
655  for (auto &it : opPipelines)
656  opPipelineList.addValue(it.second);
657  this->opPipelines.emplace_back(std::move(opPipelines));
658 }
659 
660 void InlinerPass::runOnOperation() {
661  CallGraph &cg = getAnalysis<CallGraph>();
662  auto *context = &getContext();
663 
664  // The inliner should only be run on operations that define a symbol table,
665  // as the callgraph will need to resolve references.
666  Operation *op = getOperation();
667  if (!op->hasTrait<OpTrait::SymbolTable>()) {
668  op->emitOpError() << " was scheduled to run under the inliner, but does "
669  "not define a symbol table";
670  return signalPassFailure();
671  }
672 
673  // Run the inline transform in post-order over the SCCs in the callgraph.
674  SymbolTableCollection symbolTable;
675  Inliner inliner(context, cg, symbolTable);
676  CGUseList useList(getOperation(), cg, symbolTable);
677  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
678  return inlineSCC(inliner, useList, scc, context);
679  });
680  if (failed(result))
681  return signalPassFailure();
682 
683  // After inlining, make sure to erase any callables proven to be dead.
684  inliner.eraseDeadCallables();
685 }
686 
687 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
688  CallGraphSCC &currentSCC,
689  MLIRContext *context) {
690  // Continuously simplify and inline until we either reach a fixed point, or
691  // hit the maximum iteration count. Simplifying early helps to refine the cost
692  // model, and in future iterations may devirtualize new calls.
693  unsigned iterationCount = 0;
694  do {
695  if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
696  return failure();
697  if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
698  break;
699  } while (++iterationCount < maxInliningIterations);
700  return success();
701 }
702 
703 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
704  CallGraphSCC &currentSCC,
705  MLIRContext *context) {
706  // Collect the sets of nodes to simplify.
707  SmallVector<CallGraphNode *, 4> nodesToVisit;
708  for (auto *node : currentSCC) {
709  if (node->isExternal())
710  continue;
711 
712  // Don't simplify nodes with children. Nodes with children require special
713  // handling as we may remove the node during simplification. In the future,
714  // we should be able to handle this case with proper node deletion tracking.
715  if (node->hasChildren())
716  continue;
717 
718  // We also won't apply simplifications to nodes that can't have passes
719  // scheduled on them.
720  auto *region = node->getCallableRegion();
722  continue;
723  nodesToVisit.push_back(node);
724  }
725  if (nodesToVisit.empty())
726  return success();
727 
728  // Optimize each of the nodes within the SCC in parallel.
729  if (failed(optimizeSCCAsync(nodesToVisit, context)))
730  return failure();
731 
732  // Recompute the uses held by each of the nodes.
733  for (CallGraphNode *node : nodesToVisit)
734  useList.recomputeUses(node, cg);
735  return success();
736 }
737 
739 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
740  MLIRContext *ctx) {
741  // We must maintain a fixed pool of pass managers which is at least as large
742  // as the maximum parallelism of the failableParallelForEach below.
743  // Note: The number of pass managers here needs to remain constant
744  // to prevent issues with pass instrumentations that rely on having the same
745  // pass manager for the main thread.
746  size_t numThreads = ctx->getNumThreads();
747  if (opPipelines.size() < numThreads) {
748  // Reserve before resizing so that we can use a reference to the first
749  // element.
750  opPipelines.reserve(numThreads);
751  opPipelines.resize(numThreads, opPipelines.front());
752  }
753 
754  // Ensure an analysis manager has been constructed for each of the nodes.
755  // This prevents thread races when running the nested pipelines.
756  for (CallGraphNode *node : nodesToVisit)
757  getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
758 
759  // An atomic failure variable for the async executors.
760  std::vector<std::atomic<bool>> activePMs(opPipelines.size());
761  std::fill(activePMs.begin(), activePMs.end(), false);
762  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
763  // Find a pass manager for this operation.
764  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
765  bool expectedInactive = false;
766  return isActive.compare_exchange_strong(expectedInactive, true);
767  });
768  assert(it != activePMs.end() &&
769  "could not find inactive pass manager for thread");
770  unsigned pmIndex = it - activePMs.begin();
771 
772  // Optimize this callable node.
773  LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
774 
775  // Reset the active bit for this pass manager.
776  activePMs[pmIndex].store(false);
777  return result;
778  });
779 }
780 
782 InlinerPass::optimizeCallable(CallGraphNode *node,
783  llvm::StringMap<OpPassManager> &pipelines) {
784  Operation *callable = node->getCallableRegion()->getParentOp();
785  StringRef opName = callable->getName().getStringRef();
786  auto pipelineIt = pipelines.find(opName);
787  if (pipelineIt == pipelines.end()) {
788  // If a pipeline didn't exist, use the default if possible.
789  if (!defaultPipeline)
790  return success();
791 
792  OpPassManager defaultPM(opName);
793  defaultPipeline(defaultPM);
794  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
795  }
796  return runPipeline(pipelineIt->second, callable);
797 }
798 
799 LogicalResult InlinerPass::initializeOptions(StringRef options) {
800  if (failed(Pass::initializeOptions(options)))
801  return failure();
802 
803  // Initialize the default pipeline builder to use the option string.
804  // TODO: Use a generic pass manager for default pipelines, and remove this.
805  if (!defaultPipelineStr.empty()) {
806  std::string defaultPipelineCopy = defaultPipelineStr;
807  defaultPipeline = [=](OpPassManager &pm) {
808  (void)parsePassPipeline(defaultPipelineCopy, pm);
809  };
810  } else if (defaultPipelineStr.getNumOccurrences()) {
811  defaultPipeline = nullptr;
812  }
813 
814  // Initialize the op specific pass pipelines.
815  llvm::StringMap<OpPassManager> pipelines;
816  for (OpPassManager pipeline : opPipelineList)
817  if (!pipeline.empty())
818  pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
819  opPipelines.assign({std::move(pipelines)});
820 
821  return success();
822 }
823 
824 std::unique_ptr<Pass> mlir::createInlinerPass() {
825  return std::make_unique<InlinerPass>();
826 }
827 std::unique_ptr<Pass>
828 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
829  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
830  std::move(opPipelines));
831 }
832 std::unique_ptr<Pass> mlir::createInlinerPass(
833  llvm::StringMap<OpPassManager> opPipelines,
834  std::function<void(OpPassManager &)> defaultPipelineBuilder) {
835  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
836  std::move(opPipelines));
837 }
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:326
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:47
static bool inlineHistoryIncludes(CallGraphNode *node, Optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, Optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
Definition: Inliner.cpp:384
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.
Definition: Inliner.cpp:446
static std::string getNodeName(CallOpInterface op)
Definition: Inliner.cpp:375
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC)
Attempt to inline calls within the given scc.
Definition: Inliner.cpp:464
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:293
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
Definition: Inliner.cpp:38
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:30
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:28
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:55
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:32
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:143
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:135
A symbol reference with a reference path containing a single element.
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:52
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:332
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:701
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:341
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:629
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:626
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
iterator begin()
Definition: Region.h:55
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:245
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
Definition: SymbolTable.h:147
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
Definition: Inliner.cpp:824
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26