MLIR  22.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
19 #include <map>
20 #include <optional>
21 #include <utility>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 static const StringRef kLineStyleControlFlow = "dashed";
31 static const StringRef kLineStyleDataFlow = "solid";
32 static const StringRef kShapeNode = "Mrecord";
33 static const StringRef kShapeNone = "plain";
34 
35 /// Return the size limits for eliding large attributes.
36 static int64_t getLargeAttributeSizeLimit() {
37  // Use the default from the printer flags if possible.
38  if (std::optional<int64_t> limit =
39  OpPrintingFlags().getLargeElementsAttrLimit())
40  return *limit;
41  return 16;
42 }
43 
44 /// Return all values printed onto a stream as a string.
45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46  std::string buf;
47  llvm::raw_string_ostream os(buf);
48  func(os);
49  return buf;
50 }
51 
52 /// Put quotation marks around a given string.
53 static std::string quoteString(const std::string &str) {
54  return "\"" + str + "\"";
55 }
56 
57 /// For Graphviz record nodes:
58 /// " Braces, vertical bars and angle brackets must be escaped with a backslash
59 /// character if you wish them to appear as a literal character "
60 std::string escapeLabelString(const std::string &str) {
61  std::string buf;
62  llvm::raw_string_ostream os(buf);
63  for (char c : str) {
64  if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
65  os << '\\';
66  os << c;
67  }
68  return buf;
69 }
70 
71 using AttributeMap = std::map<std::string, std::string>;
72 
73 namespace {
74 
75 /// This struct represents a node in the DOT language. Each node has an
76 /// identifier and an optional identifier for the cluster (subgraph) that
77 /// contains the node.
78 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
79 /// not between clusters. However, edges can be clipped to the boundary of a
80 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
81 /// cluster, an invisible "anchor" node is created.
82 struct Node {
83 public:
84  Node(int id = 0, std::optional<int> clusterId = std::nullopt)
85  : id(id), clusterId(clusterId) {}
86 
87  int id;
88  std::optional<int> clusterId;
89 };
90 
91 struct DataFlowEdge {
92  Value value;
93  Node node;
94  std::string port;
95 };
96 
97 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
98 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
99 /// about the Graphviz DOT language.
100 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
101 public:
102  PrintOpPass(raw_ostream &os) : os(os) {}
103  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
104 
105  void runOnOperation() override {
106  initColorMapping(*getOperation());
107  emitGraph([&]() {
108  processOperation(getOperation());
109  emitAllEdgeStmts();
110  });
111  markAllAnalysesPreserved();
112  }
113 
114  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
115  void emitRegionCFG(Region &region) {
116  printControlFlowEdges = true;
117  printDataFlowEdges = false;
118  initColorMapping(region);
119  emitGraph([&]() { processRegion(region); });
120  }
121 
122 private:
123  /// Generate a color mapping that will color every operation with the same
124  /// name the same way. It'll interpolate the hue in the HSV color-space,
125  /// using muted colors that provide good contrast for black text.
126  template <typename T>
127  void initColorMapping(T &irEntity) {
128  backgroundColors.clear();
130  irEntity.walk([&](Operation *op) {
131  auto &entry = backgroundColors[op->getName()];
132  if (entry.first == 0)
133  ops.push_back(op);
134  ++entry.first;
135  });
136  for (auto indexedOps : llvm::enumerate(ops)) {
137  double hue = ((double)indexedOps.index()) / ops.size();
138  // Use lower saturation (0.3) and higher value (0.95) for better
139  // readability
140  backgroundColors[indexedOps.value()->getName()].second =
141  std::to_string(hue) + " 0.3 0.95";
142  }
143  }
144 
145  /// Emit all edges. This function should be called after all nodes have been
146  /// emitted.
147  void emitAllEdgeStmts() {
148  if (printDataFlowEdges) {
149  for (const auto &e : dataFlowEdges) {
150  emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
151  }
152  }
153 
154  for (const std::string &edge : edges)
155  os << edge << ";\n";
156  edges.clear();
157  }
158 
159  /// Emit a cluster (subgraph). The specified builder generates the body of the
160  /// cluster. Return the anchor node of the cluster.
161  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
162  int clusterId = ++counter;
163  os << "subgraph cluster_" << clusterId << " {\n";
164  os.indent();
165  // Emit invisible anchor node from/to which arrows can be drawn.
166  Node anchorNode = emitNodeStmt(" ", kShapeNone);
167  os << attrStmt("label", quoteString(label)) << ";\n";
168  builder();
169  os.unindent();
170  os << "}\n";
171  return Node(anchorNode.id, clusterId);
172  }
173 
174  /// Generate an attribute statement.
175  std::string attrStmt(const Twine &key, const Twine &value) {
176  return (key + " = " + value).str();
177  }
178 
179  /// Emit an attribute list.
180  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
181  os << "[";
182  interleaveComma(map, os, [&](const auto &it) {
183  os << this->attrStmt(it.first, it.second);
184  });
185  os << "]";
186  }
187 
188  // Print an MLIR attribute to `os`. Large attributes are truncated.
189  void emitMlirAttr(raw_ostream &os, Attribute attr) {
190  // A value used to elide large container attribute.
191  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
192 
193  // Always emit splat attributes.
194  if (isa<SplatElementsAttr>(attr)) {
195  os << escapeLabelString(
196  strFromOs([&](raw_ostream &os) { attr.print(os); }));
197  return;
198  }
199 
200  // Elide "big" elements attributes.
201  auto elements = dyn_cast<ElementsAttr>(attr);
202  if (elements && elements.getNumElements() > largeAttrLimit) {
203  os << std::string(elements.getShapedType().getRank(), '[') << "..."
204  << std::string(elements.getShapedType().getRank(), ']') << " : ";
205  emitMlirType(os, elements.getType());
206  return;
207  }
208 
209  auto array = dyn_cast<ArrayAttr>(attr);
210  if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
211  os << "[...]";
212  return;
213  }
214 
215  // Print all other attributes.
216  std::string buf;
217  llvm::raw_string_ostream ss(buf);
218  attr.print(ss);
219  os << escapeLabelString(truncateString(buf));
220  }
221 
222  // Print a truncated and escaped MLIR type to `os`.
223  void emitMlirType(raw_ostream &os, Type type) {
224  std::string buf;
225  llvm::raw_string_ostream ss(buf);
226  type.print(ss);
227  os << escapeLabelString(truncateString(buf));
228  }
229 
230  // Print a truncated and escaped MLIR operand to `os`.
231  void emitMlirOperand(raw_ostream &os, Value operand) {
232  operand.printAsOperand(os, OpPrintingFlags());
233  }
234 
235  /// Append an edge to the list of edges.
236  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
237  void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
238  AttributeMap attrs;
239  attrs["style"] = style.str();
240  // Use `ltail` and `lhead` to draw edges between clusters.
241  if (n1.clusterId)
242  attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
243  if (n2.clusterId)
244  attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
245 
246  edges.push_back(strFromOs([&](raw_ostream &os) {
247  os << "v" << n1.id;
248  if (!port.empty() && !n1.clusterId)
249  // Attach edge to south compass point of the result
250  os << ":res" << port << ":s";
251  os << " -> ";
252  os << "v" << n2.id;
253  if (!port.empty() && !n2.clusterId)
254  // Attach edge to north compass point of the operand
255  os << ":arg" << port << ":n";
256  emitAttrList(os, attrs);
257  }));
258  }
259 
260  /// Emit a graph. The specified builder generates the body of the graph.
261  void emitGraph(function_ref<void()> builder) {
262  os << "digraph G {\n";
263  os.indent();
264  // Edges between clusters are allowed only in compound mode.
265  os << attrStmt("compound", "true") << ";\n";
266  builder();
267  os.unindent();
268  os << "}\n";
269  }
270 
271  /// Emit a node statement.
272  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
273  StringRef background = "") {
274  int nodeId = ++counter;
275  AttributeMap attrs;
276  attrs["label"] = quoteString(label);
277  attrs["shape"] = shape.str();
278  if (!background.empty()) {
279  attrs["style"] = "filled";
280  attrs["fillcolor"] = quoteString(background.str());
281  }
282  os << llvm::format("v%i ", nodeId);
283  emitAttrList(os, attrs);
284  os << ";\n";
285  return Node(nodeId);
286  }
287 
288  std::string getValuePortName(Value operand) {
289  // Print value as an operand and omit the leading '%' character.
290  auto str = strFromOs([&](raw_ostream &os) {
291  operand.printAsOperand(os, OpPrintingFlags());
292  });
293  // Replace % and # with _
294  llvm::replace(str, '%', '_');
295  llvm::replace(str, '#', '_');
296  return str;
297  }
298 
299  std::string getClusterLabel(Operation *op) {
300  return strFromOs([&](raw_ostream &os) {
301  // Print operation name and type.
302  os << op->getName();
303  if (printResultTypes) {
304  os << " : (";
305  std::string buf;
306  llvm::raw_string_ostream ss(buf);
307  interleaveComma(op->getResultTypes(), ss);
308  os << truncateString(buf) << ")";
309  }
310 
311  // Print attributes.
312  if (printAttrs) {
313  os << "\\l";
314  for (const NamedAttribute &attr : op->getAttrs()) {
315  os << escapeLabelString(attr.getName().getValue().str()) << ": ";
316  emitMlirAttr(os, attr.getValue());
317  os << "\\l";
318  }
319  }
320  });
321  }
322 
323  /// Generate a label for an operation.
324  std::string getRecordLabel(Operation *op) {
325  return strFromOs([&](raw_ostream &os) {
326  os << "{";
327 
328  // Print operation inputs.
329  if (op->getNumOperands() > 0) {
330  os << "{";
331  auto operandToPort = [&](Value operand) {
332  os << "<arg" << getValuePortName(operand) << "> ";
333  emitMlirOperand(os, operand);
334  };
335  interleave(op->getOperands(), os, operandToPort, "|");
336  os << "}|";
337  }
338  // Print operation name and type.
339  os << op->getName() << "\\l";
340 
341  // Print attributes.
342  if (printAttrs && !op->getAttrs().empty()) {
343  // Extra line break to separate attributes from the operation name.
344  os << "\\l";
345  for (const NamedAttribute &attr : op->getAttrs()) {
346  os << attr.getName().getValue() << ": ";
347  emitMlirAttr(os, attr.getValue());
348  os << "\\l";
349  }
350  }
351 
352  if (op->getNumResults() > 0) {
353  os << "|{";
354  auto resultToPort = [&](Value result) {
355  os << "<res" << getValuePortName(result) << "> ";
356  emitMlirOperand(os, result);
357  if (printResultTypes) {
358  os << " ";
359  emitMlirType(os, result.getType());
360  }
361  };
362  interleave(op->getResults(), os, resultToPort, "|");
363  os << "}";
364  }
365 
366  os << "}";
367  });
368  }
369 
370  /// Generate a label for a block argument.
371  std::string getLabel(BlockArgument arg) {
372  return strFromOs([&](raw_ostream &os) {
373  os << "<res" << getValuePortName(arg) << "> ";
374  arg.printAsOperand(os, OpPrintingFlags());
375  if (printResultTypes) {
376  os << " ";
377  emitMlirType(os, arg.getType());
378  }
379  });
380  }
381 
382  /// Process a block. Emit a cluster and one node per block argument and
383  /// operation inside the cluster.
384  void processBlock(Block &block) {
385  emitClusterStmt([&]() {
386  for (BlockArgument &blockArg : block.getArguments())
387  valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
388  // Emit a node for each operation.
389  std::optional<Node> prevNode;
390  for (Operation &op : block) {
391  Node nextNode = processOperation(&op);
392  if (printControlFlowEdges && prevNode)
393  emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
394  prevNode = nextNode;
395  }
396  });
397  }
398 
399  /// Process an operation. If the operation has regions, emit a cluster.
400  /// Otherwise, emit a node.
401  Node processOperation(Operation *op) {
402  Node node;
403  if (op->getNumRegions() > 0) {
404  // Emit cluster for op with regions.
405  node = emitClusterStmt(
406  [&]() {
407  for (Region &region : op->getRegions())
408  processRegion(region);
409  },
410  getClusterLabel(op));
411  } else {
412  node = emitNodeStmt(getRecordLabel(op), kShapeNode,
413  backgroundColors[op->getName()].second);
414  }
415 
416  // Insert data flow edges originating from each operand.
417  if (printDataFlowEdges) {
418  unsigned numOperands = op->getNumOperands();
419  for (unsigned i = 0; i < numOperands; i++) {
420  auto operand = op->getOperand(i);
421  dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
422  }
423  }
424 
425  for (Value result : op->getResults())
426  valueToNode[result] = node;
427 
428  return node;
429  }
430 
431  /// Process a region.
432  void processRegion(Region &region) {
433  for (Block &block : region.getBlocks())
434  processBlock(block);
435  }
436 
437  /// Truncate long strings.
438  std::string truncateString(std::string str) {
439  if (str.length() <= maxLabelLen)
440  return str;
441  return str.substr(0, maxLabelLen) + "...";
442  }
443 
444  /// Output stream to write DOT file to.
446  /// A list of edges. For simplicity, should be emitted after all nodes were
447  /// emitted.
448  std::vector<std::string> edges;
449  /// Mapping of SSA values to Graphviz nodes/clusters.
450  DenseMap<Value, Node> valueToNode;
451  /// Output for data flow edges is delayed until the end to handle cycles
452  std::vector<DataFlowEdge> dataFlowEdges;
453  /// Counter for generating unique node/subgraph identifiers.
454  int counter = 0;
455 
457 };
458 
459 } // namespace
460 
461 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
462  return std::make_unique<PrintOpPass>(os);
463 }
464 
465 /// Generate a CFG for a region and show it in a window.
466 static void llvmViewGraph(Region &region, const Twine &name) {
467  int fd;
468  std::string filename = llvm::createGraphFilename(name.str(), fd);
469  {
470  llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
471  if (fd == -1) {
472  llvm::errs() << "error opening file '" << filename << "' for writing\n";
473  return;
474  }
475  PrintOpPass pass(os);
476  pass.emitRegionCFG(region);
477  }
478  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
479 }
480 
481 void mlir::Region::viewGraph(const Twine &regionName) {
482  llvmViewGraph(*this, regionName);
483 }
484 
485 void mlir::Region::viewGraph() { viewGraph("region"); }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:37
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
Definition: ViewOpGraph.cpp:53
static const StringRef kLineStyleDataFlow
Definition: ViewOpGraph.cpp:31
static const StringRef kLineStyleControlFlow
Definition: ViewOpGraph.cpp:30
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
Definition: ViewOpGraph.cpp:36
std::map< std::string, std::string > AttributeMap
Definition: ViewOpGraph.cpp:71
static const StringRef kShapeNone
Definition: ViewOpGraph.cpp:33
static const StringRef kShapeNode
Definition: ViewOpGraph.cpp:32
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
Definition: ViewOpGraph.cpp:60
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Definition: ViewOpGraph.cpp:45
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void viewGraph()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.