MLIR  21.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "mlir/IR/Block.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/Format.h"
19 #include "llvm/Support/GraphWriter.h"
20 #include <map>
21 #include <optional>
22 #include <utility>
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_VIEWOPGRAPH
26 #include "mlir/Transforms/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 
31 static const StringRef kLineStyleControlFlow = "dashed";
32 static const StringRef kLineStyleDataFlow = "solid";
33 static const StringRef kShapeNode = "Mrecord";
34 static const StringRef kShapeNone = "plain";
35 
36 /// Return the size limits for eliding large attributes.
37 static int64_t getLargeAttributeSizeLimit() {
38  // Use the default from the printer flags if possible.
39  if (std::optional<int64_t> limit =
40  OpPrintingFlags().getLargeElementsAttrLimit())
41  return *limit;
42  return 16;
43 }
44 
45 /// Return all values printed onto a stream as a string.
46 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
47  std::string buf;
48  llvm::raw_string_ostream os(buf);
49  func(os);
50  return buf;
51 }
52 
53 /// Put quotation marks around a given string.
54 static std::string quoteString(const std::string &str) {
55  return "\"" + str + "\"";
56 }
57 
58 /// For Graphviz record nodes:
59 /// " Braces, vertical bars and angle brackets must be escaped with a backslash
60 /// character if you wish them to appear as a literal character "
61 std::string escapeLabelString(const std::string &str) {
62  std::string buf;
63  llvm::raw_string_ostream os(buf);
64  for (char c : str) {
65  if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
66  os << '\\';
67  os << c;
68  }
69  return buf;
70 }
71 
72 using AttributeMap = std::map<std::string, std::string>;
73 
74 namespace {
75 
76 /// This struct represents a node in the DOT language. Each node has an
77 /// identifier and an optional identifier for the cluster (subgraph) that
78 /// contains the node.
79 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
80 /// not between clusters. However, edges can be clipped to the boundary of a
81 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
82 /// cluster, an invisible "anchor" node is created.
83 struct Node {
84 public:
85  Node(int id = 0, std::optional<int> clusterId = std::nullopt)
86  : id(id), clusterId(clusterId) {}
87 
88  int id;
89  std::optional<int> clusterId;
90 };
91 
92 struct DataFlowEdge {
93  Value value;
94  Node node;
95  std::string port;
96 };
97 
98 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
99 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
100 /// about the Graphviz DOT language.
101 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
102 public:
103  PrintOpPass(raw_ostream &os) : os(os) {}
104  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
105 
106  void runOnOperation() override {
107  initColorMapping(*getOperation());
108  emitGraph([&]() {
109  processOperation(getOperation());
110  emitAllEdgeStmts();
111  });
112  markAllAnalysesPreserved();
113  }
114 
115  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
116  void emitRegionCFG(Region &region) {
117  printControlFlowEdges = true;
118  printDataFlowEdges = false;
119  initColorMapping(region);
120  emitGraph([&]() { processRegion(region); });
121  }
122 
123 private:
124  /// Generate a color mapping that will color every operation with the same
125  /// name the same way. It'll interpolate the hue in the HSV color-space,
126  /// using muted colors that provide good contrast for black text.
127  template <typename T>
128  void initColorMapping(T &irEntity) {
129  backgroundColors.clear();
131  irEntity.walk([&](Operation *op) {
132  auto &entry = backgroundColors[op->getName()];
133  if (entry.first == 0)
134  ops.push_back(op);
135  ++entry.first;
136  });
137  for (auto indexedOps : llvm::enumerate(ops)) {
138  double hue = ((double)indexedOps.index()) / ops.size();
139  // Use lower saturation (0.3) and higher value (0.95) for better
140  // readability
141  backgroundColors[indexedOps.value()->getName()].second =
142  std::to_string(hue) + " 0.3 0.95";
143  }
144  }
145 
146  /// Emit all edges. This function should be called after all nodes have been
147  /// emitted.
148  void emitAllEdgeStmts() {
149  if (printDataFlowEdges) {
150  for (const auto &e : dataFlowEdges) {
151  emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
152  }
153  }
154 
155  for (const std::string &edge : edges)
156  os << edge << ";\n";
157  edges.clear();
158  }
159 
160  /// Emit a cluster (subgraph). The specified builder generates the body of the
161  /// cluster. Return the anchor node of the cluster.
162  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
163  int clusterId = ++counter;
164  os << "subgraph cluster_" << clusterId << " {\n";
165  os.indent();
166  // Emit invisible anchor node from/to which arrows can be drawn.
167  Node anchorNode = emitNodeStmt(" ", kShapeNone);
168  os << attrStmt("label", quoteString(label)) << ";\n";
169  builder();
170  os.unindent();
171  os << "}\n";
172  return Node(anchorNode.id, clusterId);
173  }
174 
175  /// Generate an attribute statement.
176  std::string attrStmt(const Twine &key, const Twine &value) {
177  return (key + " = " + value).str();
178  }
179 
180  /// Emit an attribute list.
181  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
182  os << "[";
183  interleaveComma(map, os, [&](const auto &it) {
184  os << this->attrStmt(it.first, it.second);
185  });
186  os << "]";
187  }
188 
189  // Print an MLIR attribute to `os`. Large attributes are truncated.
190  void emitMlirAttr(raw_ostream &os, Attribute attr) {
191  // A value used to elide large container attribute.
192  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
193 
194  // Always emit splat attributes.
195  if (isa<SplatElementsAttr>(attr)) {
196  os << escapeLabelString(
197  strFromOs([&](raw_ostream &os) { attr.print(os); }));
198  return;
199  }
200 
201  // Elide "big" elements attributes.
202  auto elements = dyn_cast<ElementsAttr>(attr);
203  if (elements && elements.getNumElements() > largeAttrLimit) {
204  os << std::string(elements.getShapedType().getRank(), '[') << "..."
205  << std::string(elements.getShapedType().getRank(), ']') << " : ";
206  emitMlirType(os, elements.getType());
207  return;
208  }
209 
210  auto array = dyn_cast<ArrayAttr>(attr);
211  if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
212  os << "[...]";
213  return;
214  }
215 
216  // Print all other attributes.
217  std::string buf;
218  llvm::raw_string_ostream ss(buf);
219  attr.print(ss);
220  os << escapeLabelString(truncateString(buf));
221  }
222 
223  // Print a truncated and escaped MLIR type to `os`.
224  void emitMlirType(raw_ostream &os, Type type) {
225  std::string buf;
226  llvm::raw_string_ostream ss(buf);
227  type.print(ss);
228  os << escapeLabelString(truncateString(buf));
229  }
230 
231  // Print a truncated and escaped MLIR operand to `os`.
232  void emitMlirOperand(raw_ostream &os, Value operand) {
233  operand.printAsOperand(os, OpPrintingFlags());
234  }
235 
236  /// Append an edge to the list of edges.
237  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
238  void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
239  AttributeMap attrs;
240  attrs["style"] = style.str();
241  // Use `ltail` and `lhead` to draw edges between clusters.
242  if (n1.clusterId)
243  attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
244  if (n2.clusterId)
245  attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
246 
247  edges.push_back(strFromOs([&](raw_ostream &os) {
248  os << "v" << n1.id;
249  if (!port.empty() && !n1.clusterId)
250  // Attach edge to south compass point of the result
251  os << ":res" << port << ":s";
252  os << " -> ";
253  os << "v" << n2.id;
254  if (!port.empty() && !n2.clusterId)
255  // Attach edge to north compass point of the operand
256  os << ":arg" << port << ":n";
257  emitAttrList(os, attrs);
258  }));
259  }
260 
261  /// Emit a graph. The specified builder generates the body of the graph.
262  void emitGraph(function_ref<void()> builder) {
263  os << "digraph G {\n";
264  os.indent();
265  // Edges between clusters are allowed only in compound mode.
266  os << attrStmt("compound", "true") << ";\n";
267  builder();
268  os.unindent();
269  os << "}\n";
270  }
271 
272  /// Emit a node statement.
273  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
274  StringRef background = "") {
275  int nodeId = ++counter;
276  AttributeMap attrs;
277  attrs["label"] = quoteString(label);
278  attrs["shape"] = shape.str();
279  if (!background.empty()) {
280  attrs["style"] = "filled";
281  attrs["fillcolor"] = quoteString(background.str());
282  }
283  os << llvm::format("v%i ", nodeId);
284  emitAttrList(os, attrs);
285  os << ";\n";
286  return Node(nodeId);
287  }
288 
289  std::string getValuePortName(Value operand) {
290  // Print value as an operand and omit the leading '%' character.
291  auto str = strFromOs([&](raw_ostream &os) {
292  operand.printAsOperand(os, OpPrintingFlags());
293  });
294  // Replace % and # with _
295  std::replace(str.begin(), str.end(), '%', '_');
296  std::replace(str.begin(), str.end(), '#', '_');
297  return str;
298  }
299 
300  std::string getClusterLabel(Operation *op) {
301  return strFromOs([&](raw_ostream &os) {
302  // Print operation name and type.
303  os << op->getName();
304  if (printResultTypes) {
305  os << " : (";
306  std::string buf;
307  llvm::raw_string_ostream ss(buf);
308  interleaveComma(op->getResultTypes(), ss);
309  os << truncateString(buf) << ")";
310  }
311 
312  // Print attributes.
313  if (printAttrs) {
314  os << "\\l";
315  for (const NamedAttribute &attr : op->getAttrs()) {
316  os << escapeLabelString(attr.getName().getValue().str()) << ": ";
317  emitMlirAttr(os, attr.getValue());
318  os << "\\l";
319  }
320  }
321  });
322  }
323 
324  /// Generate a label for an operation.
325  std::string getRecordLabel(Operation *op) {
326  return strFromOs([&](raw_ostream &os) {
327  os << "{";
328 
329  // Print operation inputs.
330  if (op->getNumOperands() > 0) {
331  os << "{";
332  auto operandToPort = [&](Value operand) {
333  os << "<arg" << getValuePortName(operand) << "> ";
334  emitMlirOperand(os, operand);
335  };
336  interleave(op->getOperands(), os, operandToPort, "|");
337  os << "}|";
338  }
339  // Print operation name and type.
340  os << op->getName() << "\\l";
341 
342  // Print attributes.
343  if (printAttrs && !op->getAttrs().empty()) {
344  // Extra line break to separate attributes from the operation name.
345  os << "\\l";
346  for (const NamedAttribute &attr : op->getAttrs()) {
347  os << attr.getName().getValue() << ": ";
348  emitMlirAttr(os, attr.getValue());
349  os << "\\l";
350  }
351  }
352 
353  if (op->getNumResults() > 0) {
354  os << "|{";
355  auto resultToPort = [&](Value result) {
356  os << "<res" << getValuePortName(result) << "> ";
357  emitMlirOperand(os, result);
358  if (printResultTypes) {
359  os << " ";
360  emitMlirType(os, result.getType());
361  }
362  };
363  interleave(op->getResults(), os, resultToPort, "|");
364  os << "}";
365  }
366 
367  os << "}";
368  });
369  }
370 
371  /// Generate a label for a block argument.
372  std::string getLabel(BlockArgument arg) {
373  return strFromOs([&](raw_ostream &os) {
374  os << "<res" << getValuePortName(arg) << "> ";
375  arg.printAsOperand(os, OpPrintingFlags());
376  if (printResultTypes) {
377  os << " ";
378  emitMlirType(os, arg.getType());
379  }
380  });
381  }
382 
383  /// Process a block. Emit a cluster and one node per block argument and
384  /// operation inside the cluster.
385  void processBlock(Block &block) {
386  emitClusterStmt([&]() {
387  for (BlockArgument &blockArg : block.getArguments())
388  valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
389  // Emit a node for each operation.
390  std::optional<Node> prevNode;
391  for (Operation &op : block) {
392  Node nextNode = processOperation(&op);
393  if (printControlFlowEdges && prevNode)
394  emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
395  prevNode = nextNode;
396  }
397  });
398  }
399 
400  /// Process an operation. If the operation has regions, emit a cluster.
401  /// Otherwise, emit a node.
402  Node processOperation(Operation *op) {
403  Node node;
404  if (op->getNumRegions() > 0) {
405  // Emit cluster for op with regions.
406  node = emitClusterStmt(
407  [&]() {
408  for (Region &region : op->getRegions())
409  processRegion(region);
410  },
411  getClusterLabel(op));
412  } else {
413  node = emitNodeStmt(getRecordLabel(op), kShapeNode,
414  backgroundColors[op->getName()].second);
415  }
416 
417  // Insert data flow edges originating from each operand.
418  if (printDataFlowEdges) {
419  unsigned numOperands = op->getNumOperands();
420  for (unsigned i = 0; i < numOperands; i++) {
421  auto operand = op->getOperand(i);
422  dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
423  }
424  }
425 
426  for (Value result : op->getResults())
427  valueToNode[result] = node;
428 
429  return node;
430  }
431 
432  /// Process a region.
433  void processRegion(Region &region) {
434  for (Block &block : region.getBlocks())
435  processBlock(block);
436  }
437 
438  /// Truncate long strings.
439  std::string truncateString(std::string str) {
440  if (str.length() <= maxLabelLen)
441  return str;
442  return str.substr(0, maxLabelLen) + "...";
443  }
444 
445  /// Output stream to write DOT file to.
447  /// A list of edges. For simplicity, should be emitted after all nodes were
448  /// emitted.
449  std::vector<std::string> edges;
450  /// Mapping of SSA values to Graphviz nodes/clusters.
451  DenseMap<Value, Node> valueToNode;
452  /// Output for data flow edges is delayed until the end to handle cycles
453  std::vector<DataFlowEdge> dataFlowEdges;
454  /// Counter for generating unique node/subgraph identifiers.
455  int counter = 0;
456 
458 };
459 
460 } // namespace
461 
462 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
463  return std::make_unique<PrintOpPass>(os);
464 }
465 
466 /// Generate a CFG for a region and show it in a window.
467 static void llvmViewGraph(Region &region, const Twine &name) {
468  int fd;
469  std::string filename = llvm::createGraphFilename(name.str(), fd);
470  {
471  llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
472  if (fd == -1) {
473  llvm::errs() << "error opening file '" << filename << "' for writing\n";
474  return;
475  }
476  PrintOpPass pass(os);
477  pass.emitRegionCFG(region);
478  }
479  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
480 }
481 
482 void mlir::Region::viewGraph(const Twine &regionName) {
483  llvmViewGraph(*this, regionName);
484 }
485 
486 void mlir::Region::viewGraph() { viewGraph("region"); }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:39
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
Definition: ViewOpGraph.cpp:54
static const StringRef kLineStyleDataFlow
Definition: ViewOpGraph.cpp:32
static const StringRef kLineStyleControlFlow
Definition: ViewOpGraph.cpp:31
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
Definition: ViewOpGraph.cpp:37
std::map< std::string, std::string > AttributeMap
Definition: ViewOpGraph.cpp:72
static const StringRef kShapeNone
Definition: ViewOpGraph.cpp:34
static const StringRef kShapeNode
Definition: ViewOpGraph.cpp:33
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
Definition: ViewOpGraph.cpp:61
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Definition: ViewOpGraph.cpp:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void viewGraph()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.