16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
38 if (std::optional<int64_t> limit =
47 llvm::raw_string_ostream os(buf);
54 return "\"" + str +
"\"";
62 llvm::raw_string_ostream os(buf);
64 if (llvm::is_contained({
'{',
'|',
'<',
'}',
'>',
'\n',
'"'}, c))
84 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
85 : id(
id), clusterId(clusterId) {}
88 std::optional<int> clusterId;
100 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
102 PrintOpPass(raw_ostream &os) : os(os) {}
103 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
105 void runOnOperation()
override {
106 initColorMapping(*getOperation());
108 processOperation(getOperation());
111 markAllAnalysesPreserved();
115 void emitRegionCFG(
Region ®ion) {
116 printControlFlowEdges =
true;
117 printDataFlowEdges =
false;
118 initColorMapping(region);
119 emitGraph([&]() { processRegion(region); });
126 template <
typename T>
127 void initColorMapping(T &irEntity) {
128 backgroundColors.clear();
131 auto &entry = backgroundColors[op->
getName()];
132 if (entry.first == 0)
137 double hue = ((double)indexedOps.index()) / ops.size();
140 backgroundColors[indexedOps.value()->getName()].second =
141 std::to_string(hue) +
" 0.3 0.95";
147 void emitAllEdgeStmts() {
148 if (printDataFlowEdges) {
149 for (
const auto &e : dataFlowEdges) {
154 for (
const std::string &edge : edges)
162 const std::string &label =
"") {
163 int clusterId = ++counter;
164 os <<
"subgraph cluster_" << clusterId <<
" {\n";
168 os << attrStmt(
"label",
quoteString(label)) <<
";\n";
172 return Node(anchorNode.id, clusterId);
176 std::string attrStmt(
const Twine &key,
const Twine &value) {
177 return (key +
" = " + value).str();
181 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
183 interleaveComma(map, os, [&](
const auto &it) {
184 os << this->attrStmt(it.first, it.second);
190 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
195 if (isa<SplatElementsAttr>(attr)) {
202 auto elements = dyn_cast<ElementsAttr>(attr);
203 if (elements && elements.getNumElements() > largeAttrLimit) {
204 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
205 << std::string(elements.getShapedType().getRank(),
']') <<
" : ";
206 emitMlirType(os, elements.getType());
210 auto array = dyn_cast<ArrayAttr>(attr);
211 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
218 llvm::raw_string_ostream ss(buf);
224 void emitMlirType(raw_ostream &os,
Type type) {
226 llvm::raw_string_ostream ss(buf);
232 void emitMlirOperand(raw_ostream &os,
Value operand) {
238 void emitEdgeStmt(
Node n1,
Node n2, std::string port, StringRef style) {
240 attrs[
"style"] = style.str();
243 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
245 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
247 edges.push_back(
strFromOs([&](raw_ostream &os) {
249 if (!port.empty() && !n1.clusterId)
251 os <<
":res" << port <<
":s";
254 if (!port.empty() && !n2.clusterId)
256 os <<
":arg" << port <<
":n";
257 emitAttrList(os, attrs);
263 os <<
"digraph G {\n";
266 os << attrStmt(
"compound",
"true") <<
";\n";
273 Node emitNodeStmt(
const std::string &label, StringRef shape =
kShapeNode,
274 StringRef background =
"") {
275 int nodeId = ++counter;
278 attrs[
"shape"] = shape.str();
279 if (!background.empty()) {
280 attrs[
"style"] =
"filled";
281 attrs[
"fillcolor"] =
quoteString(background.str());
283 os << llvm::format(
"v%i ", nodeId);
284 emitAttrList(os, attrs);
289 std::string getValuePortName(
Value operand) {
291 auto str =
strFromOs([&](raw_ostream &os) {
295 llvm::replace(str,
'%',
'_');
296 llvm::replace(str,
'#',
'_');
300 std::string getClusterLabel(
Operation *op) {
304 if (printResultTypes) {
307 llvm::raw_string_ostream ss(buf);
309 os << truncateString(buf) <<
")";
317 emitMlirAttr(os, attr.getValue());
325 std::string getRecordLabel(
Operation *op) {
332 auto operandToPort = [&](Value operand) {
333 os <<
"<arg" << getValuePortName(operand) <<
"> ";
334 emitMlirOperand(os, operand);
336 interleave(op->
getOperands(), os, operandToPort,
"|");
343 if (printAttrs && !op->
getAttrs().empty()) {
346 for (const NamedAttribute &attr : op->getAttrs()) {
347 os << attr.getName().getValue() <<
": ";
348 emitMlirAttr(os, attr.getValue());
355 auto resultToPort = [&](Value result) {
356 os <<
"<res" << getValuePortName(result) <<
"> ";
357 emitMlirOperand(os, result);
358 if (printResultTypes) {
360 emitMlirType(os, result.getType());
363 interleave(op->
getResults(), os, resultToPort,
"|");
374 os <<
"<res" << getValuePortName(arg) <<
"> ";
376 if (printResultTypes) {
378 emitMlirType(os, arg.
getType());
385 void processBlock(
Block &block) {
386 emitClusterStmt([&]() {
388 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
390 std::optional<Node> prevNode;
392 Node nextNode = processOperation(&op);
393 if (printControlFlowEdges && prevNode)
406 node = emitClusterStmt(
409 processRegion(region);
411 getClusterLabel(op));
413 node = emitNodeStmt(getRecordLabel(op),
kShapeNode,
414 backgroundColors[op->
getName()].second);
418 if (printDataFlowEdges) {
420 for (
unsigned i = 0; i < numOperands; i++) {
422 dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
427 valueToNode[result] = node;
433 void processRegion(
Region ®ion) {
439 std::string truncateString(std::string str) {
440 if (str.length() <= maxLabelLen)
442 return str.substr(0, maxLabelLen) +
"...";
449 std::vector<std::string> edges;
453 std::vector<DataFlowEdge> dataFlowEdges;
463 return std::make_unique<PrintOpPass>(os);
469 std::string filename = llvm::createGraphFilename(name.str(), fd);
471 llvm::raw_fd_ostream os(fd,
true);
473 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
476 PrintOpPass pass(os);
477 pass.emitRegionCFG(region);
479 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.