17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/Format.h"
19 #include "llvm/Support/GraphWriter.h"
25 #define GEN_PASS_DEF_VIEWOPGRAPH
26 #include "mlir/Transforms/Passes.h.inc"
39 if (std::optional<int64_t> limit =
48 llvm::raw_string_ostream os(buf);
55 return "\"" + str +
"\"";
63 llvm::raw_string_ostream os(buf);
65 if (llvm::is_contained({
'{',
'|',
'<',
'}',
'>',
'\n',
'"'}, c))
85 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
86 : id(
id), clusterId(clusterId) {}
89 std::optional<int> clusterId;
101 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
103 PrintOpPass(raw_ostream &os) : os(os) {}
104 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
106 void runOnOperation()
override {
107 initColorMapping(*getOperation());
109 processOperation(getOperation());
112 markAllAnalysesPreserved();
116 void emitRegionCFG(
Region ®ion) {
117 printControlFlowEdges =
true;
118 printDataFlowEdges =
false;
119 initColorMapping(region);
120 emitGraph([&]() { processRegion(region); });
127 template <
typename T>
128 void initColorMapping(T &irEntity) {
129 backgroundColors.clear();
132 auto &entry = backgroundColors[op->
getName()];
133 if (entry.first == 0)
138 double hue = ((double)indexedOps.index()) / ops.size();
141 backgroundColors[indexedOps.value()->getName()].second =
142 std::to_string(hue) +
" 0.3 0.95";
148 void emitAllEdgeStmts() {
149 if (printDataFlowEdges) {
150 for (
const auto &e : dataFlowEdges) {
155 for (
const std::string &edge : edges)
162 Node emitClusterStmt(
function_ref<
void()> builder, std::string label =
"") {
163 int clusterId = ++counter;
164 os <<
"subgraph cluster_" << clusterId <<
" {\n";
168 os << attrStmt(
"label",
quoteString(label)) <<
";\n";
172 return Node(anchorNode.id, clusterId);
176 std::string attrStmt(
const Twine &key,
const Twine &value) {
177 return (key +
" = " + value).str();
181 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
183 interleaveComma(map, os, [&](
const auto &it) {
184 os << this->attrStmt(it.first, it.second);
190 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
195 if (isa<SplatElementsAttr>(attr)) {
202 auto elements = dyn_cast<ElementsAttr>(attr);
203 if (elements && elements.getNumElements() > largeAttrLimit) {
204 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
205 << std::string(elements.getShapedType().getRank(),
']') <<
" : ";
206 emitMlirType(os, elements.getType());
210 auto array = dyn_cast<ArrayAttr>(attr);
211 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
218 llvm::raw_string_ostream ss(buf);
224 void emitMlirType(raw_ostream &os,
Type type) {
226 llvm::raw_string_ostream ss(buf);
232 void emitMlirOperand(raw_ostream &os,
Value operand) {
238 void emitEdgeStmt(
Node n1,
Node n2, std::string port, StringRef style) {
240 attrs[
"style"] = style.str();
243 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
245 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
247 edges.push_back(
strFromOs([&](raw_ostream &os) {
249 if (!port.empty() && !n1.clusterId)
251 os <<
":res" << port <<
":s";
254 if (!port.empty() && !n2.clusterId)
256 os <<
":arg" << port <<
":n";
257 emitAttrList(os, attrs);
263 os <<
"digraph G {\n";
266 os << attrStmt(
"compound",
"true") <<
";\n";
273 Node emitNodeStmt(std::string label, StringRef shape =
kShapeNode,
274 StringRef background =
"") {
275 int nodeId = ++counter;
278 attrs[
"shape"] = shape.str();
279 if (!background.empty()) {
280 attrs[
"style"] =
"filled";
281 attrs[
"fillcolor"] =
quoteString(background.str());
283 os << llvm::format(
"v%i ", nodeId);
284 emitAttrList(os, attrs);
289 std::string getValuePortName(
Value operand) {
291 auto str =
strFromOs([&](raw_ostream &os) {
295 std::replace(str.begin(), str.end(),
'%',
'_');
296 std::replace(str.begin(), str.end(),
'#',
'_');
300 std::string getClusterLabel(
Operation *op) {
304 if (printResultTypes) {
307 llvm::raw_string_ostream ss(buf);
309 os << truncateString(buf) <<
")";
317 emitMlirAttr(os, attr.getValue());
325 std::string getRecordLabel(
Operation *op) {
332 auto operandToPort = [&](Value operand) {
333 os <<
"<arg" << getValuePortName(operand) <<
"> ";
334 emitMlirOperand(os, operand);
336 interleave(op->
getOperands(), os, operandToPort,
"|");
343 if (printAttrs && !op->
getAttrs().empty()) {
346 for (const NamedAttribute &attr : op->getAttrs()) {
347 os << attr.getName().getValue() <<
": ";
348 emitMlirAttr(os, attr.getValue());
355 auto resultToPort = [&](Value result) {
356 os <<
"<res" << getValuePortName(result) <<
"> ";
357 emitMlirOperand(os, result);
358 if (printResultTypes) {
360 emitMlirType(os, result.getType());
363 interleave(op->
getResults(), os, resultToPort,
"|");
374 os <<
"<res" << getValuePortName(arg) <<
"> ";
376 if (printResultTypes) {
378 emitMlirType(os, arg.
getType());
385 void processBlock(
Block &block) {
386 emitClusterStmt([&]() {
388 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
390 std::optional<Node> prevNode;
392 Node nextNode = processOperation(&op);
393 if (printControlFlowEdges && prevNode)
406 node = emitClusterStmt(
409 processRegion(region);
411 getClusterLabel(op));
413 node = emitNodeStmt(getRecordLabel(op),
kShapeNode,
414 backgroundColors[op->
getName()].second);
418 if (printDataFlowEdges) {
420 for (
unsigned i = 0; i < numOperands; i++) {
422 dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
427 valueToNode[result] = node;
433 void processRegion(
Region ®ion) {
439 std::string truncateString(std::string str) {
440 if (str.length() <= maxLabelLen)
442 return str.substr(0, maxLabelLen) +
"...";
449 std::vector<std::string> edges;
453 std::vector<DataFlowEdge> dataFlowEdges;
463 return std::make_unique<PrintOpPass>(os);
469 std::string filename = llvm::createGraphFilename(name.str(), fd);
471 llvm::raw_fd_ostream os(fd,
true);
473 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
476 PrintOpPass pass(os);
477 pass.emitRegionCFG(region);
479 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.