MLIR 22.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/IR/Block.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/Pass/Pass.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/GraphWriter.h"
19#include <map>
20#include <optional>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_VIEWOPGRAPH
25#include "mlir/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static const StringRef kLineStyleControlFlow = "dashed";
31static const StringRef kLineStyleDataFlow = "solid";
32static const StringRef kShapeNode = "Mrecord";
33static const StringRef kShapeNone = "plain";
34
35/// Return the size limits for eliding large attributes.
37 // Use the default from the printer flags if possible.
38 if (std::optional<int64_t> limit =
39 OpPrintingFlags().getLargeElementsAttrLimit())
40 return *limit;
41 return 16;
42}
43
44/// Return all values printed onto a stream as a string.
45static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46 std::string buf;
47 llvm::raw_string_ostream os(buf);
48 func(os);
49 return buf;
50}
51
52/// Put quotation marks around a given string.
53static std::string quoteString(const std::string &str) {
54 return "\"" + str + "\"";
55}
56
57/// For Graphviz record nodes:
58/// " Braces, vertical bars and angle brackets must be escaped with a backslash
59/// character if you wish them to appear as a literal character "
60std::string escapeLabelString(const std::string &str) {
61 std::string buf;
62 llvm::raw_string_ostream os(buf);
63 for (char c : str) {
64 if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
65 os << '\\';
66 os << c;
67 }
68 return buf;
69}
70
71using AttributeMap = std::map<std::string, std::string>;
72
73namespace {
74
75/// This struct represents a node in the DOT language. Each node has an
76/// identifier and an optional identifier for the cluster (subgraph) that
77/// contains the node.
78/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
79/// not between clusters. However, edges can be clipped to the boundary of a
80/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
81/// cluster, an invisible "anchor" node is created.
82struct Node {
83public:
84 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
85 : id(id), clusterId(clusterId) {}
86
87 int id;
88 std::optional<int> clusterId;
89};
90
91struct DataFlowEdge {
92 Value value;
93 Node node;
94 std::string port;
95};
96
97/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
98/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
99/// about the Graphviz DOT language.
100class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
101public:
102 PrintOpPass(raw_ostream &os) : os(os) {}
103 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
104
105 void runOnOperation() override {
106 initColorMapping(*getOperation());
107 emitGraph([&]() {
108 processOperation(getOperation());
109 emitAllEdgeStmts();
110 });
111 markAllAnalysesPreserved();
112 }
113
114 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
115 void emitRegionCFG(Region &region) {
116 printControlFlowEdges = true;
117 printDataFlowEdges = false;
118 initColorMapping(region);
119 emitGraph([&]() { processRegion(region); });
120 }
121
122private:
123 /// Generate a color mapping that will color every operation with the same
124 /// name the same way. It'll interpolate the hue in the HSV color-space,
125 /// using muted colors that provide good contrast for black text.
126 template <typename T>
127 void initColorMapping(T &irEntity) {
128 backgroundColors.clear();
130 irEntity.walk([&](Operation *op) {
131 auto &entry = backgroundColors[op->getName()];
132 if (entry.first == 0)
133 ops.push_back(op);
134 ++entry.first;
135 });
136 for (auto indexedOps : llvm::enumerate(ops)) {
137 double hue = ((double)indexedOps.index()) / ops.size();
138 // Use lower saturation (0.3) and higher value (0.95) for better
139 // readability
140 backgroundColors[indexedOps.value()->getName()].second =
141 std::to_string(hue) + " 0.3 0.95";
142 }
143 }
144
145 /// Emit all edges. This function should be called after all nodes have been
146 /// emitted.
147 void emitAllEdgeStmts() {
148 if (printDataFlowEdges) {
149 for (const auto &e : dataFlowEdges) {
150 emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
151 }
152 }
153
154 for (const std::string &edge : edges)
155 os << edge << ";\n";
156 edges.clear();
157 }
158
159 /// Emit a cluster (subgraph). The specified builder generates the body of the
160 /// cluster. Return the anchor node of the cluster.
161 Node emitClusterStmt(function_ref<void()> builder,
162 const std::string &label = "") {
163 int clusterId = ++counter;
164 os << "subgraph cluster_" << clusterId << " {\n";
165 os.indent();
166 // Emit invisible anchor node from/to which arrows can be drawn.
167 Node anchorNode = emitNodeStmt(" ", kShapeNone);
168 os << attrStmt("label", quoteString(label)) << ";\n";
169 builder();
170 os.unindent();
171 os << "}\n";
172 return Node(anchorNode.id, clusterId);
173 }
174
175 /// Generate an attribute statement.
176 std::string attrStmt(const Twine &key, const Twine &value) {
177 return (key + " = " + value).str();
178 }
179
180 /// Emit an attribute list.
181 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
182 os << "[";
183 interleaveComma(map, os, [&](const auto &it) {
184 os << this->attrStmt(it.first, it.second);
185 });
186 os << "]";
187 }
188
189 // Print an MLIR attribute to `os`. Large attributes are truncated.
190 void emitMlirAttr(raw_ostream &os, Attribute attr) {
191 // A value used to elide large container attribute.
192 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
193
194 // Always emit splat attributes.
195 if (isa<SplatElementsAttr>(attr)) {
196 os << escapeLabelString(
197 strFromOs([&](raw_ostream &os) { attr.print(os); }));
198 return;
199 }
200
201 // Elide "big" elements attributes.
202 auto elements = dyn_cast<ElementsAttr>(attr);
203 if (elements && elements.getNumElements() > largeAttrLimit) {
204 os << std::string(elements.getShapedType().getRank(), '[') << "..."
205 << std::string(elements.getShapedType().getRank(), ']') << " : ";
206 emitMlirType(os, elements.getType());
207 return;
208 }
209
210 auto array = dyn_cast<ArrayAttr>(attr);
211 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
212 os << "[...]";
213 return;
214 }
215
216 // Print all other attributes.
217 std::string buf;
218 llvm::raw_string_ostream ss(buf);
219 attr.print(ss);
220 os << escapeLabelString(truncateString(buf));
221 }
222
223 // Print a truncated and escaped MLIR type to `os`.
224 void emitMlirType(raw_ostream &os, Type type) {
225 std::string buf;
226 llvm::raw_string_ostream ss(buf);
227 type.print(ss);
228 os << escapeLabelString(truncateString(buf));
229 }
230
231 // Print a truncated and escaped MLIR operand to `os`.
232 void emitMlirOperand(raw_ostream &os, Value operand) {
233 operand.printAsOperand(os, OpPrintingFlags());
234 }
235
236 /// Append an edge to the list of edges.
237 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
238 void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
239 AttributeMap attrs;
240 attrs["style"] = style.str();
241 // Use `ltail` and `lhead` to draw edges between clusters.
242 if (n1.clusterId)
243 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
244 if (n2.clusterId)
245 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
246
247 edges.push_back(strFromOs([&](raw_ostream &os) {
248 os << "v" << n1.id;
249 if (!port.empty() && !n1.clusterId)
250 // Attach edge to south compass point of the result
251 os << ":res" << port << ":s";
252 os << " -> ";
253 os << "v" << n2.id;
254 if (!port.empty() && !n2.clusterId)
255 // Attach edge to north compass point of the operand
256 os << ":arg" << port << ":n";
257 emitAttrList(os, attrs);
258 }));
259 }
260
261 /// Emit a graph. The specified builder generates the body of the graph.
262 void emitGraph(function_ref<void()> builder) {
263 os << "digraph G {\n";
264 os.indent();
265 // Edges between clusters are allowed only in compound mode.
266 os << attrStmt("compound", "true") << ";\n";
267 builder();
268 os.unindent();
269 os << "}\n";
270 }
271
272 /// Emit a node statement.
273 Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode,
274 StringRef background = "") {
275 int nodeId = ++counter;
276 AttributeMap attrs;
277 attrs["label"] = quoteString(label);
278 attrs["shape"] = shape.str();
279 if (!background.empty()) {
280 attrs["style"] = "filled";
281 attrs["fillcolor"] = quoteString(background.str());
282 }
283 os << llvm::format("v%i ", nodeId);
284 emitAttrList(os, attrs);
285 os << ";\n";
286 return Node(nodeId);
287 }
288
289 std::string getValuePortName(Value operand) {
290 // Print value as an operand and omit the leading '%' character.
291 auto str = strFromOs([&](raw_ostream &os) {
292 operand.printAsOperand(os, OpPrintingFlags());
293 });
294 // Replace % and # with _
295 llvm::replace(str, '%', '_');
296 llvm::replace(str, '#', '_');
297 return str;
298 }
299
300 std::string getClusterLabel(Operation *op) {
301 return strFromOs([&](raw_ostream &os) {
302 // Print operation name and type.
303 os << op->getName();
304 if (printResultTypes) {
305 os << " : (";
306 std::string buf;
307 llvm::raw_string_ostream ss(buf);
308 interleaveComma(op->getResultTypes(), ss);
309 os << truncateString(buf) << ")";
310 }
311
312 // Print attributes.
313 if (printAttrs) {
314 os << "\\l";
315 for (const NamedAttribute &attr : op->getAttrs()) {
316 os << escapeLabelString(attr.getName().getValue().str()) << ": ";
317 emitMlirAttr(os, attr.getValue());
318 os << "\\l";
319 }
320 }
321 });
322 }
323
324 /// Generate a label for an operation.
325 std::string getRecordLabel(Operation *op) {
326 return strFromOs([&](raw_ostream &os) {
327 os << "{";
328
329 // Print operation inputs.
330 if (op->getNumOperands() > 0) {
331 os << "{";
332 auto operandToPort = [&](Value operand) {
333 os << "<arg" << getValuePortName(operand) << "> ";
334 emitMlirOperand(os, operand);
335 };
336 interleave(op->getOperands(), os, operandToPort, "|");
337 os << "}|";
338 }
339 // Print operation name and type.
340 os << op->getName() << "\\l";
341
342 // Print attributes.
343 if (printAttrs && !op->getAttrs().empty()) {
344 // Extra line break to separate attributes from the operation name.
345 os << "\\l";
346 for (const NamedAttribute &attr : op->getAttrs()) {
347 os << attr.getName().getValue() << ": ";
348 emitMlirAttr(os, attr.getValue());
349 os << "\\l";
350 }
351 }
352
353 if (op->getNumResults() > 0) {
354 os << "|{";
355 auto resultToPort = [&](Value result) {
356 os << "<res" << getValuePortName(result) << "> ";
357 emitMlirOperand(os, result);
358 if (printResultTypes) {
359 os << " ";
360 emitMlirType(os, result.getType());
361 }
362 };
363 interleave(op->getResults(), os, resultToPort, "|");
364 os << "}";
365 }
366
367 os << "}";
368 });
369 }
370
371 /// Generate a label for a block argument.
372 std::string getLabel(BlockArgument arg) {
373 return strFromOs([&](raw_ostream &os) {
374 os << "<res" << getValuePortName(arg) << "> ";
376 if (printResultTypes) {
377 os << " ";
378 emitMlirType(os, arg.getType());
379 }
380 });
381 }
382
383 /// Process a block. Emit a cluster and one node per block argument and
384 /// operation inside the cluster.
385 void processBlock(Block &block) {
386 emitClusterStmt([&]() {
387 for (BlockArgument &blockArg : block.getArguments())
388 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
389 // Emit a node for each operation.
390 std::optional<Node> prevNode;
391 for (Operation &op : block) {
392 Node nextNode = processOperation(&op);
393 if (printControlFlowEdges && prevNode)
394 emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
395 prevNode = nextNode;
396 }
397 });
398 }
399
400 /// Process an operation. If the operation has regions, emit a cluster.
401 /// Otherwise, emit a node.
402 Node processOperation(Operation *op) {
403 Node node;
404 if (op->getNumRegions() > 0) {
405 // Emit cluster for op with regions.
406 node = emitClusterStmt(
407 [&]() {
408 for (Region &region : op->getRegions())
409 processRegion(region);
410 },
411 getClusterLabel(op));
412 } else {
413 node = emitNodeStmt(getRecordLabel(op), kShapeNode,
414 backgroundColors[op->getName()].second);
415 }
416
417 // Insert data flow edges originating from each operand.
418 if (printDataFlowEdges) {
419 unsigned numOperands = op->getNumOperands();
420 for (unsigned i = 0; i < numOperands; i++) {
421 auto operand = op->getOperand(i);
422 dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
423 }
424 }
425
426 for (Value result : op->getResults())
427 valueToNode[result] = node;
428
429 return node;
430 }
431
432 /// Process a region.
433 void processRegion(Region &region) {
434 for (Block &block : region.getBlocks())
435 processBlock(block);
436 }
437
438 /// Truncate long strings.
439 std::string truncateString(std::string str) {
440 if (str.length() <= maxLabelLen)
441 return str;
442 return str.substr(0, maxLabelLen) + "...";
443 }
444
445 /// Output stream to write DOT file to.
447 /// A list of edges. For simplicity, should be emitted after all nodes were
448 /// emitted.
449 std::vector<std::string> edges;
450 /// Mapping of SSA values to Graphviz nodes/clusters.
451 DenseMap<Value, Node> valueToNode;
452 /// Output for data flow edges is delayed until the end to handle cycles
453 std::vector<DataFlowEdge> dataFlowEdges;
454 /// Counter for generating unique node/subgraph identifiers.
455 int counter = 0;
456
458};
459
460} // namespace
461
462std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
463 return std::make_unique<PrintOpPass>(os);
464}
465
466/// Generate a CFG for a region and show it in a window.
467static void llvmViewGraph(Region &region, const Twine &name) {
468 int fd;
469 std::string filename = llvm::createGraphFilename(name.str(), fd);
470 {
471 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
472 if (fd == -1) {
473 llvm::errs() << "error opening file '" << filename << "' for writing\n";
474 return;
475 }
476 PrintOpPass pass(os);
477 pass.emitRegionCFG(region);
478 }
479 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
480}
481
482void mlir::Region::viewGraph(const Twine &regionName) {
483 llvmViewGraph(*this, regionName);
484}
485
MemRefDependenceGraph::Node Node
Definition Utils.cpp:38
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
Definition Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:87
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void viewGraph(const Twine &regionName)
Displays the CFG in a window.
BlockListType & getBlocks()
Definition Region.h:45
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
raw_ostream & getOStream() const
Returns the underlying (unindented) raw_ostream.
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152