MLIR  18.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/GraphWriter.h"
18 #include <map>
19 #include <optional>
20 #include <utility>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_VIEWOPGRAPH
24 #include "mlir/Transforms/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 static const StringRef kLineStyleControlFlow = "dashed";
30 static const StringRef kLineStyleDataFlow = "solid";
31 static const StringRef kShapeNode = "ellipse";
32 static const StringRef kShapeNone = "plain";
33 
34 /// Return the size limits for eliding large attributes.
35 static int64_t getLargeAttributeSizeLimit() {
36  // Use the default from the printer flags if possible.
37  if (std::optional<int64_t> limit =
38  OpPrintingFlags().getLargeElementsAttrLimit())
39  return *limit;
40  return 16;
41 }
42 
43 /// Return all values printed onto a stream as a string.
44 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
45  std::string buf;
46  llvm::raw_string_ostream os(buf);
47  func(os);
48  return os.str();
49 }
50 
51 /// Escape special characters such as '\n' and quotation marks.
52 static std::string escapeString(std::string str) {
53  return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
54 }
55 
56 /// Put quotation marks around a given string.
57 static std::string quoteString(const std::string &str) {
58  return "\"" + str + "\"";
59 }
60 
61 using AttributeMap = std::map<std::string, std::string>;
62 
63 namespace {
64 
65 /// This struct represents a node in the DOT language. Each node has an
66 /// identifier and an optional identifier for the cluster (subgraph) that
67 /// contains the node.
68 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
69 /// not between clusters. However, edges can be clipped to the boundary of a
70 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
71 /// cluster, an invisible "anchor" node is created.
72 struct Node {
73 public:
74  Node(int id = 0, std::optional<int> clusterId = std::nullopt)
75  : id(id), clusterId(clusterId) {}
76 
77  int id;
78  std::optional<int> clusterId;
79 };
80 
81 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
82 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
83 /// about the Graphviz DOT language.
84 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
85 public:
86  PrintOpPass(raw_ostream &os) : os(os) {}
87  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
88 
89  void runOnOperation() override {
90  initColorMapping(*getOperation());
91  emitGraph([&]() {
92  processOperation(getOperation());
93  emitAllEdgeStmts();
94  });
95  }
96 
97  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
98  void emitRegionCFG(Region &region) {
99  printControlFlowEdges = true;
100  printDataFlowEdges = false;
101  initColorMapping(region);
102  emitGraph([&]() { processRegion(region); });
103  }
104 
105 private:
106  /// Generate a color mapping that will color every operation with the same
107  /// name the same way. It'll interpolate the hue in the HSV color-space,
108  /// attempting to keep the contrast suitable for black text.
109  template <typename T>
110  void initColorMapping(T &irEntity) {
111  backgroundColors.clear();
113  irEntity.walk([&](Operation *op) {
114  auto &entry = backgroundColors[op->getName()];
115  if (entry.first == 0)
116  ops.push_back(op);
117  ++entry.first;
118  });
119  for (auto indexedOps : llvm::enumerate(ops)) {
120  double hue = ((double)indexedOps.index()) / ops.size();
121  backgroundColors[indexedOps.value()->getName()].second =
122  std::to_string(hue) + " 1.0 1.0";
123  }
124  }
125 
126  /// Emit all edges. This function should be called after all nodes have been
127  /// emitted.
128  void emitAllEdgeStmts() {
129  for (const std::string &edge : edges)
130  os << edge << ";\n";
131  edges.clear();
132  }
133 
134  /// Emit a cluster (subgraph). The specified builder generates the body of the
135  /// cluster. Return the anchor node of the cluster.
136  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
137  int clusterId = ++counter;
138  os << "subgraph cluster_" << clusterId << " {\n";
139  os.indent();
140  // Emit invisible anchor node from/to which arrows can be drawn.
141  Node anchorNode = emitNodeStmt(" ", kShapeNone);
142  os << attrStmt("label", quoteString(escapeString(std::move(label))))
143  << ";\n";
144  builder();
145  os.unindent();
146  os << "}\n";
147  return Node(anchorNode.id, clusterId);
148  }
149 
150  /// Generate an attribute statement.
151  std::string attrStmt(const Twine &key, const Twine &value) {
152  return (key + " = " + value).str();
153  }
154 
155  /// Emit an attribute list.
156  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
157  os << "[";
158  interleaveComma(map, os, [&](const auto &it) {
159  os << this->attrStmt(it.first, it.second);
160  });
161  os << "]";
162  }
163 
164  // Print an MLIR attribute to `os`. Large attributes are truncated.
165  void emitMlirAttr(raw_ostream &os, Attribute attr) {
166  // A value used to elide large container attribute.
167  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
168 
169  // Always emit splat attributes.
170  if (isa<SplatElementsAttr>(attr)) {
171  attr.print(os);
172  return;
173  }
174 
175  // Elide "big" elements attributes.
176  auto elements = dyn_cast<ElementsAttr>(attr);
177  if (elements && elements.getNumElements() > largeAttrLimit) {
178  os << std::string(elements.getShapedType().getRank(), '[') << "..."
179  << std::string(elements.getShapedType().getRank(), ']') << " : "
180  << elements.getType();
181  return;
182  }
183 
184  auto array = dyn_cast<ArrayAttr>(attr);
185  if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
186  os << "[...]";
187  return;
188  }
189 
190  // Print all other attributes.
191  std::string buf;
192  llvm::raw_string_ostream ss(buf);
193  attr.print(ss);
194  os << truncateString(ss.str());
195  }
196 
197  /// Append an edge to the list of edges.
198  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
199  void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
200  AttributeMap attrs;
201  attrs["style"] = style.str();
202  // Do not label edges that start/end at a cluster boundary. Such edges are
203  // clipped at the boundary, but labels are not. This can lead to labels
204  // floating around without any edge next to them.
205  if (!n1.clusterId && !n2.clusterId)
206  attrs["label"] = quoteString(escapeString(std::move(label)));
207  // Use `ltail` and `lhead` to draw edges between clusters.
208  if (n1.clusterId)
209  attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
210  if (n2.clusterId)
211  attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
212 
213  edges.push_back(strFromOs([&](raw_ostream &os) {
214  os << llvm::format("v%i -> v%i ", n1.id, n2.id);
215  emitAttrList(os, attrs);
216  }));
217  }
218 
219  /// Emit a graph. The specified builder generates the body of the graph.
220  void emitGraph(function_ref<void()> builder) {
221  os << "digraph G {\n";
222  os.indent();
223  // Edges between clusters are allowed only in compound mode.
224  os << attrStmt("compound", "true") << ";\n";
225  builder();
226  os.unindent();
227  os << "}\n";
228  }
229 
230  /// Emit a node statement.
231  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
232  StringRef background = "") {
233  int nodeId = ++counter;
234  AttributeMap attrs;
235  attrs["label"] = quoteString(escapeString(std::move(label)));
236  attrs["shape"] = shape.str();
237  if (!background.empty()) {
238  attrs["style"] = "filled";
239  attrs["fillcolor"] = ("\"" + background + "\"").str();
240  }
241  os << llvm::format("v%i ", nodeId);
242  emitAttrList(os, attrs);
243  os << ";\n";
244  return Node(nodeId);
245  }
246 
247  /// Generate a label for an operation.
248  std::string getLabel(Operation *op) {
249  return strFromOs([&](raw_ostream &os) {
250  // Print operation name and type.
251  os << op->getName();
252  if (printResultTypes) {
253  os << " : (";
254  std::string buf;
255  llvm::raw_string_ostream ss(buf);
256  interleaveComma(op->getResultTypes(), ss);
257  os << truncateString(ss.str()) << ")";
258  }
259 
260  // Print attributes.
261  if (printAttrs) {
262  os << "\n";
263  for (const NamedAttribute &attr : op->getAttrs()) {
264  os << '\n' << attr.getName().getValue() << ": ";
265  emitMlirAttr(os, attr.getValue());
266  }
267  }
268  });
269  }
270 
271  /// Generate a label for a block argument.
272  std::string getLabel(BlockArgument arg) {
273  return "arg" + std::to_string(arg.getArgNumber());
274  }
275 
276  /// Process a block. Emit a cluster and one node per block argument and
277  /// operation inside the cluster.
278  void processBlock(Block &block) {
279  emitClusterStmt([&]() {
280  for (BlockArgument &blockArg : block.getArguments())
281  valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
282 
283  // Emit a node for each operation.
284  std::optional<Node> prevNode;
285  for (Operation &op : block) {
286  Node nextNode = processOperation(&op);
287  if (printControlFlowEdges && prevNode)
288  emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
290  prevNode = nextNode;
291  }
292  });
293  }
294 
295  /// Process an operation. If the operation has regions, emit a cluster.
296  /// Otherwise, emit a node.
297  Node processOperation(Operation *op) {
298  Node node;
299  if (op->getNumRegions() > 0) {
300  // Emit cluster for op with regions.
301  node = emitClusterStmt(
302  [&]() {
303  for (Region &region : op->getRegions())
304  processRegion(region);
305  },
306  getLabel(op));
307  } else {
308  node = emitNodeStmt(getLabel(op), kShapeNode,
309  backgroundColors[op->getName()].second);
310  }
311 
312  // Insert data flow edges originating from each operand.
313  if (printDataFlowEdges) {
314  unsigned numOperands = op->getNumOperands();
315  for (unsigned i = 0; i < numOperands; i++)
316  emitEdgeStmt(valueToNode[op->getOperand(i)], node,
317  /*label=*/numOperands == 1 ? "" : std::to_string(i),
319  }
320 
321  for (Value result : op->getResults())
322  valueToNode[result] = node;
323 
324  return node;
325  }
326 
327  /// Process a region.
328  void processRegion(Region &region) {
329  for (Block &block : region.getBlocks())
330  processBlock(block);
331  }
332 
333  /// Truncate long strings.
334  std::string truncateString(std::string str) {
335  if (str.length() <= maxLabelLen)
336  return str;
337  return str.substr(0, maxLabelLen) + "...";
338  }
339 
340  /// Output stream to write DOT file to.
342  /// A list of edges. For simplicity, should be emitted after all nodes were
343  /// emitted.
344  std::vector<std::string> edges;
345  /// Mapping of SSA values to Graphviz nodes/clusters.
346  DenseMap<Value, Node> valueToNode;
347  /// Counter for generating unique node/subgraph identifiers.
348  int counter = 0;
349 
351 };
352 
353 } // namespace
354 
355 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
356  return std::make_unique<PrintOpPass>(os);
357 }
358 
359 /// Generate a CFG for a region and show it in a window.
360 static void llvmViewGraph(Region &region, const Twine &name) {
361  int fd;
362  std::string filename = llvm::createGraphFilename(name.str(), fd);
363  {
364  llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
365  if (fd == -1) {
366  llvm::errs() << "error opening file '" << filename << "' for writing\n";
367  return;
368  }
369  PrintOpPass pass(os);
370  pass.emitRegionCFG(region);
371  }
372  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
373 }
374 
375 void mlir::Region::viewGraph(const Twine &regionName) {
376  llvmViewGraph(*this, regionName);
377 }
378 
379 void mlir::Region::viewGraph() { viewGraph("region"); }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
Definition: ViewOpGraph.cpp:57
static const StringRef kLineStyleDataFlow
Definition: ViewOpGraph.cpp:30
static const StringRef kLineStyleControlFlow
Definition: ViewOpGraph.cpp:29
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
Definition: ViewOpGraph.cpp:35
std::map< std::string, std::string > AttributeMap
Definition: ViewOpGraph.cpp:61
static const StringRef kShapeNone
Definition: ViewOpGraph.cpp:32
static const StringRef kShapeNode
Definition: ViewOpGraph.cpp:31
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string escapeString(std::string str)
Escape special characters such as ' ' and quotation marks.
Definition: ViewOpGraph.cpp:52
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Definition: ViewOpGraph.cpp:44
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition: Value.h:315
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void viewGraph()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.