MLIR  16.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/StringMap.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
19 #include <utility>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_VIEWOPGRAPH
23 #include "mlir/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 static const StringRef kLineStyleControlFlow = "dashed";
29 static const StringRef kLineStyleDataFlow = "solid";
30 static const StringRef kShapeNode = "ellipse";
31 static const StringRef kShapeNone = "plain";
32 
33 /// Return the size limits for eliding large attributes.
34 static int64_t getLargeAttributeSizeLimit() {
35  // Use the default from the printer flags if possible.
36  if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
37  return *limit;
38  return 16;
39 }
40 
41 /// Return all values printed onto a stream as a string.
42 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
43  std::string buf;
44  llvm::raw_string_ostream os(buf);
45  func(os);
46  return os.str();
47 }
48 
49 /// Escape special characters such as '\n' and quotation marks.
50 static std::string escapeString(std::string str) {
51  return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
52 }
53 
54 /// Put quotation marks around a given string.
55 static std::string quoteString(const std::string &str) {
56  return "\"" + str + "\"";
57 }
58 
59 using AttributeMap = llvm::StringMap<std::string>;
60 
61 namespace {
62 
63 /// This struct represents a node in the DOT language. Each node has an
64 /// identifier and an optional identifier for the cluster (subgraph) that
65 /// contains the node.
66 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
67 /// not between clusters. However, edges can be clipped to the boundary of a
68 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
69 /// cluster, an invisible "anchor" node is created.
70 struct Node {
71 public:
72  Node(int id = 0, Optional<int> clusterId = llvm::None)
73  : id(id), clusterId(clusterId) {}
74 
75  int id;
76  Optional<int> clusterId;
77 };
78 
79 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
80 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
81 /// about the Graphviz DOT language.
82 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
83 public:
84  PrintOpPass(raw_ostream &os) : os(os) {}
85  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
86 
87  void runOnOperation() override {
88  emitGraph([&]() {
89  processOperation(getOperation());
90  emitAllEdgeStmts();
91  });
92  }
93 
94  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
95  void emitRegionCFG(Region &region) {
96  printControlFlowEdges = true;
97  printDataFlowEdges = false;
98  emitGraph([&]() { processRegion(region); });
99  }
100 
101 private:
102  /// Emit all edges. This function should be called after all nodes have been
103  /// emitted.
104  void emitAllEdgeStmts() {
105  for (const std::string &edge : edges)
106  os << edge << ";\n";
107  edges.clear();
108  }
109 
110  /// Emit a cluster (subgraph). The specified builder generates the body of the
111  /// cluster. Return the anchor node of the cluster.
112  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
113  int clusterId = ++counter;
114  os << "subgraph cluster_" << clusterId << " {\n";
115  os.indent();
116  // Emit invisible anchor node from/to which arrows can be drawn.
117  Node anchorNode = emitNodeStmt(" ", kShapeNone);
118  os << attrStmt("label", quoteString(escapeString(std::move(label))))
119  << ";\n";
120  builder();
121  os.unindent();
122  os << "}\n";
123  return Node(anchorNode.id, clusterId);
124  }
125 
126  /// Generate an attribute statement.
127  std::string attrStmt(const Twine &key, const Twine &value) {
128  return (key + " = " + value).str();
129  }
130 
131  /// Emit an attribute list.
132  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
133  os << "[";
134  interleaveComma(map, os, [&](const auto &it) {
135  os << this->attrStmt(it.getKey(), it.getValue());
136  });
137  os << "]";
138  }
139 
140  // Print an MLIR attribute to `os`. Large attributes are truncated.
141  void emitMlirAttr(raw_ostream &os, Attribute attr) {
142  // A value used to elide large container attribute.
143  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
144 
145  // Always emit splat attributes.
146  if (attr.isa<SplatElementsAttr>()) {
147  attr.print(os);
148  return;
149  }
150 
151  // Elide "big" elements attributes.
152  auto elements = attr.dyn_cast<ElementsAttr>();
153  if (elements && elements.getNumElements() > largeAttrLimit) {
154  os << std::string(elements.getType().getRank(), '[') << "..."
155  << std::string(elements.getType().getRank(), ']') << " : "
156  << elements.getType();
157  return;
158  }
159 
160  auto array = attr.dyn_cast<ArrayAttr>();
161  if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
162  os << "[...]";
163  return;
164  }
165 
166  // Print all other attributes.
167  std::string buf;
168  llvm::raw_string_ostream ss(buf);
169  attr.print(ss);
170  os << truncateString(ss.str());
171  }
172 
173  /// Append an edge to the list of edges.
174  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
175  void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
176  AttributeMap attrs;
177  attrs["style"] = style.str();
178  // Do not label edges that start/end at a cluster boundary. Such edges are
179  // clipped at the boundary, but labels are not. This can lead to labels
180  // floating around without any edge next to them.
181  if (!n1.clusterId && !n2.clusterId)
182  attrs["label"] = quoteString(escapeString(std::move(label)));
183  // Use `ltail` and `lhead` to draw edges between clusters.
184  if (n1.clusterId)
185  attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
186  if (n2.clusterId)
187  attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
188 
189  edges.push_back(strFromOs([&](raw_ostream &os) {
190  os << llvm::format("v%i -> v%i ", n1.id, n2.id);
191  emitAttrList(os, attrs);
192  }));
193  }
194 
195  /// Emit a graph. The specified builder generates the body of the graph.
196  void emitGraph(function_ref<void()> builder) {
197  os << "digraph G {\n";
198  os.indent();
199  // Edges between clusters are allowed only in compound mode.
200  os << attrStmt("compound", "true") << ";\n";
201  builder();
202  os.unindent();
203  os << "}\n";
204  }
205 
206  /// Emit a node statement.
207  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
208  int nodeId = ++counter;
209  AttributeMap attrs;
210  attrs["label"] = quoteString(escapeString(std::move(label)));
211  attrs["shape"] = shape.str();
212  os << llvm::format("v%i ", nodeId);
213  emitAttrList(os, attrs);
214  os << ";\n";
215  return Node(nodeId);
216  }
217 
218  /// Generate a label for an operation.
219  std::string getLabel(Operation *op) {
220  return strFromOs([&](raw_ostream &os) {
221  // Print operation name and type.
222  os << op->getName();
223  if (printResultTypes) {
224  os << " : (";
225  std::string buf;
226  llvm::raw_string_ostream ss(buf);
227  interleaveComma(op->getResultTypes(), ss);
228  os << truncateString(ss.str()) << ")";
229  os << ")";
230  }
231 
232  // Print attributes.
233  if (printAttrs) {
234  os << "\n";
235  for (const NamedAttribute &attr : op->getAttrs()) {
236  os << '\n' << attr.getName().getValue() << ": ";
237  emitMlirAttr(os, attr.getValue());
238  }
239  }
240  });
241  }
242 
243  /// Generate a label for a block argument.
244  std::string getLabel(BlockArgument arg) {
245  return "arg" + std::to_string(arg.getArgNumber());
246  }
247 
248  /// Process a block. Emit a cluster and one node per block argument and
249  /// operation inside the cluster.
250  void processBlock(Block &block) {
251  emitClusterStmt([&]() {
252  for (BlockArgument &blockArg : block.getArguments())
253  valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
254 
255  // Emit a node for each operation.
256  Optional<Node> prevNode;
257  for (Operation &op : block) {
258  Node nextNode = processOperation(&op);
259  if (printControlFlowEdges && prevNode)
260  emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
262  prevNode = nextNode;
263  }
264  });
265  }
266 
267  /// Process an operation. If the operation has regions, emit a cluster.
268  /// Otherwise, emit a node.
269  Node processOperation(Operation *op) {
270  Node node;
271  if (op->getNumRegions() > 0) {
272  // Emit cluster for op with regions.
273  node = emitClusterStmt(
274  [&]() {
275  for (Region &region : op->getRegions())
276  processRegion(region);
277  },
278  getLabel(op));
279  } else {
280  node = emitNodeStmt(getLabel(op));
281  }
282 
283  // Insert data flow edges originating from each operand.
284  if (printDataFlowEdges) {
285  unsigned numOperands = op->getNumOperands();
286  for (unsigned i = 0; i < numOperands; i++)
287  emitEdgeStmt(valueToNode[op->getOperand(i)], node,
288  /*label=*/numOperands == 1 ? "" : std::to_string(i),
290  }
291 
292  for (Value result : op->getResults())
293  valueToNode[result] = node;
294 
295  return node;
296  }
297 
298  /// Process a region.
299  void processRegion(Region &region) {
300  for (Block &block : region.getBlocks())
301  processBlock(block);
302  }
303 
304  /// Truncate long strings.
305  std::string truncateString(std::string str) {
306  if (str.length() <= maxLabelLen)
307  return str;
308  return str.substr(0, maxLabelLen) + "...";
309  }
310 
311  /// Output stream to write DOT file to.
313  /// A list of edges. For simplicity, should be emitted after all nodes were
314  /// emitted.
315  std::vector<std::string> edges;
316  /// Mapping of SSA values to Graphviz nodes/clusters.
317  DenseMap<Value, Node> valueToNode;
318  /// Counter for generating unique node/subgraph identifiers.
319  int counter = 0;
320 };
321 
322 } // namespace
323 
324 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
325  return std::make_unique<PrintOpPass>(os);
326 }
327 
328 /// Generate a CFG for a region and show it in a window.
329 static void llvmViewGraph(Region &region, const Twine &name) {
330  int fd;
331  std::string filename = llvm::createGraphFilename(name.str(), fd);
332  {
333  llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
334  if (fd == -1) {
335  llvm::errs() << "error opening file '" << filename << "' for writing\n";
336  return;
337  }
338  PrintOpPass pass(os);
339  pass.emitRegionCFG(region);
340  }
341  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
342 }
343 
344 void mlir::Region::viewGraph(const Twine &regionName) {
345  llvmViewGraph(*this, regionName);
346 }
347 
348 void mlir::Region::viewGraph() { viewGraph("region"); }
static constexpr const bool value
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
Definition: ViewOpGraph.cpp:55
llvm::StringMap< std::string > AttributeMap
Definition: ViewOpGraph.cpp:59
static const StringRef kLineStyleDataFlow
Definition: ViewOpGraph.cpp:29
static const StringRef kLineStyleControlFlow
Definition: ViewOpGraph.cpp:28
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
Definition: ViewOpGraph.cpp:34
static const StringRef kShapeNone
Definition: ViewOpGraph.cpp:31
static const StringRef kShapeNode
Definition: ViewOpGraph.cpp:30
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string escapeString(std::string str)
Escape special characters such as ' ' and quotation marks.
Definition: ViewOpGraph.cpp:50
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Definition: ViewOpGraph.cpp:42
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
This class represents an argument of a Block.
Definition: Value.h:296
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:308
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
unsigned getNumOperands()
Definition: Operation.h:263
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
result_type_range getResultTypes()
Definition: Operation.h:345
result_range getResults()
Definition: Operation.h:332
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void viewGraph()
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
raw_ostream subclass that simplifies indention a sequence of code.
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.