MLIR  22.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
19 #include <map>
20 #include <optional>
21 #include <utility>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 static const StringRef kLineStyleControlFlow = "dashed";
31 static const StringRef kLineStyleDataFlow = "solid";
32 static const StringRef kShapeNode = "Mrecord";
33 static const StringRef kShapeNone = "plain";
34 
35 /// Return the size limits for eliding large attributes.
36 static int64_t getLargeAttributeSizeLimit() {
37  // Use the default from the printer flags if possible.
38  if (std::optional<int64_t> limit =
39  OpPrintingFlags().getLargeElementsAttrLimit())
40  return *limit;
41  return 16;
42 }
43 
44 /// Return all values printed onto a stream as a string.
45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46  std::string buf;
47  llvm::raw_string_ostream os(buf);
48  func(os);
49  return buf;
50 }
51 
52 /// Put quotation marks around a given string.
53 static std::string quoteString(const std::string &str) {
54  return "\"" + str + "\"";
55 }
56 
57 /// For Graphviz record nodes:
58 /// " Braces, vertical bars and angle brackets must be escaped with a backslash
59 /// character if you wish them to appear as a literal character "
60 std::string escapeLabelString(const std::string &str) {
61  std::string buf;
62  llvm::raw_string_ostream os(buf);
63  for (char c : str) {
64  if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
65  os << '\\';
66  os << c;
67  }
68  return buf;
69 }
70 
71 using AttributeMap = std::map<std::string, std::string>;
72 
73 namespace {
74 
75 /// This struct represents a node in the DOT language. Each node has an
76 /// identifier and an optional identifier for the cluster (subgraph) that
77 /// contains the node.
78 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
79 /// not between clusters. However, edges can be clipped to the boundary of a
80 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
81 /// cluster, an invisible "anchor" node is created.
82 struct Node {
83 public:
84  Node(int id = 0, std::optional<int> clusterId = std::nullopt)
85  : id(id), clusterId(clusterId) {}
86 
87  int id;
88  std::optional<int> clusterId;
89 };
90 
91 struct DataFlowEdge {
92  Value value;
93  Node node;
94  std::string port;
95 };
96 
97 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
98 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
99 /// about the Graphviz DOT language.
100 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
101 public:
102  PrintOpPass(raw_ostream &os) : os(os) {}
103  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
104 
105  void runOnOperation() override {
106  initColorMapping(*getOperation());
107  emitGraph([&]() {
108  processOperation(getOperation());
109  emitAllEdgeStmts();
110  });
111  markAllAnalysesPreserved();
112  }
113 
114  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
115  void emitRegionCFG(Region &region) {
116  printControlFlowEdges = true;
117  printDataFlowEdges = false;
118  initColorMapping(region);
119  emitGraph([&]() { processRegion(region); });
120  }
121 
122 private:
123  /// Generate a color mapping that will color every operation with the same
124  /// name the same way. It'll interpolate the hue in the HSV color-space,
125  /// using muted colors that provide good contrast for black text.
126  template <typename T>
127  void initColorMapping(T &irEntity) {
128  backgroundColors.clear();
130  irEntity.walk([&](Operation *op) {
131  auto &entry = backgroundColors[op->getName()];
132  if (entry.first == 0)
133  ops.push_back(op);
134  ++entry.first;
135  });
136  for (auto indexedOps : llvm::enumerate(ops)) {
137  double hue = ((double)indexedOps.index()) / ops.size();
138  // Use lower saturation (0.3) and higher value (0.95) for better
139  // readability
140  backgroundColors[indexedOps.value()->getName()].second =
141  std::to_string(hue) + " 0.3 0.95";
142  }
143  }
144 
145  /// Emit all edges. This function should be called after all nodes have been
146  /// emitted.
147  void emitAllEdgeStmts() {
148  if (printDataFlowEdges) {
149  for (const auto &e : dataFlowEdges) {
150  emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
151  }
152  }
153 
154  for (const std::string &edge : edges)
155  os << edge << ";\n";
156  edges.clear();
157  }
158 
159  /// Emit a cluster (subgraph). The specified builder generates the body of the
160  /// cluster. Return the anchor node of the cluster.
161  Node emitClusterStmt(function_ref<void()> builder,
162  const std::string &label = "") {
163  int clusterId = ++counter;
164  os << "subgraph cluster_" << clusterId << " {\n";
165  os.indent();
166  // Emit invisible anchor node from/to which arrows can be drawn.
167  Node anchorNode = emitNodeStmt(" ", kShapeNone);
168  os << attrStmt("label", quoteString(label)) << ";\n";
169  builder();
170  os.unindent();
171  os << "}\n";
172  return Node(anchorNode.id, clusterId);
173  }
174 
175  /// Generate an attribute statement.
176  std::string attrStmt(const Twine &key, const Twine &value) {
177  return (key + " = " + value).str();
178  }
179 
180  /// Emit an attribute list.
181  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
182  os << "[";
183  interleaveComma(map, os, [&](const auto &it) {
184  os << this->attrStmt(it.first, it.second);
185  });
186  os << "]";
187  }
188 
189  // Print an MLIR attribute to `os`. Large attributes are truncated.
190  void emitMlirAttr(raw_ostream &os, Attribute attr) {
191  // A value used to elide large container attribute.
192  int64_t largeAttrLimit = getLargeAttributeSizeLimit();
193 
194  // Always emit splat attributes.
195  if (isa<SplatElementsAttr>(attr)) {
196  os << escapeLabelString(
197  strFromOs([&](raw_ostream &os) { attr.print(os); }));
198  return;
199  }
200 
201  // Elide "big" elements attributes.
202  auto elements = dyn_cast<ElementsAttr>(attr);
203  if (elements && elements.getNumElements() > largeAttrLimit) {
204  os << std::string(elements.getShapedType().getRank(), '[') << "..."
205  << std::string(elements.getShapedType().getRank(), ']') << " : ";
206  emitMlirType(os, elements.getType());
207  return;
208  }
209 
210  auto array = dyn_cast<ArrayAttr>(attr);
211  if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
212  os << "[...]";
213  return;
214  }
215 
216  // Print all other attributes.
217  std::string buf;
218  llvm::raw_string_ostream ss(buf);
219  attr.print(ss);
220  os << escapeLabelString(truncateString(buf));
221  }
222 
223  // Print a truncated and escaped MLIR type to `os`.
224  void emitMlirType(raw_ostream &os, Type type) {
225  std::string buf;
226  llvm::raw_string_ostream ss(buf);
227  type.print(ss);
228  os << escapeLabelString(truncateString(buf));
229  }
230 
231  // Print a truncated and escaped MLIR operand to `os`.
232  void emitMlirOperand(raw_ostream &os, Value operand) {
233  operand.printAsOperand(os, OpPrintingFlags());
234  }
235 
236  /// Append an edge to the list of edges.
237  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
238  void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
239  AttributeMap attrs;
240  attrs["style"] = style.str();
241  // Use `ltail` and `lhead` to draw edges between clusters.
242  if (n1.clusterId)
243  attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
244  if (n2.clusterId)
245  attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
246 
247  edges.push_back(strFromOs([&](raw_ostream &os) {
248  os << "v" << n1.id;
249  if (!port.empty() && !n1.clusterId)
250  // Attach edge to south compass point of the result
251  os << ":res" << port << ":s";
252  os << " -> ";
253  os << "v" << n2.id;
254  if (!port.empty() && !n2.clusterId)
255  // Attach edge to north compass point of the operand
256  os << ":arg" << port << ":n";
257  emitAttrList(os, attrs);
258  }));
259  }
260 
261  /// Emit a graph. The specified builder generates the body of the graph.
262  void emitGraph(function_ref<void()> builder) {
263  os << "digraph G {\n";
264  os.indent();
265  // Edges between clusters are allowed only in compound mode.
266  os << attrStmt("compound", "true") << ";\n";
267  builder();
268  os.unindent();
269  os << "}\n";
270  }
271 
272  /// Emit a node statement.
273  Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode,
274  StringRef background = "") {
275  int nodeId = ++counter;
276  AttributeMap attrs;
277  attrs["label"] = quoteString(label);
278  attrs["shape"] = shape.str();
279  if (!background.empty()) {
280  attrs["style"] = "filled";
281  attrs["fillcolor"] = quoteString(background.str());
282  }
283  os << llvm::format("v%i ", nodeId);
284  emitAttrList(os, attrs);
285  os << ";\n";
286  return Node(nodeId);
287  }
288 
289  std::string getValuePortName(Value operand) {
290  // Print value as an operand and omit the leading '%' character.
291  auto str = strFromOs([&](raw_ostream &os) {
292  operand.printAsOperand(os, OpPrintingFlags());
293  });
294  // Replace % and # with _
295  llvm::replace(str, '%', '_');
296  llvm::replace(str, '#', '_');
297  return str;
298  }
299 
300  std::string getClusterLabel(Operation *op) {
301  return strFromOs([&](raw_ostream &os) {
302  // Print operation name and type.
303  os << op->getName();
304  if (printResultTypes) {
305  os << " : (";
306  std::string buf;
307  llvm::raw_string_ostream ss(buf);
308  interleaveComma(op->getResultTypes(), ss);
309  os << truncateString(buf) << ")";
310  }
311 
312  // Print attributes.
313  if (printAttrs) {
314  os << "\\l";
315  for (const NamedAttribute &attr : op->getAttrs()) {
316  os << escapeLabelString(attr.getName().getValue().str()) << ": ";
317  emitMlirAttr(os, attr.getValue());
318  os << "\\l";
319  }
320  }
321  });
322  }
323 
324  /// Generate a label for an operation.
325  std::string getRecordLabel(Operation *op) {
326  return strFromOs([&](raw_ostream &os) {
327  os << "{";
328 
329  // Print operation inputs.
330  if (op->getNumOperands() > 0) {
331  os << "{";
332  auto operandToPort = [&](Value operand) {
333  os << "<arg" << getValuePortName(operand) << "> ";
334  emitMlirOperand(os, operand);
335  };
336  interleave(op->getOperands(), os, operandToPort, "|");
337  os << "}|";
338  }
339  // Print operation name and type.
340  os << op->getName() << "\\l";
341 
342  // Print attributes.
343  if (printAttrs && !op->getAttrs().empty()) {
344  // Extra line break to separate attributes from the operation name.
345  os << "\\l";
346  for (const NamedAttribute &attr : op->getAttrs()) {
347  os << attr.getName().getValue() << ": ";
348  emitMlirAttr(os, attr.getValue());
349  os << "\\l";
350  }
351  }
352 
353  if (op->getNumResults() > 0) {
354  os << "|{";
355  auto resultToPort = [&](Value result) {
356  os << "<res" << getValuePortName(result) << "> ";
357  emitMlirOperand(os, result);
358  if (printResultTypes) {
359  os << " ";
360  emitMlirType(os, result.getType());
361  }
362  };
363  interleave(op->getResults(), os, resultToPort, "|");
364  os << "}";
365  }
366 
367  os << "}";
368  });
369  }
370 
371  /// Generate a label for a block argument.
372  std::string getLabel(BlockArgument arg) {
373  return strFromOs([&](raw_ostream &os) {
374  os << "<res" << getValuePortName(arg) << "> ";
375  arg.printAsOperand(os, OpPrintingFlags());
376  if (printResultTypes) {
377  os << " ";
378  emitMlirType(os, arg.getType());
379  }
380  });
381  }
382 
383  /// Process a block. Emit a cluster and one node per block argument and
384  /// operation inside the cluster.
385  void processBlock(Block &block) {
386  emitClusterStmt([&]() {
387  for (BlockArgument &blockArg : block.getArguments())
388  valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
389  // Emit a node for each operation.
390  std::optional<Node> prevNode;
391  for (Operation &op : block) {
392  Node nextNode = processOperation(&op);
393  if (printControlFlowEdges && prevNode)
394  emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
395  prevNode = nextNode;
396  }
397  });
398  }
399 
400  /// Process an operation. If the operation has regions, emit a cluster.
401  /// Otherwise, emit a node.
402  Node processOperation(Operation *op) {
403  Node node;
404  if (op->getNumRegions() > 0) {
405  // Emit cluster for op with regions.
406  node = emitClusterStmt(
407  [&]() {
408  for (Region &region : op->getRegions())
409  processRegion(region);
410  },
411  getClusterLabel(op));
412  } else {
413  node = emitNodeStmt(getRecordLabel(op), kShapeNode,
414  backgroundColors[op->getName()].second);
415  }
416 
417  // Insert data flow edges originating from each operand.
418  if (printDataFlowEdges) {
419  unsigned numOperands = op->getNumOperands();
420  for (unsigned i = 0; i < numOperands; i++) {
421  auto operand = op->getOperand(i);
422  dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
423  }
424  }
425 
426  for (Value result : op->getResults())
427  valueToNode[result] = node;
428 
429  return node;
430  }
431 
432  /// Process a region.
433  void processRegion(Region &region) {
434  for (Block &block : region.getBlocks())
435  processBlock(block);
436  }
437 
438  /// Truncate long strings.
439  std::string truncateString(std::string str) {
440  if (str.length() <= maxLabelLen)
441  return str;
442  return str.substr(0, maxLabelLen) + "...";
443  }
444 
445  /// Output stream to write DOT file to.
447  /// A list of edges. For simplicity, should be emitted after all nodes were
448  /// emitted.
449  std::vector<std::string> edges;
450  /// Mapping of SSA values to Graphviz nodes/clusters.
451  DenseMap<Value, Node> valueToNode;
452  /// Output for data flow edges is delayed until the end to handle cycles
453  std::vector<DataFlowEdge> dataFlowEdges;
454  /// Counter for generating unique node/subgraph identifiers.
455  int counter = 0;
456 
458 };
459 
460 } // namespace
461 
462 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
463  return std::make_unique<PrintOpPass>(os);
464 }
465 
466 /// Generate a CFG for a region and show it in a window.
467 static void llvmViewGraph(Region &region, const Twine &name) {
468  int fd;
469  std::string filename = llvm::createGraphFilename(name.str(), fd);
470  {
471  llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
472  if (fd == -1) {
473  llvm::errs() << "error opening file '" << filename << "' for writing\n";
474  return;
475  }
476  PrintOpPass pass(os);
477  pass.emitRegionCFG(region);
478  }
479  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
480 }
481 
482 void mlir::Region::viewGraph(const Twine &regionName) {
483  llvmViewGraph(*this, regionName);
484 }
485 
486 void mlir::Region::viewGraph() { viewGraph("region"); }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
Definition: ViewOpGraph.cpp:53
static const StringRef kLineStyleDataFlow
Definition: ViewOpGraph.cpp:31
static const StringRef kLineStyleControlFlow
Definition: ViewOpGraph.cpp:30
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
Definition: ViewOpGraph.cpp:36
std::map< std::string, std::string > AttributeMap
Definition: ViewOpGraph.cpp:71
static const StringRef kShapeNone
Definition: ViewOpGraph.cpp:33
static const StringRef kShapeNode
Definition: ViewOpGraph.cpp:32
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
Definition: ViewOpGraph.cpp:60
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Definition: ViewOpGraph.cpp:45
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
void viewGraph()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.