MLIR  14.0.0git
Inliner.cpp
Go to the documentation of this file.
1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
18 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 /// This function implements the default inliner optimization pipeline.
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Symbol Use Tracking
37 //===----------------------------------------------------------------------===//
38 
39 /// Walk all of the used symbol callgraph nodes referenced with the given op.
41  Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
43  function_ref<void(CallGraphNode *, Operation *)> callback) {
44  auto symbolUses = SymbolTable::getSymbolUses(op);
45  assert(symbolUses && "expected uses to be valid");
46 
47  Operation *symbolTableOp = op->getParentOp();
48  for (const SymbolTable::SymbolUse &use : *symbolUses) {
49  auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
50  CallGraphNode *&node = refIt.first->second;
51 
52  // If this is the first instance of this reference, try to resolve a
53  // callgraph node for it.
54  if (refIt.second) {
55  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
56  use.getSymbolRef());
57  auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
58  if (!callableOp)
59  continue;
60  node = cg.lookupNode(callableOp.getCallableRegion());
61  }
62  if (node)
63  callback(node, use.getUser());
64  }
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // CGUseList
69 
70 namespace {
71 /// This struct tracks the uses of callgraph nodes that can be dropped when
72 /// use_empty. It directly tracks and manages a use-list for all of the
73 /// call-graph nodes. This is necessary because many callgraph nodes are
74 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
75 /// class.
76 struct CGUseList {
77  /// This struct tracks the uses of callgraph nodes within a specific
78  /// operation.
79  struct CGUser {
80  /// Any nodes referenced in the top-level attribute list of this user. We
81  /// use a set here because the number of references does not matter.
82  DenseSet<CallGraphNode *> topLevelUses;
83 
84  /// Uses of nodes referenced by nested operations.
86  };
87 
88  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
89 
90  /// Drop uses of nodes referred to by the given call operation that resides
91  /// within 'userNode'.
92  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
93 
94  /// Remove the given node from the use list.
95  void eraseNode(CallGraphNode *node);
96 
97  /// Returns true if the given callgraph node has no uses and can be pruned.
98  bool isDead(CallGraphNode *node) const;
99 
100  /// Returns true if the given callgraph node has a single use and can be
101  /// discarded.
102  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
103 
104  /// Recompute the uses held by the given callgraph node.
105  void recomputeUses(CallGraphNode *node, CallGraph &cg);
106 
107  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
108  /// of 'lhs' into 'rhs'.
109  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
110 
111 private:
112  /// Decrement the uses of discardable nodes referenced by the given user.
113  void decrementDiscardableUses(CGUser &uses);
114 
115  /// A mapping between a discardable callgraph node (that is a symbol) and the
116  /// number of uses for this node.
117  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
118 
119  /// A mapping between a callgraph node and the symbol callgraph nodes that it
120  /// uses.
122 
123  /// A symbol table to use when resolving call lookups.
124  SymbolTableCollection &symbolTable;
125 };
126 } // namespace
127 
128 CGUseList::CGUseList(Operation *op, CallGraph &cg,
129  SymbolTableCollection &symbolTable)
130  : symbolTable(symbolTable) {
131  /// A set of callgraph nodes that are always known to be live during inlining.
132  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
133 
134  // Walk each of the symbol tables looking for discardable callgraph nodes.
135  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
136  for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
137  // If this is a callgraph operation, check to see if it is discardable.
138  if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
139  if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
140  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
141  if (symbol && (allUsesVisible || symbol.isPrivate()) &&
142  symbol.canDiscardOnUseEmpty()) {
143  discardableSymNodeUses.try_emplace(node, 0);
144  }
145  continue;
146  }
147  }
148  // Otherwise, check for any referenced nodes. These will be always-live.
149  walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
150  [](CallGraphNode *, Operation *) {});
151  }
152  };
153  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
154  walkFn);
155 
156  // Drop the use information for any discardable nodes that are always live.
157  for (auto &it : alwaysLiveNodes)
158  discardableSymNodeUses.erase(it.second);
159 
160  // Compute the uses for each of the callable nodes in the graph.
161  for (CallGraphNode *node : cg)
162  recomputeUses(node, cg);
163 }
164 
165 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
166  CallGraph &cg) {
167  auto &userRefs = nodeUses[userNode].innerUses;
168  auto walkFn = [&](CallGraphNode *node, Operation *user) {
169  auto parentIt = userRefs.find(node);
170  if (parentIt == userRefs.end())
171  return;
172  --parentIt->second;
173  --discardableSymNodeUses[node];
174  };
176  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
177 }
178 
179 void CGUseList::eraseNode(CallGraphNode *node) {
180  // Drop all child nodes.
181  for (auto &edge : *node)
182  if (edge.isChild())
183  eraseNode(edge.getTarget());
184 
185  // Drop the uses held by this node and erase it.
186  auto useIt = nodeUses.find(node);
187  assert(useIt != nodeUses.end() && "expected node to be valid");
188  decrementDiscardableUses(useIt->getSecond());
189  nodeUses.erase(useIt);
190  discardableSymNodeUses.erase(node);
191 }
192 
193 bool CGUseList::isDead(CallGraphNode *node) const {
194  // If the parent operation isn't a symbol, simply check normal SSA deadness.
195  Operation *nodeOp = node->getCallableRegion()->getParentOp();
196  if (!isa<SymbolOpInterface>(nodeOp))
197  return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
198 
199  // Otherwise, check the number of symbol uses.
200  auto symbolIt = discardableSymNodeUses.find(node);
201  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
202 }
203 
204 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
205  // If this isn't a symbol node, check for side-effects and SSA use count.
206  Operation *nodeOp = node->getCallableRegion()->getParentOp();
207  if (!isa<SymbolOpInterface>(nodeOp))
208  return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
209 
210  // Otherwise, check the number of symbol uses.
211  auto symbolIt = discardableSymNodeUses.find(node);
212  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
213 }
214 
215 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
216  Operation *parentOp = node->getCallableRegion()->getParentOp();
217  CGUser &uses = nodeUses[node];
218  decrementDiscardableUses(uses);
219 
220  // Collect the new discardable uses within this node.
221  uses = CGUser();
223  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
224  auto discardSymIt = discardableSymNodeUses.find(refNode);
225  if (discardSymIt == discardableSymNodeUses.end())
226  return;
227 
228  if (user != parentOp)
229  ++uses.innerUses[refNode];
230  else if (!uses.topLevelUses.insert(refNode).second)
231  return;
232  ++discardSymIt->second;
233  };
234  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
235 }
236 
237 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
238  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
239  for (auto &useIt : lhsUses.innerUses) {
240  rhsUses.innerUses[useIt.first] += useIt.second;
241  discardableSymNodeUses[useIt.first] += useIt.second;
242  }
243 }
244 
245 void CGUseList::decrementDiscardableUses(CGUser &uses) {
246  for (CallGraphNode *node : uses.topLevelUses)
247  --discardableSymNodeUses[node];
248  for (auto &it : uses.innerUses)
249  discardableSymNodeUses[it.first] -= it.second;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // CallGraph traversal
254 //===----------------------------------------------------------------------===//
255 
256 namespace {
257 /// This class represents a specific callgraph SCC.
258 class CallGraphSCC {
259 public:
260  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
261  : parentIterator(parentIterator) {}
262  /// Return a range over the nodes within this SCC.
263  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
264  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
265 
266  /// Reset the nodes of this SCC with those provided.
267  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
268 
269  /// Remove the given node from this SCC.
270  void remove(CallGraphNode *node) {
271  auto it = llvm::find(nodes, node);
272  if (it != nodes.end()) {
273  nodes.erase(it);
274  parentIterator.ReplaceNode(node, nullptr);
275  }
276  }
277 
278 private:
279  std::vector<CallGraphNode *> nodes;
280  llvm::scc_iterator<const CallGraph *> &parentIterator;
281 };
282 } // namespace
283 
284 /// Run a given transformation over the SCCs of the callgraph in a bottom up
285 /// traversal.
287  const CallGraph &cg,
288  function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
289  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
290  CallGraphSCC currentSCC(cgi);
291  while (!cgi.isAtEnd()) {
292  // Copy the current SCC and increment so that the transformer can modify the
293  // SCC without invalidating our iterator.
294  currentSCC.reset(*cgi);
295  ++cgi;
296  if (failed(sccTransformer(currentSCC)))
297  return failure();
298  }
299  return success();
300 }
301 
302 namespace {
303 /// This struct represents a resolved call to a given callgraph node. Given that
304 /// the call does not actually contain a direct reference to the
305 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
306 /// explicitly.
307 struct ResolvedCall {
308  ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
309  CallGraphNode *targetNode)
310  : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
311  CallOpInterface call;
312  CallGraphNode *sourceNode, *targetNode;
313 };
314 } // namespace
315 
316 /// Collect all of the callable operations within the given range of blocks. If
317 /// `traverseNestedCGNodes` is true, this will also collect call operations
318 /// inside of nested callgraph nodes.
320  CallGraphNode *sourceNode, CallGraph &cg,
321  SymbolTableCollection &symbolTable,
323  bool traverseNestedCGNodes) {
325  auto addToWorklist = [&](CallGraphNode *node,
327  for (Block &block : blocks)
328  worklist.emplace_back(&block, node);
329  };
330 
331  addToWorklist(sourceNode, blocks);
332  while (!worklist.empty()) {
333  Block *block;
334  std::tie(block, sourceNode) = worklist.pop_back_val();
335 
336  for (Operation &op : *block) {
337  if (auto call = dyn_cast<CallOpInterface>(op)) {
338  // TODO: Support inlining nested call references.
339  CallInterfaceCallable callable = call.getCallableForCallee();
340  if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
341  if (!symRef.isa<FlatSymbolRefAttr>())
342  continue;
343  }
344 
345  CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
346  if (!targetNode->isExternal())
347  calls.emplace_back(call, sourceNode, targetNode);
348  continue;
349  }
350 
351  // If this is not a call, traverse the nested regions. If
352  // `traverseNestedCGNodes` is false, then don't traverse nested call graph
353  // regions.
354  for (auto &nestedRegion : op.getRegions()) {
355  CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
356  if (traverseNestedCGNodes || !nestedNode)
357  addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
358  }
359  }
360  }
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // Inliner
365 //===----------------------------------------------------------------------===//
366 namespace {
367 /// This class provides a specialization of the main inlining interface.
368 struct Inliner : public InlinerInterface {
369  Inliner(MLIRContext *context, CallGraph &cg,
370  SymbolTableCollection &symbolTable)
371  : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
372 
373  /// Process a set of blocks that have been inlined. This callback is invoked
374  /// *before* inlined terminator operations have been processed.
375  void
376  processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
377  // Find the closest callgraph node from the first block.
378  CallGraphNode *node;
379  Region *region = inlinedBlocks.begin()->getParent();
380  while (!(node = cg.lookupNode(region))) {
381  region = region->getParentRegion();
382  assert(region && "expected valid parent node");
383  }
384 
385  collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
386  /*traverseNestedCGNodes=*/true);
387  }
388 
389  /// Mark the given callgraph node for deletion.
390  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
391 
392  /// This method properly disposes of callables that became dead during
393  /// inlining. This should not be called while iterating over the SCCs.
394  void eraseDeadCallables() {
395  for (CallGraphNode *node : deadNodes)
396  node->getCallableRegion()->getParentOp()->erase();
397  }
398 
399  /// The set of callables known to be dead.
401 
402  /// The current set of call instructions to consider for inlining.
404 
405  /// The callgraph being operated on.
406  CallGraph &cg;
407 
408  /// A symbol table to use when resolving call lookups.
409  SymbolTableCollection &symbolTable;
410 };
411 } // namespace
412 
413 /// Returns true if the given call should be inlined.
414 static bool shouldInline(ResolvedCall &resolvedCall) {
415  // Don't allow inlining terminator calls. We currently don't support this
416  // case.
417  if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
418  return false;
419 
420  // Don't allow inlining if the target is an ancestor of the call. This
421  // prevents inlining recursively.
422  if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
423  resolvedCall.call->getParentRegion()))
424  return false;
425 
426  // Otherwise, inline.
427  return true;
428 }
429 
430 /// Attempt to inline calls within the given scc. This function returns
431 /// success if any calls were inlined, failure otherwise.
432 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
433  CallGraphSCC &currentSCC) {
434  CallGraph &cg = inliner.cg;
435  auto &calls = inliner.calls;
436 
437  // A set of dead nodes to remove after inlining.
438  llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
439 
440  // Collect all of the direct calls within the nodes of the current SCC. We
441  // don't traverse nested callgraph nodes, because they are handled separately
442  // likely within a different SCC.
443  for (CallGraphNode *node : currentSCC) {
444  if (node->isExternal())
445  continue;
446 
447  // Don't collect calls if the node is already dead.
448  if (useList.isDead(node)) {
449  deadNodes.insert(node);
450  } else {
451  collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
452  calls, /*traverseNestedCGNodes=*/false);
453  }
454  }
455 
456  // Try to inline each of the call operations. Don't cache the end iterator
457  // here as more calls may be added during inlining.
458  bool inlinedAnyCalls = false;
459  for (unsigned i = 0; i != calls.size(); ++i) {
460  if (deadNodes.contains(calls[i].sourceNode))
461  continue;
462  ResolvedCall it = calls[i];
463  bool doInline = shouldInline(it);
464  CallOpInterface call = it.call;
465  LLVM_DEBUG({
466  if (doInline)
467  llvm::dbgs() << "* Inlining call: " << call << "\n";
468  else
469  llvm::dbgs() << "* Not inlining call: " << call << "\n";
470  });
471  if (!doInline)
472  continue;
473  Region *targetRegion = it.targetNode->getCallableRegion();
474 
475  // If this is the last call to the target node and the node is discardable,
476  // then inline it in-place and delete the node if successful.
477  bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
478 
479  LogicalResult inlineResult = inlineCall(
480  inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
481  targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
482  if (failed(inlineResult)) {
483  LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
484  continue;
485  }
486  inlinedAnyCalls = true;
487 
488  // If the inlining was successful, Merge the new uses into the source node.
489  useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
490  useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
491 
492  // then erase the call.
493  call.erase();
494 
495  // If we inlined in place, mark the node for deletion.
496  if (inlineInPlace) {
497  useList.eraseNode(it.targetNode);
498  deadNodes.insert(it.targetNode);
499  }
500  }
501 
502  for (CallGraphNode *node : deadNodes) {
503  currentSCC.remove(node);
504  inliner.markForDeletion(node);
505  }
506  calls.clear();
507  return success(inlinedAnyCalls);
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // InlinerPass
512 //===----------------------------------------------------------------------===//
513 
514 namespace {
515 class InlinerPass : public InlinerBase<InlinerPass> {
516 public:
517  InlinerPass();
518  InlinerPass(const InlinerPass &) = default;
519  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
520  InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
521  llvm::StringMap<OpPassManager> opPipelines);
522  void runOnOperation() override;
523 
524 private:
525  /// Attempt to inline calls within the given scc, and run simplifications,
526  /// until a fixed point is reached. This allows for the inlining of newly
527  /// devirtualized calls. Returns failure if there was a fatal error during
528  /// inlining.
529  LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
530  CallGraphSCC &currentSCC, MLIRContext *context);
531 
532  /// Optimize the nodes within the given SCC with one of the held optimization
533  /// pass pipelines. Returns failure if an error occurred during the
534  /// optimization of the SCC, success otherwise.
535  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
536  CallGraphSCC &currentSCC, MLIRContext *context);
537 
538  /// Optimize the nodes within the given SCC in parallel. Returns failure if an
539  /// error occurred during the optimization of the SCC, success otherwise.
540  LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
541  MLIRContext *context);
542 
543  /// Optimize the given callable node with one of the pass managers provided
544  /// with `pipelines`, or the default pipeline. Returns failure if an error
545  /// occurred during the optimization of the callable, success otherwise.
546  LogicalResult optimizeCallable(CallGraphNode *node,
547  llvm::StringMap<OpPassManager> &pipelines);
548 
549  /// Attempt to initialize the options of this pass from the given string.
550  /// Derived classes may override this method to hook into the point at which
551  /// options are initialized, but should generally always invoke this base
552  /// class variant.
553  LogicalResult initializeOptions(StringRef options) override;
554 
555  /// An optional function that constructs a default optimization pipeline for
556  /// a given operation.
557  std::function<void(OpPassManager &)> defaultPipeline;
558  /// A map of operation names to pass pipelines to use when optimizing
559  /// callable operations of these types. This provides a specialized pipeline
560  /// instead of the default. The vector size is the number of threads used
561  /// during optimization.
563 };
564 } // namespace
565 
566 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
567 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
568  : defaultPipeline(std::move(defaultPipeline)) {
569  opPipelines.push_back({});
570 
571  // Initialize the pass options with the provided arguments.
572  if (defaultPipeline) {
573  OpPassManager fakePM("__mlir_fake_pm_op");
574  defaultPipeline(fakePM);
575  llvm::raw_string_ostream strStream(defaultPipelineStr);
576  fakePM.printAsTextualPipeline(strStream);
577  }
578 }
579 
580 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
581  llvm::StringMap<OpPassManager> opPipelines)
582  : InlinerPass(std::move(defaultPipeline)) {
583  if (opPipelines.empty())
584  return;
585 
586  // Update the option for the op specific optimization pipelines.
587  for (auto &it : opPipelines) {
588  std::string pipeline;
589  llvm::raw_string_ostream pipelineOS(pipeline);
590  pipelineOS << it.getKey() << "(";
591  it.second.printAsTextualPipeline(pipelineOS);
592  pipelineOS << ")";
593  opPipelineStrs.addValue(pipeline);
594  }
595  this->opPipelines.emplace_back(std::move(opPipelines));
596 }
597 
598 void InlinerPass::runOnOperation() {
599  CallGraph &cg = getAnalysis<CallGraph>();
600  auto *context = &getContext();
601 
602  // The inliner should only be run on operations that define a symbol table,
603  // as the callgraph will need to resolve references.
604  Operation *op = getOperation();
605  if (!op->hasTrait<OpTrait::SymbolTable>()) {
606  op->emitOpError() << " was scheduled to run under the inliner, but does "
607  "not define a symbol table";
608  return signalPassFailure();
609  }
610 
611  // Run the inline transform in post-order over the SCCs in the callgraph.
612  SymbolTableCollection symbolTable;
613  Inliner inliner(context, cg, symbolTable);
614  CGUseList useList(getOperation(), cg, symbolTable);
615  LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
616  return inlineSCC(inliner, useList, scc, context);
617  });
618  if (failed(result))
619  return signalPassFailure();
620 
621  // After inlining, make sure to erase any callables proven to be dead.
622  inliner.eraseDeadCallables();
623 }
624 
625 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
626  CallGraphSCC &currentSCC,
627  MLIRContext *context) {
628  // Continuously simplify and inline until we either reach a fixed point, or
629  // hit the maximum iteration count. Simplifying early helps to refine the cost
630  // model, and in future iterations may devirtualize new calls.
631  unsigned iterationCount = 0;
632  do {
633  if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
634  return failure();
635  if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
636  break;
637  } while (++iterationCount < maxInliningIterations);
638  return success();
639 }
640 
641 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
642  CallGraphSCC &currentSCC,
643  MLIRContext *context) {
644  // Collect the sets of nodes to simplify.
645  SmallVector<CallGraphNode *, 4> nodesToVisit;
646  for (auto *node : currentSCC) {
647  if (node->isExternal())
648  continue;
649 
650  // Don't simplify nodes with children. Nodes with children require special
651  // handling as we may remove the node during simplification. In the future,
652  // we should be able to handle this case with proper node deletion tracking.
653  if (node->hasChildren())
654  continue;
655 
656  // We also won't apply simplifications to nodes that can't have passes
657  // scheduled on them.
658  auto *region = node->getCallableRegion();
659  if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
660  continue;
661  nodesToVisit.push_back(node);
662  }
663  if (nodesToVisit.empty())
664  return success();
665 
666  // Optimize each of the nodes within the SCC in parallel.
667  if (failed(optimizeSCCAsync(nodesToVisit, context)))
668  return failure();
669 
670  // Recompute the uses held by each of the nodes.
671  for (CallGraphNode *node : nodesToVisit)
672  useList.recomputeUses(node, cg);
673  return success();
674 }
675 
677 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
678  MLIRContext *ctx) {
679  // We must maintain a fixed pool of pass managers which is at least as large
680  // as the maximum parallelism of the failableParallelForEach below.
681  // Note: The number of pass managers here needs to remain constant
682  // to prevent issues with pass instrumentations that rely on having the same
683  // pass manager for the main thread.
684  size_t numThreads = ctx->getNumThreads();
685  if (opPipelines.size() < numThreads) {
686  // Reserve before resizing so that we can use a reference to the first
687  // element.
688  opPipelines.reserve(numThreads);
689  opPipelines.resize(numThreads, opPipelines.front());
690  }
691 
692  // Ensure an analysis manager has been constructed for each of the nodes.
693  // This prevents thread races when running the nested pipelines.
694  for (CallGraphNode *node : nodesToVisit)
695  getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
696 
697  // An atomic failure variable for the async executors.
698  std::vector<std::atomic<bool>> activePMs(opPipelines.size());
699  std::fill(activePMs.begin(), activePMs.end(), false);
700  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
701  // Find a pass manager for this operation.
702  auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
703  bool expectedInactive = false;
704  return isActive.compare_exchange_strong(expectedInactive, true);
705  });
706  assert(it != activePMs.end() &&
707  "could not find inactive pass manager for thread");
708  unsigned pmIndex = it - activePMs.begin();
709 
710  // Optimize this callable node.
711  LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
712 
713  // Reset the active bit for this pass manager.
714  activePMs[pmIndex].store(false);
715  return result;
716  });
717 }
718 
720 InlinerPass::optimizeCallable(CallGraphNode *node,
721  llvm::StringMap<OpPassManager> &pipelines) {
722  Operation *callable = node->getCallableRegion()->getParentOp();
723  StringRef opName = callable->getName().getStringRef();
724  auto pipelineIt = pipelines.find(opName);
725  if (pipelineIt == pipelines.end()) {
726  // If a pipeline didn't exist, use the default if possible.
727  if (!defaultPipeline)
728  return success();
729 
730  OpPassManager defaultPM(opName);
731  defaultPipeline(defaultPM);
732  pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
733  }
734  return runPipeline(pipelineIt->second, callable);
735 }
736 
737 LogicalResult InlinerPass::initializeOptions(StringRef options) {
738  if (failed(Pass::initializeOptions(options)))
739  return failure();
740 
741  // Initialize the default pipeline builder to use the option string.
742  if (!defaultPipelineStr.empty()) {
743  std::string defaultPipelineCopy = defaultPipelineStr;
744  defaultPipeline = [=](OpPassManager &pm) {
745  (void)parsePassPipeline(defaultPipelineCopy, pm);
746  };
747  } else if (defaultPipelineStr.getNumOccurrences()) {
748  defaultPipeline = nullptr;
749  }
750 
751  // Initialize the op specific pass pipelines.
752  llvm::StringMap<OpPassManager> pipelines;
753  for (StringRef pipeline : opPipelineStrs) {
754  // Skip empty pipelines.
755  if (pipeline.empty())
756  continue;
758  if (failed(pm))
759  return failure();
760  pipelines.try_emplace(pm->getOpName(), std::move(*pm));
761  }
762  opPipelines.assign({std::move(pipelines)});
763 
764  return success();
765 }
766 
767 std::unique_ptr<Pass> mlir::createInlinerPass() {
768  return std::make_unique<InlinerPass>();
769 }
770 std::unique_ptr<Pass>
771 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
772  return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
773  std::move(opPipelines));
774 }
775 std::unique_ptr<Pass> mlir::createInlinerPass(
776  llvm::StringMap<OpPassManager> opPipelines,
777  std::function<void(OpPassManager &)> defaultPipelineBuilder) {
778  return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
779  std::move(opPipelines));
780 }
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:133
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition: Inliner.cpp:286
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
Block represents an ordered list of Operations.
Definition: Block.h:29
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
Definition: Inliner.cpp:31
A symbol reference with a reference path containing a single element.
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:567
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:706
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode *> &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition: Inliner.cpp:40
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:141
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
Definition: Inliner.cpp:767
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition: Threading.h:36
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, &#39;src&#39;, of a callable operation, &#39;callable&#39;, into the location d...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
iterator_range< OpIterator > getOps()
Definition: Region.h:172
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
Definition: Pass.cpp:42
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:570
This interface provides the hooks into the inlining interface.
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition: Inliner.cpp:319
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:28
This class provides the API for ops that are known to be isolated from above.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to &#39;pm&#39; on success...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void printAsTextualPipeline(raw_ostream &os)
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:296
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:266
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC)
Attempt to inline calls within the given scc.
Definition: Inliner.cpp:432
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
This class represents a specific symbol use.
Definition: SymbolTable.h:144
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:32
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
This class represents a pass manager that runs passes on a specific operation type.
Definition: PassManager.h:51
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:55
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.
Definition: Inliner.cpp:414