MLIR 23.0.0git
ViewOpGraph.cpp
Go to the documentation of this file.
1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/IR/Block.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/Pass/Pass.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/GraphWriter.h"
19#include <map>
20#include <optional>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_VIEWOPGRAPHPASS
25#include "mlir/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static const StringRef kLineStyleControlFlow = "dashed";
31static const StringRef kLineStyleDataFlow = "solid";
32static const StringRef kShapeNode = "Mrecord";
33static const StringRef kShapeNone = "plain";
34
35/// Return the size limits for eliding large attributes.
37 // Use the default from the printer flags if possible.
38 if (std::optional<int64_t> limit =
39 OpPrintingFlags().getLargeElementsAttrLimit())
40 return *limit;
41 return 16;
42}
43
44/// Return all values printed onto a stream as a string.
45static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46 std::string buf;
47 llvm::raw_string_ostream os(buf);
48 func(os);
49 return buf;
50}
51
52/// Put quotation marks around a given string.
53static std::string quoteString(const std::string &str) {
54 return "\"" + str + "\"";
55}
56
57/// For Graphviz record nodes:
58/// " Braces, vertical bars and angle brackets must be escaped with a backslash
59/// character if you wish them to appear as a literal character "
60static std::string escapeLabelString(const std::string &str) {
61 std::string buf;
62 llvm::raw_string_ostream os(buf);
63 for (char c : str) {
64 if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
65 os << '\\';
66 os << c;
67 }
68 return buf;
69}
70
71using AttributeMap = std::map<std::string, std::string>;
72
73namespace {
74
75/// This struct represents a node in the DOT language. Each node has an
76/// identifier and an optional identifier for the cluster (subgraph) that
77/// contains the node.
78/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
79/// not between clusters. However, edges can be clipped to the boundary of a
80/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
81/// cluster, an invisible "anchor" node is created.
82struct Node {
83public:
84 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
85 : id(id), clusterId(clusterId) {}
86
87 int id;
88 std::optional<int> clusterId;
89};
90
91struct DataFlowEdge {
92 Value value;
93 Node node;
94 std::string port;
95};
96
97/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
98/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
99/// about the Graphviz DOT language.
100class PrintOpPass : public impl::ViewOpGraphPassBase<PrintOpPass> {
101public:
102 PrintOpPass() : os(llvm::errs()) {}
103 explicit PrintOpPass(ViewOpGraphPassOptions options)
105 os(llvm::errs()) {}
106 PrintOpPass(raw_ostream &os) : os(os) {}
107 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
108
109 void runOnOperation() override {
110 initColorMapping(*getOperation());
111 emitGraph([&]() {
112 processOperation(getOperation());
113 emitAllEdgeStmts();
114 });
115 markAllAnalysesPreserved();
116 }
117
118 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
119 void emitRegionCFG(Region &region) {
120 printControlFlowEdges = true;
121 printDataFlowEdges = false;
122 initColorMapping(region);
123 emitGraph([&]() { processRegion(region); });
124 }
125
126private:
127 /// Generate a color mapping that will color every operation with the same
128 /// name the same way. It'll interpolate the hue in the HSV color-space,
129 /// using muted colors that provide good contrast for black text.
130 template <typename T>
131 void initColorMapping(T &irEntity) {
132 backgroundColors.clear();
134 irEntity.walk([&](Operation *op) {
135 auto &entry = backgroundColors[op->getName()];
136 if (entry.first == 0)
137 ops.push_back(op);
138 ++entry.first;
139 });
140 for (auto indexedOps : llvm::enumerate(ops)) {
141 double hue = ((double)indexedOps.index()) / ops.size();
142 // Use lower saturation (0.3) and higher value (0.95) for better
143 // readability
144 backgroundColors[indexedOps.value()->getName()].second =
145 std::to_string(hue) + " 0.3 0.95";
146 }
147 }
148
149 /// Emit all edges. This function should be called after all nodes have been
150 /// emitted.
151 void emitAllEdgeStmts() {
152 if (printDataFlowEdges) {
153 for (const auto &e : dataFlowEdges) {
154 emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
155 }
156 }
157
158 for (const std::string &edge : edges)
159 os << edge << ";\n";
160 edges.clear();
161 }
162
163 /// Emit a cluster (subgraph). The specified builder generates the body of the
164 /// cluster. Return the anchor node of the cluster.
165 Node emitClusterStmt(function_ref<void()> builder,
166 const std::string &label = "") {
167 int clusterId = ++counter;
168 os << "subgraph cluster_" << clusterId << " {\n";
169 os.indent();
170 // Emit invisible anchor node from/to which arrows can be drawn.
171 Node anchorNode = emitNodeStmt(" ", kShapeNone);
172 os << attrStmt("label", quoteString(label)) << ";\n";
173 builder();
174 os.unindent();
175 os << "}\n";
176 return Node(anchorNode.id, clusterId);
177 }
178
179 /// Generate an attribute statement.
180 std::string attrStmt(const Twine &key, const Twine &value) {
181 return (key + " = " + value).str();
182 }
183
184 /// Emit an attribute list.
185 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
186 os << "[";
187 interleaveComma(map, os, [&](const auto &it) {
188 os << this->attrStmt(it.first, it.second);
189 });
190 os << "]";
191 }
192
193 // Print an MLIR attribute to `os`. Large attributes are truncated.
194 void emitMlirAttr(raw_ostream &os, Attribute attr) {
195 // A value used to elide large container attribute.
196 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
197
198 // Always emit splat attributes.
199 if (isa<SplatElementsAttr>(attr)) {
200 os << escapeLabelString(
201 strFromOs([&](raw_ostream &os) { attr.print(os); }));
202 return;
203 }
204
205 // Elide "big" elements attributes.
206 auto elements = dyn_cast<ElementsAttr>(attr);
207 if (elements && elements.getNumElements() > largeAttrLimit) {
208 os << std::string(elements.getShapedType().getRank(), '[') << "..."
209 << std::string(elements.getShapedType().getRank(), ']') << " : ";
210 emitMlirType(os, elements.getType());
211 return;
212 }
213
214 auto array = dyn_cast<ArrayAttr>(attr);
215 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
216 os << "[...]";
217 return;
218 }
219
220 // Print all other attributes.
221 std::string buf;
222 llvm::raw_string_ostream ss(buf);
223 attr.print(ss);
224 os << escapeLabelString(truncateString(buf));
225 }
226
227 // Print a truncated and escaped MLIR type to `os`.
228 void emitMlirType(raw_ostream &os, Type type) {
229 std::string buf;
230 llvm::raw_string_ostream ss(buf);
231 type.print(ss);
232 os << escapeLabelString(truncateString(buf));
233 }
234
235 // Print a truncated and escaped MLIR operand to `os`.
236 void emitMlirOperand(raw_ostream &os, Value operand) {
237 operand.printAsOperand(os, OpPrintingFlags());
238 }
239
240 /// Append an edge to the list of edges.
241 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
242 void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
243 AttributeMap attrs;
244 attrs["style"] = style.str();
245 // Use `ltail` and `lhead` to draw edges between clusters.
246 if (n1.clusterId)
247 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
248 if (n2.clusterId)
249 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
250
251 edges.push_back(strFromOs([&](raw_ostream &os) {
252 os << "v" << n1.id;
253 if (!port.empty() && !n1.clusterId)
254 // Attach edge to south compass point of the result
255 os << ":res" << port << ":s";
256 os << " -> ";
257 os << "v" << n2.id;
258 if (!port.empty() && !n2.clusterId)
259 // Attach edge to north compass point of the operand
260 os << ":arg" << port << ":n";
261 emitAttrList(os, attrs);
262 }));
263 }
264
265 /// Emit a graph. The specified builder generates the body of the graph.
266 void emitGraph(function_ref<void()> builder) {
267 os << "digraph G {\n";
268 os.indent();
269 // Edges between clusters are allowed only in compound mode.
270 os << attrStmt("compound", "true") << ";\n";
271 builder();
272 os.unindent();
273 os << "}\n";
274 }
275
276 /// Emit a node statement.
277 Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode,
278 StringRef background = "") {
279 int nodeId = ++counter;
280 AttributeMap attrs;
281 attrs["label"] = quoteString(label);
282 attrs["shape"] = shape.str();
283 if (!background.empty()) {
284 attrs["style"] = "filled";
285 attrs["fillcolor"] = quoteString(background.str());
286 }
287 os << llvm::format("v%i ", nodeId);
288 emitAttrList(os, attrs);
289 os << ";\n";
290 return Node(nodeId);
291 }
292
293 std::string getValuePortName(Value operand) {
294 // Print value as an operand and omit the leading '%' character.
295 auto str = strFromOs([&](raw_ostream &os) {
296 operand.printAsOperand(os, OpPrintingFlags());
297 });
298 // Replace % and # with _
299 llvm::replace(str, '%', '_');
300 llvm::replace(str, '#', '_');
301 return str;
302 }
303
304 std::string getClusterLabel(Operation *op) {
305 return strFromOs([&](raw_ostream &os) {
306 // Print operation name and type.
307 os << op->getName();
308 if (printResultTypes) {
309 os << " : (";
310 std::string buf;
311 llvm::raw_string_ostream ss(buf);
312 interleaveComma(op->getResultTypes(), ss);
313 os << truncateString(buf) << ")";
314 }
315
316 // Print attributes.
317 if (printAttrs) {
318 os << "\\l";
319 for (const NamedAttribute &attr : op->getAttrs()) {
320 os << escapeLabelString(attr.getName().getValue().str()) << ": ";
321 emitMlirAttr(os, attr.getValue());
322 os << "\\l";
323 }
324 }
325 });
326 }
327
328 /// Generate a label for an operation.
329 std::string getRecordLabel(Operation *op) {
330 return strFromOs([&](raw_ostream &os) {
331 os << "{";
332
333 // Print operation inputs.
334 if (op->getNumOperands() > 0) {
335 os << "{";
336 auto operandToPort = [&](Value operand) {
337 os << "<arg" << getValuePortName(operand) << "> ";
338 emitMlirOperand(os, operand);
339 };
340 interleave(op->getOperands(), os, operandToPort, "|");
341 os << "}|";
342 }
343 // Print operation name and type.
344 os << op->getName() << "\\l";
345
346 // Print attributes.
347 if (printAttrs && !op->getAttrs().empty()) {
348 // Extra line break to separate attributes from the operation name.
349 os << "\\l";
350 for (const NamedAttribute &attr : op->getAttrs()) {
351 os << attr.getName().getValue() << ": ";
352 emitMlirAttr(os, attr.getValue());
353 os << "\\l";
354 }
355 }
356
357 if (op->getNumResults() > 0) {
358 os << "|{";
359 auto resultToPort = [&](Value result) {
360 os << "<res" << getValuePortName(result) << "> ";
361 emitMlirOperand(os, result);
362 if (printResultTypes) {
363 os << " ";
364 emitMlirType(os, result.getType());
365 }
366 };
367 interleave(op->getResults(), os, resultToPort, "|");
368 os << "}";
369 }
370
371 os << "}";
372 });
373 }
374
375 /// Generate a label for a block argument.
376 std::string getLabel(BlockArgument arg) {
377 return strFromOs([&](raw_ostream &os) {
378 os << "<res" << getValuePortName(arg) << "> ";
380 if (printResultTypes) {
381 os << " ";
382 emitMlirType(os, arg.getType());
383 }
384 });
385 }
386
387 /// Process a block. Emit a cluster and one node per block argument and
388 /// operation inside the cluster.
389 void processBlock(Block &block) {
390 emitClusterStmt([&]() {
391 for (BlockArgument &blockArg : block.getArguments())
392 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
393 // Emit a node for each operation.
394 std::optional<Node> prevNode;
395 for (Operation &op : block) {
396 Node nextNode = processOperation(&op);
397 if (printControlFlowEdges && prevNode)
398 emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
399 prevNode = nextNode;
400 }
401 });
402 }
403
404 /// Process an operation. If the operation has regions, emit a cluster.
405 /// Otherwise, emit a node.
406 Node processOperation(Operation *op) {
407 Node node;
408 if (op->getNumRegions() > 0) {
409 // Emit cluster for op with regions.
410 node = emitClusterStmt(
411 [&]() {
412 for (Region &region : op->getRegions())
413 processRegion(region);
414 },
415 getClusterLabel(op));
416 } else {
417 node = emitNodeStmt(getRecordLabel(op), kShapeNode,
418 backgroundColors[op->getName()].second);
419 }
420
421 // Insert data flow edges originating from each operand.
422 if (printDataFlowEdges) {
423 unsigned numOperands = op->getNumOperands();
424 for (unsigned i = 0; i < numOperands; i++) {
425 auto operand = op->getOperand(i);
426 dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
427 }
428 }
429
430 for (Value result : op->getResults())
431 valueToNode[result] = node;
432
433 return node;
434 }
435
436 /// Process a region.
437 void processRegion(Region &region) {
438 for (Block &block : region.getBlocks())
439 processBlock(block);
440 }
441
442 /// Truncate long strings.
443 std::string truncateString(std::string str) {
444 if (str.length() <= maxLabelLen)
445 return str;
446 return str.substr(0, maxLabelLen) + "...";
447 }
448
449 /// Output stream to write DOT file to.
451 /// A list of edges. For simplicity, should be emitted after all nodes were
452 /// emitted.
453 std::vector<std::string> edges;
454 /// Mapping of SSA values to Graphviz nodes/clusters.
455 DenseMap<Value, Node> valueToNode;
456 /// Output for data flow edges is delayed until the end to handle cycles
457 std::vector<DataFlowEdge> dataFlowEdges;
458 /// Counter for generating unique node/subgraph identifiers.
459 int counter = 0;
460
462};
463
464} // namespace
465
466std::unique_ptr<Pass> mlir::createViewOpGraphPass(raw_ostream &os) {
467 return std::make_unique<PrintOpPass>(os);
468}
469
470/// Generate a CFG for a region and show it in a window.
471static void llvmViewGraph(Region &region, const Twine &name) {
472 int fd;
473 std::string filename = llvm::createGraphFilename(name.str(), fd);
474 {
475 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
476 if (fd == -1) {
477 llvm::errs() << "error opening file '" << filename << "' for writing\n";
478 return;
479 }
480 PrintOpPass pass(os);
481 pass.emitRegionCFG(region);
482 }
483 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
484}
485
486void mlir::Region::viewGraph(const Twine &regionName) {
487 llvmViewGraph(*this, regionName);
488}
489
MemRefDependenceGraph::Node Node
Definition Utils.cpp:39
static llvm::ManagedStatic< PassManagerOptions > options
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
static void llvmViewGraph(Region &region, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
Definition Attributes.h:25
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:97
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:358
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:520
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:682
unsigned getNumOperands()
Definition Operation.h:354
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:685
result_type_range getResultTypes()
Definition Operation.h:436
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:386
result_range getResults()
Definition Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:412
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void viewGraph(const Twine &regionName)
Displays the CFG in a window.
BlockListType & getBlocks()
Definition Region.h:45
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
raw_ostream & getOStream() const
Returns the underlying (unindented) raw_ostream.
Include the generated interface declarations.
std::unique_ptr<::mlir::Pass > createViewOpGraphPass()
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144