MLIR  15.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
18 #include "mlir/IR/Threading.h"
21 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "inlining"
29 
30 using namespace mlir;
31 
32 /// This function implements the default inliner optimization pipeline.
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Symbol Use Tracking
39 //===----------------------------------------------------------------------===//
40 
41 /// Walk all of the used symbol callgraph nodes referenced with the given op.
43  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
45  function_ref<void(CallGraphNode *, Operation *)> callback) {
46  auto symbolUses = SymbolTable::getSymbolUses(op);
47  assert(symbolUses && "expected uses to be valid");
48 
49  Operation *symbolTableOp = op->getParentOp();
50  for (const SymbolTable::SymbolUse &use : *symbolUses) {
51  auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
52  CallGraphNode *&node = refIt.first->second;
53 
54  // If this is the first instance of this reference, try to resolve a
55  // callgraph node for it.
56  if (refIt.second) {
57  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
58  use.getSymbolRef());
59  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
60  if (!callableOp)
61  continue;
62  node = cg.lookupNode(callableOp.getCallableRegion());
63  }
64  if (node)
65  callback(node, use.getUser());
66  }
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // CGUseList
71 
72 namespace {
73 /// This struct tracks the uses of callgraph nodes that can be dropped when
74 /// use_empty. It directly tracks and manages a use-list for all of the
75 /// call-graph nodes. This is necessary because many callgraph nodes are
76 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
77 /// class.
78 struct CGUseList {
79  /// This struct tracks the uses of callgraph nodes within a specific
80  /// operation.
81  struct CGUser {
82  /// Any nodes referenced in the top-level attribute list of this user. We
83  /// use a set here because the number of references does not matter.
84  DenseSet<CallGraphNode *> topLevelUses;
85 
86  /// Uses of nodes referenced by nested operations.
88  };
89 
90  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
91 
92  /// Drop uses of nodes referred to by the given call operation that resides
93  /// within 'userNode'.
94  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
95 
96  /// Remove the given node from the use list.
97  void eraseNode(CallGraphNode *node);
98 
99  /// Returns true if the given callgraph node has no uses and can be pruned.
100  bool isDead(CallGraphNode *node) const;
101 
102  /// Returns true if the given callgraph node has a single use and can be
103  /// discarded.
104  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
105 
106  /// Recompute the uses held by the given callgraph node.
107  void recomputeUses(CallGraphNode *node, CallGraph &cg);
108 
109  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
110  /// of 'lhs' into 'rhs'.
111  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
112 
113 private:
114  /// Decrement the uses of discardable nodes referenced by the given user.
115  void decrementDiscardableUses(CGUser &uses);
116 
117  /// A mapping between a discardable callgraph node (that is a symbol) and the
118  /// number of uses for this node.
119  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
120 
121  /// A mapping between a callgraph node and the symbol callgraph nodes that it
122  /// uses.
124 
125  /// A symbol table to use when resolving call lookups.
126  SymbolTableCollection &symbolTable;
127 };
128 } // namespace
129 
130 CGUseList::CGUseList(Operation *op, CallGraph &cg,
131  SymbolTableCollection &symbolTable)
132  : symbolTable(symbolTable) {
133  /// A set of callgraph nodes that are always known to be live during inlining.
134  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
135 
136  // Walk each of the symbol tables looking for discardable callgraph nodes.
137  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
138  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
139  // If this is a callgraph operation, check to see if it is discardable.
140  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
141  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
142  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
143  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
144  symbol.canDiscardOnUseEmpty()) {
145  discardableSymNodeUses.try_emplace(node, 0);
146  }
147  continue;
148  }
149  }
150  // Otherwise, check for any referenced nodes. These will be always-live.
151  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
152  [](CallGraphNode *, Operation *) {});
153  }
154  };
155  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
156  walkFn);
157 
158  // Drop the use information for any discardable nodes that are always live.
159  for (auto &it : alwaysLiveNodes)
160  discardableSymNodeUses.erase(it.second);
161 
162  // Compute the uses for each of the callable nodes in the graph.
163  for (CallGraphNode *node : cg)
164  recomputeUses(node, cg);
165 }
166 
167 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
168  CallGraph &cg) {
169  auto &userRefs = nodeUses[userNode].innerUses;
170  auto walkFn = [&](CallGraphNode *node, Operation *user) {
171  auto parentIt = userRefs.find(node);
172  if (parentIt == userRefs.end())
173  return;
174  --parentIt->second;
175  --discardableSymNodeUses[node];
176  };
178  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
179 }
180 
181 void CGUseList::eraseNode(CallGraphNode *node) {
182  // Drop all child nodes.
183  for (auto &edge : *node)
184  if (edge.isChild())
185  eraseNode(edge.getTarget());
186 
187  // Drop the uses held by this node and erase it.
188  auto useIt = nodeUses.find(node);
189  assert(useIt != nodeUses.end() && "expected node to be valid");
190  decrementDiscardableUses(useIt->getSecond());
191  nodeUses.erase(useIt);
192  discardableSymNodeUses.erase(node);
193 }
194 
195 bool CGUseList::isDead(CallGraphNode *node) const {
196  // If the parent operation isn't a symbol, simply check normal SSA deadness.
197  Operation *nodeOp = node->getCallableRegion()->getParentOp();
198  if (!isa<SymbolOpInterface>(nodeOp))
199  return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
200 
201  // Otherwise, check the number of symbol uses.
202  auto symbolIt = discardableSymNodeUses.find(node);
203  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
204 }
205 
206 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
207  // If this isn't a symbol node, check for side-effects and SSA use count.
208  Operation *nodeOp = node->getCallableRegion()->getParentOp();
209  if (!isa<SymbolOpInterface>(nodeOp))
210  return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
211 
212  // Otherwise, check the number of symbol uses.
213  auto symbolIt = discardableSymNodeUses.find(node);
214  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
215 }
216 
217 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
218  Operation *parentOp = node->getCallableRegion()->getParentOp();
219  CGUser &uses = nodeUses[node];
220  decrementDiscardableUses(uses);
221 
222  // Collect the new discardable uses within this node.
223  uses = CGUser();
225  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
226  auto discardSymIt = discardableSymNodeUses.find(refNode);
227  if (discardSymIt == discardableSymNodeUses.end())
228  return;
229 
230  if (user != parentOp)
231  ++uses.innerUses[refNode];
232  else if (!uses.topLevelUses.insert(refNode).second)
233  return;
234  ++discardSymIt->second;
235  };
236  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
237 }
238 
239 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
240  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
241  for (auto &useIt : lhsUses.innerUses) {
242  rhsUses.innerUses[useIt.first] += useIt.second;
243  discardableSymNodeUses[useIt.first] += useIt.second;
244  }
245 }
246 
247 void CGUseList::decrementDiscardableUses(CGUser &uses) {
248  for (CallGraphNode *node : uses.topLevelUses)
249  --discardableSymNodeUses[node];
250  for (auto &it : uses.innerUses)
251  discardableSymNodeUses[it.first] -= it.second;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // CallGraph traversal
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class represents a specific callgraph SCC.
260 class CallGraphSCC {
261 public:
262  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
263  : parentIterator(parentIterator) {}
264  /// Return a range over the nodes within this SCC.
265  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
266  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
267 
268  /// Reset the nodes of this SCC with those provided.
269  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
270 
271  /// Remove the given node from this SCC.
272  void remove(CallGraphNode *node) {
273  auto it = llvm::find(nodes, node);
274  if (it != nodes.end()) {
275  nodes.erase(it);
276  parentIterator.ReplaceNode(node, nullptr);
277  }
278  }
279 
280 private:
281  std::vector<CallGraphNode *> nodes;
282  llvm::scc_iterator<const CallGraph *> &parentIterator;
283 };
284 } // namespace
285 
286 /// Run a given transformation over the SCCs of the callgraph in a bottom up
287 /// traversal.
289  const CallGraph &cg,
290  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
291  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
292  CallGraphSCC currentSCC(cgi);
293  while (!cgi.isAtEnd()) {
294  // Copy the current SCC and increment so that the transformer can modify the
295  // SCC without invalidating our iterator.
296  currentSCC.reset(*cgi);
297  ++cgi;
298  if (failed(sccTransformer(currentSCC)))
299  return failure();
300  }
301  return success();
302 }
303 
304 namespace {
305 /// This struct represents a resolved call to a given callgraph node. Given that
306 /// the call does not actually contain a direct reference to the
307 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
308 /// explicitly.
309 struct ResolvedCall {
310  ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
311  CallGraphNode *targetNode)
312  : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
313  CallOpInterface call;
314  CallGraphNode *sourceNode, *targetNode;
315 };
316 } // namespace
317 
318 /// Collect all of the callable operations within the given range of blocks. If
319 /// `traverseNestedCGNodes` is true, this will also collect call operations
320 /// inside of nested callgraph nodes.
322  CallGraphNode *sourceNode, CallGraph &cg,
323  SymbolTableCollection &symbolTable,
325  bool traverseNestedCGNodes) {
327  auto addToWorklist = [&](CallGraphNode *node,
329  for (Block &block : blocks)
330  worklist.emplace_back(&block, node);
331  };
332 
333  addToWorklist(sourceNode, blocks);
334  while (!worklist.empty()) {
335  Block *block;
336  std::tie(block, sourceNode) = worklist.pop_back_val();
337 
338  for (Operation &op : *block) {
339  if (auto call = dyn_cast<CallOpInterface>(op)) {
340  // TODO: Support inlining nested call references.
341  CallInterfaceCallable callable = call.getCallableForCallee();
342  if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
343  if (!symRef.isa<FlatSymbolRefAttr>())
344  continue;
345  }
346 
347  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
348  if (!targetNode->isExternal())
349  calls.emplace_back(call, sourceNode, targetNode);
350  continue;
351  }
352 
353  // If this is not a call, traverse the nested regions. If
354  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
355  // regions.
356  for (auto &nestedRegion : op.getRegions()) {
357  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
358  if (traverseNestedCGNodes || !nestedNode)
359  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
360  }
361  }
362  }
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // Inliner
367 //===----------------------------------------------------------------------===//
368 
369 #ifndef NDEBUG
370 static std::string getNodeName(CallOpInterface op) {
371  if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
372  return debugString(op);
373  return "_unnamed_callee_";
374 }
375 #endif
376 
377 /// Return true if the specified `inlineHistoryID` indicates an inline history
378 /// that already includes `node`.
380  CallGraphNode *node, Optional<size_t> inlineHistoryID,
382  inlineHistory) {
383  while (inlineHistoryID.has_value()) {
384  assert(inlineHistoryID.value() < inlineHistory.size() &&
385  "Invalid inline history ID");
386  if (inlineHistory[inlineHistoryID.value()].first == node)
387  return true;
388  inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
389  }
390  return false;
391 }
392 
393 namespace {
394 /// This class provides a specialization of the main inlining interface.
395 struct Inliner : public InlinerInterface {
396  Inliner(MLIRContext *context, CallGraph &cg,
397  SymbolTableCollection &symbolTable)
398  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
399 
400  /// Process a set of blocks that have been inlined. This callback is invoked
401  /// *before* inlined terminator operations have been processed.
402  void
403  processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
404  // Find the closest callgraph node from the first block.
405  CallGraphNode *node;
406  Region *region = inlinedBlocks.begin()->getParent();
407  while (!(node = cg.lookupNode(region))) {
408  region = region->getParentRegion();
409  assert(region && "expected valid parent node");
410  }
411 
412  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
413  /*traverseNestedCGNodes=*/true);
414  }
415 
416  /// Mark the given callgraph node for deletion.
417  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
418 
419  /// This method properly disposes of callables that became dead during
420  /// inlining. This should not be called while iterating over the SCCs.
421  void eraseDeadCallables() {
422  for (CallGraphNode *node : deadNodes)
423  node->getCallableRegion()->getParentOp()->erase();
424  }
425 
426  /// The set of callables known to be dead.
428 
429  /// The current set of call instructions to consider for inlining.
431 
432  /// The callgraph being operated on.
433  CallGraph &cg;
434 
435  /// A symbol table to use when resolving call lookups.
436  SymbolTableCollection &symbolTable;
437 };
438 } // namespace
439 
440 /// Returns true if the given call should be inlined.
441 static bool shouldInline(ResolvedCall &resolvedCall) {
442  // Don't allow inlining terminator calls. We currently don't support this
443  // case.
444  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
445  return false;
446 
447  // Don't allow inlining if the target is an ancestor of the call. This
448  // prevents inlining recursively.
449  if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
450  resolvedCall.call->getParentRegion()))
451  return false;
452 
453  // Otherwise, inline.
454  return true;
455 }
456 
457 /// Attempt to inline calls within the given scc. This function returns
458 /// success if any calls were inlined, failure otherwise.
459 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
460  CallGraphSCC &currentSCC) {
461  CallGraph &cg = inliner.cg;
462  auto &calls = inliner.calls;
463 
464  // A set of dead nodes to remove after inlining.
465  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
466 
467  // Collect all of the direct calls within the nodes of the current SCC. We
468  // don't traverse nested callgraph nodes, because they are handled separately
469  // likely within a different SCC.
470  for (CallGraphNode *node : currentSCC) {
471  if (node->isExternal())
472  continue;
473 
474  // Don't collect calls if the node is already dead.
475  if (useList.isDead(node)) {
476  deadNodes.insert(node);
477  } else {
478  collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
479  calls, /*traverseNestedCGNodes=*/false);
480  }
481  }
482 
483  // When inlining a callee produces new call sites, we want to keep track of
484  // the fact that they were inlined from the callee. This allows us to avoid
485  // infinite inlining.
486  using InlineHistoryT = Optional<size_t>;
488  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
489 
490  LLVM_DEBUG({
491  llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
492  for (unsigned i = 0, e = calls.size(); i < e; ++i)
493  llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
494  llvm::dbgs() << "}\n";
495  });
496 
497  // Try to inline each of the call operations. Don't cache the end iterator
498  // here as more calls may be added during inlining.
499  bool inlinedAnyCalls = false;
500  for (unsigned i = 0; i < calls.size(); ++i) {
501  if (deadNodes.contains(calls[i].sourceNode))
502  continue;
503  ResolvedCall it = calls[i];
504 
505  InlineHistoryT inlineHistoryID = callHistory[i];
506  bool inHistory =
507  inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
508  bool doInline = !inHistory && shouldInline(it);
509  CallOpInterface call = it.call;
510  LLVM_DEBUG({
511  if (doInline)
512  llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
513  else
514  llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
515  });
516  if (!doInline)
517  continue;
518 
519  unsigned prevSize = calls.size();
520  Region *targetRegion = it.targetNode->getCallableRegion();
521 
522  // If this is the last call to the target node and the node is discardable,
523  // then inline it in-place and delete the node if successful.
524  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
525 
526  LogicalResult inlineResult = inlineCall(
527  inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
528  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
529  if (failed(inlineResult)) {
530  LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
531  continue;
532  }
533  inlinedAnyCalls = true;
534 
535  // Create a inline history entry for this inlined call, so that we remember
536  // that new callsites came about due to inlining Callee.
537  InlineHistoryT newInlineHistoryID{inlineHistory.size()};
538  inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
539 
540  auto historyToString = [](InlineHistoryT h) {
541  return h.has_value() ? std::to_string(h.value()) : "root";
542  };
543  (void)historyToString;
544  LLVM_DEBUG(llvm::dbgs()
545  << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
546  << getNodeName(call) << ", " << historyToString(inlineHistoryID)
547  << "]\n");
548 
549  for (unsigned k = prevSize; k != calls.size(); ++k) {
550  callHistory.push_back(newInlineHistoryID);
551  LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
552  << "}\n with historyID = " << newInlineHistoryID
553  << ", added due to inlining of\n call {" << call
554  << "}\n with historyID = "
555  << historyToString(inlineHistoryID) << "\n");
556  }
557 
558  // If the inlining was successful, Merge the new uses into the source node.
559  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
560  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
561 
562  // then erase the call.
563  call.erase();
564 
565  // If we inlined in place, mark the node for deletion.
566  if (inlineInPlace) {
567  useList.eraseNode(it.targetNode);
568  deadNodes.insert(it.targetNode);
569  }
570  }
571 
572  for (CallGraphNode *node : deadNodes) {
573  currentSCC.remove(node);
574  inliner.markForDeletion(node);
575  }
576  calls.clear();
577  return success(inlinedAnyCalls);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // InlinerPass
582 //===----------------------------------------------------------------------===//
583 
584 namespace {
585 class InlinerPass : public InlinerBase<InlinerPass> {
586 public:
587  InlinerPass();
588  InlinerPass(const InlinerPass &) = default;
589  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
590  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
591  llvm::StringMap<OpPassManager> opPipelines);
592  void runOnOperation() override;
593 
594 private:
595  /// Attempt to inline calls within the given scc, and run simplifications,
596  /// until a fixed point is reached. This allows for the inlining of newly
597  /// devirtualized calls. Returns failure if there was a fatal error during
598  /// inlining.
599  LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
600  CallGraphSCC &currentSCC, MLIRContext *context);
601 
602  /// Optimize the nodes within the given SCC with one of the held optimization
603  /// pass pipelines. Returns failure if an error occurred during the
604  /// optimization of the SCC, success otherwise.
605  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
606  CallGraphSCC &currentSCC, MLIRContext *context);
607 
608  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
609  /// error occurred during the optimization of the SCC, success otherwise.
610  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
611  MLIRContext *context);
612 
613  /// Optimize the given callable node with one of the pass managers provided
614  /// with `pipelines`, or the default pipeline. Returns failure if an error
615  /// occurred during the optimization of the callable, success otherwise.
616  LogicalResult optimizeCallable(CallGraphNode *node,
617  llvm::StringMap<OpPassManager> &pipelines);
618 
619  /// Attempt to initialize the options of this pass from the given string.
620  /// Derived classes may override this method to hook into the point at which
621  /// options are initialized, but should generally always invoke this base
622  /// class variant.
623  LogicalResult initializeOptions(StringRef options) override;
624 
625  /// An optional function that constructs a default optimization pipeline for
626  /// a given operation.
627  std::function<void(OpPassManager &)> defaultPipeline;
628  /// A map of operation names to pass pipelines to use when optimizing
629  /// callable operations of these types. This provides a specialized pipeline
630  /// instead of the default. The vector size is the number of threads used
631  /// during optimization.
633 };
634 } // namespace
635 
636 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
637 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
638  : defaultPipeline(std::move(defaultPipeline)) {
639  opPipelines.push_back({});
640 
641  // Initialize the pass options with the provided arguments.
642  if (defaultPipeline) {
643  OpPassManager fakePM("__mlir_fake_pm_op");
644  defaultPipeline(fakePM);
645  llvm::raw_string_ostream strStream(defaultPipelineStr);
646  fakePM.printAsTextualPipeline(strStream);
647  }
648 }
649 
650 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
651  llvm::StringMap<OpPassManager> opPipelines)
652  : InlinerPass(std::move(defaultPipeline)) {
653  if (opPipelines.empty())
654  return;
655 
656  // Update the option for the op specific optimization pipelines.
657  for (auto &it : opPipelines)
658  opPipelineList.addValue(it.second);
659  this->opPipelines.emplace_back(std::move(opPipelines));
660 }
661 
662 void InlinerPass::runOnOperation() {
663  CallGraph &cg = getAnalysis<CallGraph>();
664  auto *context = &getContext();
665 
666  // The inliner should only be run on operations that define a symbol table,
667  // as the callgraph will need to resolve references.
668  Operation *op = getOperation();
669  if (!op->hasTrait<OpTrait::SymbolTable>()) {
670  op->emitOpError() << " was scheduled to run under the inliner, but does "
671  "not define a symbol table";
672  return signalPassFailure();
673  }
674 
675  // Run the inline transform in post-order over the SCCs in the callgraph.
676  SymbolTableCollection symbolTable;
677  Inliner inliner(context, cg, symbolTable);
678  CGUseList useList(getOperation(), cg, symbolTable);
679  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
680  return inlineSCC(inliner, useList, scc, context);
681  });
682  if (failed(result))
683  return signalPassFailure();
684 
685  // After inlining, make sure to erase any callables proven to be dead.
686  inliner.eraseDeadCallables();
687 }
688 
689 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
690  CallGraphSCC &currentSCC,
691  MLIRContext *context) {
692  // Continuously simplify and inline until we either reach a fixed point, or
693  // hit the maximum iteration count. Simplifying early helps to refine the cost
694  // model, and in future iterations may devirtualize new calls.
695  unsigned iterationCount = 0;
696  do {
697  if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
698  return failure();
699  if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
700  break;
701  } while (++iterationCount < maxInliningIterations);
702  return success();
703 }
704 
705 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
706  CallGraphSCC &currentSCC,
707  MLIRContext *context) {
708  // Collect the sets of nodes to simplify.
709  SmallVector<CallGraphNode *, 4> nodesToVisit;
710  for (auto *node : currentSCC) {
711  if (node->isExternal())
712  continue;
713 
714  // Don't simplify nodes with children. Nodes with children require special
715  // handling as we may remove the node during simplification. In the future,
716  // we should be able to handle this case with proper node deletion tracking.
717  if (node->hasChildren())
718  continue;
719 
720  // We also won't apply simplifications to nodes that can't have passes
721  // scheduled on them.
722  auto *region = node->getCallableRegion();
723  if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
724  continue;
725  nodesToVisit.push_back(node);
726  }
727  if (nodesToVisit.empty())
728  return success();
729 
730  // Optimize each of the nodes within the SCC in parallel.
731  if (failed(optimizeSCCAsync(nodesToVisit, context)))
732  return failure();
733 
734  // Recompute the uses held by each of the nodes.
735  for (CallGraphNode *node : nodesToVisit)
736  useList.recomputeUses(node, cg);
737  return success();
738 }
739 
741 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
742  MLIRContext *ctx) {
743  // We must maintain a fixed pool of pass managers which is at least as large
744  // as the maximum parallelism of the failableParallelForEach below.
745  // Note: The number of pass managers here needs to remain constant
746  // to prevent issues with pass instrumentations that rely on having the same
747  // pass manager for the main thread.
748  size_t numThreads = ctx->getNumThreads();
749  if (opPipelines.size() < numThreads) {
750  // Reserve before resizing so that we can use a reference to the first
751  // element.
752  opPipelines.reserve(numThreads);
753  opPipelines.resize(numThreads, opPipelines.front());
754  }
755 
756  // Ensure an analysis manager has been constructed for each of the nodes.
757  // This prevents thread races when running the nested pipelines.
758  for (CallGraphNode *node : nodesToVisit)
759  getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
760 
761  // An atomic failure variable for the async executors.
762  std::vector<std::atomic<bool>> activePMs(opPipelines.size());
763  std::fill(activePMs.begin(), activePMs.end(), false);
764  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
765  // Find a pass manager for this operation.
766  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
767  bool expectedInactive = false;
768  return isActive.compare_exchange_strong(expectedInactive, true);
769  });
770  assert(it != activePMs.end() &&
771  "could not find inactive pass manager for thread");
772  unsigned pmIndex = it - activePMs.begin();
773 
774  // Optimize this callable node.
775  LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
776 
777  // Reset the active bit for this pass manager.
778  activePMs[pmIndex].store(false);
779  return result;
780  });
781 }
782 
784 InlinerPass::optimizeCallable(CallGraphNode *node,
785  llvm::StringMap<OpPassManager> &pipelines) {
786  Operation *callable = node->getCallableRegion()->getParentOp();
787  StringRef opName = callable->getName().getStringRef();
788  auto pipelineIt = pipelines.find(opName);
789  if (pipelineIt == pipelines.end()) {
790  // If a pipeline didn't exist, use the default if possible.
791  if (!defaultPipeline)
792  return success();
793 
794  OpPassManager defaultPM(opName);
795  defaultPipeline(defaultPM);
796  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
797  }
798  return runPipeline(pipelineIt->second, callable);
799 }
800 
801 LogicalResult InlinerPass::initializeOptions(StringRef options) {
802  if (failed(Pass::initializeOptions(options)))
803  return failure();
804 
805  // Initialize the default pipeline builder to use the option string.
806  // TODO: Use a generic pass manager for default pipelines, and remove this.
807  if (!defaultPipelineStr.empty()) {
808  std::string defaultPipelineCopy = defaultPipelineStr;
809  defaultPipeline = [=](OpPassManager &pm) {
810  (void)parsePassPipeline(defaultPipelineCopy, pm);
811  };
812  } else if (defaultPipelineStr.getNumOccurrences()) {
813  defaultPipeline = nullptr;
814  }
815 
816  // Initialize the op specific pass pipelines.
817  llvm::StringMap<OpPassManager> pipelines;
818  for (OpPassManager pipeline : opPipelineList)
819  if (!pipeline.empty())
820  pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
821  opPipelines.assign({std::move(pipelines)});
822 
823  return success();
824 }
825 
826 std::unique_ptr<Pass> mlir::createInlinerPass() {
827  return std::make_unique<InlinerPass>();
828 }
829 std::unique_ptr<Pass>
830 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
831  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
832  std::move(opPipelines));
833 }
834 std::unique_ptr<Pass> mlir::createInlinerPass(
835  llvm::StringMap<OpPassManager> opPipelines,
836  std::function<void(OpPassManager &)> defaultPipelineBuilder) {
837  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
838  std::move(opPipelines));
839 }
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:133
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:288
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:478
Block represents an ordered list of Operations.
Definition: Block.h:29
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
Definition: Inliner.cpp:33
A symbol reference with a reference path containing a single element.
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:624
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:697
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode *> &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:42
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:151
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:141
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
static bool inlineHistoryIncludes(CallGraphNode *node, Optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, Optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node...
Definition: Inliner.cpp:379
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
Definition: Inliner.cpp:826
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, &#39;src&#39;, of a callable operation, &#39;callable&#39;, into the location d...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static std::string getNodeName(CallOpInterface op)
Definition: Inliner.cpp:370
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
iterator_range< OpIterator > getOps()
Definition: Region.h:172
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:43
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:526
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:172
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:627
This interface provides the hooks into the inlining interface.
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:321
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:28
This class provides the API for ops that are known to be isolated from above.
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:372
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to &#39;pm&#39; on success...
static std::string debugString(T &&op)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:333
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC)
Attempt to inline calls within the given scc.
Definition: Inliner.cpp:459
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
This class represents a specific symbol use.
Definition: SymbolTable.h:144
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:32
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
This class represents a pass manager that runs passes on either a specific operation type...
Definition: PassManager.h:52
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:55
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.
Definition: Inliner.cpp:441