MLIR  16.0.0git
CallGraph.cpp
Go to the documentation of this file.
1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // CallGraphNode
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns true if this node refers to the indirect/external node.
28 bool CallGraphNode::isExternal() const { return !callableRegion; }
29 
30 /// Return the callable region this node represents. This can only be called
31 /// on non-external nodes.
33  assert(!isExternal() && "the external node has no callable region");
34  return callableRegion;
35 }
36 
37 /// Adds an reference edge to the given node. This is only valid on the
38 /// external node.
40  assert(isExternal() && "abstract edges are only valid on external nodes");
41  addEdge(node, Edge::Kind::Abstract);
42 }
43 
44 /// Add an outgoing call edge from this node.
46  addEdge(node, Edge::Kind::Call);
47 }
48 
49 /// Adds a reference edge to the given child node.
51  addEdge(child, Edge::Kind::Child);
52 }
53 
54 /// Returns true if this node has any child edges.
56  return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
57 }
58 
59 /// Add an edge to 'node' with the given kind.
60 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
61  edges.insert({node, kind});
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // CallGraph
66 //===----------------------------------------------------------------------===//
67 
68 /// Recursively compute the callgraph edges for the given operation. Computed
69 /// edges are placed into the given callgraph object.
70 static void computeCallGraph(Operation *op, CallGraph &cg,
71  SymbolTableCollection &symbolTable,
72  CallGraphNode *parentNode, bool resolveCalls) {
73  if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
74  // If there is no parent node, we ignore this operation. Even if this
75  // operation was a call, there would be no callgraph node to attribute it
76  // to.
77  if (resolveCalls && parentNode)
78  parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
79  return;
80  }
81 
82  // Compute the callgraph nodes and edges for each of the nested operations.
83  if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
84  if (auto *callableRegion = callable.getCallableRegion())
85  parentNode = cg.getOrAddNode(callableRegion, parentNode);
86  else
87  return;
88  }
89 
90  for (Region &region : op->getRegions())
91  for (Operation &nested : region.getOps())
92  computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
93 }
94 
96  : externalCallerNode(/*callableRegion=*/nullptr),
97  unknownCalleeNode(/*callableRegion=*/nullptr) {
98  // Make two passes over the graph, one to compute the callables and one to
99  // resolve the calls. We split these up as we may have nested callable objects
100  // that need to be reserved before the calls.
101  SymbolTableCollection symbolTable;
102  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
103  /*resolveCalls=*/false);
104  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
105  /*resolveCalls=*/true);
106 }
107 
108 /// Get or add a call graph node for the given region.
110  CallGraphNode *parentNode) {
111  assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
112  "expected parent operation to be callable");
113  std::unique_ptr<CallGraphNode> &node = nodes[region];
114  if (!node) {
115  node.reset(new CallGraphNode(region));
116 
117  // Add this node to the given parent node if necessary.
118  if (parentNode) {
119  parentNode->addChildEdge(node.get());
120  } else {
121  // Otherwise, connect all callable nodes to the external node, this allows
122  // for conservatively including all callable nodes within the graph.
123  // FIXME This isn't correct, this is only necessary for callable nodes
124  // that *could* be called from external sources. This requires extending
125  // the interface for callables to check if they may be referenced
126  // externally.
127  externalCallerNode.addAbstractEdge(node.get());
128  }
129  }
130  return node.get();
131 }
132 
133 /// Lookup a call graph node for the given region, or nullptr if none is
134 /// registered.
136  auto it = nodes.find(region);
137  return it == nodes.end() ? nullptr : it->second.get();
138 }
139 
140 /// Resolve the callable for given callee to a node in the callgraph, or the
141 /// unknown callee node if a valid node was not resolved.
143 CallGraph::resolveCallable(CallOpInterface call,
144  SymbolTableCollection &symbolTable) const {
145  Operation *callable = call.resolveCallable(&symbolTable);
146  if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
147  if (auto *node = lookupNode(callableOp.getCallableRegion()))
148  return node;
149 
150  return getUnknownCalleeNode();
151 }
152 
153 /// Erase the given node from the callgraph.
155  // Erase any children of this node first.
156  if (node->hasChildren()) {
157  for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
158  if (edge.isChild())
159  eraseNode(edge.getTarget());
160  }
161  // Erase any edges to this node from any other nodes.
162  for (auto &it : nodes) {
163  it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
164  return edge.getTarget() == node;
165  });
166  }
167  nodes.erase(node->getCallableRegion());
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // Printing
172 
173 /// Dump the graph in a human readable format.
174 void CallGraph::dump() const { print(llvm::errs()); }
175 void CallGraph::print(raw_ostream &os) const {
176  os << "// ---- CallGraph ----\n";
177 
178  // Functor used to output the name for the given node.
179  auto emitNodeName = [&](const CallGraphNode *node) {
180  if (node == getExternalCallerNode()) {
181  os << "<External-Caller-Node>";
182  return;
183  }
184  if (node == getUnknownCalleeNode()) {
185  os << "<Unknown-Callee-Node>";
186  return;
187  }
188 
189  auto *callableRegion = node->getCallableRegion();
190  auto *parentOp = callableRegion->getParentOp();
191  os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
192  << callableRegion->getRegionNumber();
193  auto attrs = parentOp->getAttrDictionary();
194  if (!attrs.empty())
195  os << " : " << attrs;
196  };
197 
198  for (auto &nodeIt : nodes) {
199  const CallGraphNode *node = nodeIt.second.get();
200 
201  // Dump the header for this node.
202  os << "// - Node : ";
203  emitNodeName(node);
204  os << "\n";
205 
206  // Emit each of the edges.
207  for (auto &edge : *node) {
208  os << "// -- ";
209  if (edge.isCall())
210  os << "Call";
211  else if (edge.isChild())
212  os << "Child";
213 
214  os << "-Edge : ";
215  emitNodeName(edge.getTarget());
216  os << "\n";
217  }
218  os << "//\n";
219  }
220 
221  os << "// -- SCCs --\n";
222 
223  for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
224  os << "// - SCC : \n";
225  for (auto &node : scc) {
226  os << "// -- Node :";
227  emitNodeName(node);
228  os << "\n";
229  }
230  os << "\n";
231  }
232 
233  os << "// -------------------\n";
234 }
static void computeCallGraph(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, CallGraphNode *parentNode, bool resolveCalls)
Recursively compute the callgraph edges for the given operation.
Definition: CallGraph.cpp:70
This class represents a directed edge between two nodes in the callgraph.
Definition: CallGraph.h:43
bool isChild() const
Returns true if this edge represents a Child edge.
Definition: CallGraph.h:70
CallGraphNode * getTarget() const
Returns the target node for this edge.
Definition: CallGraph.h:73
This class represents a single callable in the callgraph.
Definition: CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition: CallGraph.cpp:28
void addAbstractEdge(CallGraphNode *node)
Adds an abstract reference edge to the given node.
Definition: CallGraph.cpp:39
void addChildEdge(CallGraphNode *child)
Adds a reference edge to the given child node.
Definition: CallGraph.cpp:50
bool hasChildren() const
Returns true if this node has any child edges.
Definition: CallGraph.cpp:55
void addCallEdge(CallGraphNode *node)
Add an outgoing call edge from this node.
Definition: CallGraph.cpp:45
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition: CallGraph.cpp:32
void eraseNode(CallGraphNode *node)
Erase the given node from the callgraph.
Definition: CallGraph.cpp:154
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
Definition: CallGraph.cpp:143
CallGraphNode * getUnknownCalleeNode() const
Return the callgraph node representing an indirect callee.
Definition: CallGraph.h:193
CallGraph(Operation *op)
Definition: CallGraph.cpp:95
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
Definition: CallGraph.cpp:135
void dump() const
Dump the graph in a human readable format.
Definition: CallGraph.cpp:174
CallGraphNode * getOrAddNode(Region *region, CallGraphNode *parentNode)
Get or add a call graph node for the given region.
Definition: CallGraph.cpp:109
void print(raw_ostream &os) const
Definition: CallGraph.cpp:175
CallGraphNode * getExternalCallerNode() const
Return the callgraph node representing an external caller.
Definition: CallGraph.h:188
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:245
Kind
Tensor expression kind.
Definition: Merger.h:25
Include the generated interface declarations.