16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
38 if (std::optional<int64_t> limit =
47 llvm::raw_string_ostream os(buf);
54 return "\"" + str +
"\"";
62 llvm::raw_string_ostream os(buf);
64 if (llvm::is_contained({
'{',
'|',
'<',
'}',
'>',
'\n',
'"'}, c))
84 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
85 : id(
id), clusterId(clusterId) {}
88 std::optional<int> clusterId;
100 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
102 PrintOpPass(raw_ostream &os) : os(os) {}
103 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
105 void runOnOperation()
override {
106 initColorMapping(*getOperation());
108 processOperation(getOperation());
111 markAllAnalysesPreserved();
115 void emitRegionCFG(
Region ®ion) {
116 printControlFlowEdges =
true;
117 printDataFlowEdges =
false;
118 initColorMapping(region);
119 emitGraph([&]() { processRegion(region); });
126 template <
typename T>
127 void initColorMapping(T &irEntity) {
128 backgroundColors.clear();
131 auto &entry = backgroundColors[op->
getName()];
132 if (entry.first == 0)
137 double hue = ((double)indexedOps.index()) / ops.size();
140 backgroundColors[indexedOps.value()->getName()].second =
141 std::to_string(hue) +
" 0.3 0.95";
147 void emitAllEdgeStmts() {
148 if (printDataFlowEdges) {
149 for (
const auto &e : dataFlowEdges) {
154 for (
const std::string &edge : edges)
161 Node emitClusterStmt(
function_ref<
void()> builder, std::string label =
"") {
162 int clusterId = ++counter;
163 os <<
"subgraph cluster_" << clusterId <<
" {\n";
167 os << attrStmt(
"label",
quoteString(label)) <<
";\n";
171 return Node(anchorNode.id, clusterId);
175 std::string attrStmt(
const Twine &key,
const Twine &value) {
176 return (key +
" = " + value).str();
180 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
182 interleaveComma(map, os, [&](
const auto &it) {
183 os << this->attrStmt(it.first, it.second);
189 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
194 if (isa<SplatElementsAttr>(attr)) {
201 auto elements = dyn_cast<ElementsAttr>(attr);
202 if (elements && elements.getNumElements() > largeAttrLimit) {
203 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
204 << std::string(elements.getShapedType().getRank(),
']') <<
" : ";
205 emitMlirType(os, elements.getType());
209 auto array = dyn_cast<ArrayAttr>(attr);
210 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
217 llvm::raw_string_ostream ss(buf);
223 void emitMlirType(raw_ostream &os,
Type type) {
225 llvm::raw_string_ostream ss(buf);
231 void emitMlirOperand(raw_ostream &os,
Value operand) {
237 void emitEdgeStmt(
Node n1,
Node n2, std::string port, StringRef style) {
239 attrs[
"style"] = style.str();
242 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
244 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
246 edges.push_back(
strFromOs([&](raw_ostream &os) {
248 if (!port.empty() && !n1.clusterId)
250 os <<
":res" << port <<
":s";
253 if (!port.empty() && !n2.clusterId)
255 os <<
":arg" << port <<
":n";
256 emitAttrList(os, attrs);
262 os <<
"digraph G {\n";
265 os << attrStmt(
"compound",
"true") <<
";\n";
272 Node emitNodeStmt(std::string label, StringRef shape =
kShapeNode,
273 StringRef background =
"") {
274 int nodeId = ++counter;
277 attrs[
"shape"] = shape.str();
278 if (!background.empty()) {
279 attrs[
"style"] =
"filled";
280 attrs[
"fillcolor"] =
quoteString(background.str());
282 os << llvm::format(
"v%i ", nodeId);
283 emitAttrList(os, attrs);
288 std::string getValuePortName(
Value operand) {
290 auto str =
strFromOs([&](raw_ostream &os) {
294 llvm::replace(str,
'%',
'_');
295 llvm::replace(str,
'#',
'_');
299 std::string getClusterLabel(
Operation *op) {
303 if (printResultTypes) {
306 llvm::raw_string_ostream ss(buf);
308 os << truncateString(buf) <<
")";
316 emitMlirAttr(os, attr.getValue());
324 std::string getRecordLabel(
Operation *op) {
331 auto operandToPort = [&](Value operand) {
332 os <<
"<arg" << getValuePortName(operand) <<
"> ";
333 emitMlirOperand(os, operand);
335 interleave(op->
getOperands(), os, operandToPort,
"|");
342 if (printAttrs && !op->
getAttrs().empty()) {
345 for (const NamedAttribute &attr : op->getAttrs()) {
346 os << attr.getName().getValue() <<
": ";
347 emitMlirAttr(os, attr.getValue());
354 auto resultToPort = [&](Value result) {
355 os <<
"<res" << getValuePortName(result) <<
"> ";
356 emitMlirOperand(os, result);
357 if (printResultTypes) {
359 emitMlirType(os, result.getType());
362 interleave(op->
getResults(), os, resultToPort,
"|");
373 os <<
"<res" << getValuePortName(arg) <<
"> ";
375 if (printResultTypes) {
377 emitMlirType(os, arg.
getType());
384 void processBlock(
Block &block) {
385 emitClusterStmt([&]() {
387 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
389 std::optional<Node> prevNode;
391 Node nextNode = processOperation(&op);
392 if (printControlFlowEdges && prevNode)
405 node = emitClusterStmt(
408 processRegion(region);
410 getClusterLabel(op));
412 node = emitNodeStmt(getRecordLabel(op),
kShapeNode,
413 backgroundColors[op->
getName()].second);
417 if (printDataFlowEdges) {
419 for (
unsigned i = 0; i < numOperands; i++) {
421 dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
426 valueToNode[result] = node;
432 void processRegion(
Region ®ion) {
438 std::string truncateString(std::string str) {
439 if (str.length() <= maxLabelLen)
441 return str.substr(0, maxLabelLen) +
"...";
448 std::vector<std::string> edges;
452 std::vector<DataFlowEdge> dataFlowEdges;
462 return std::make_unique<PrintOpPass>(os);
468 std::string filename = llvm::createGraphFilename(name.str(), fd);
470 llvm::raw_fd_ostream os(fd,
true);
472 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
475 PrintOpPass pass(os);
476 pass.emitRegionCFG(region);
478 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
std::string escapeLabelString(const std::string &str)
For Graphviz record nodes: " Braces, vertical bars and angle brackets must be escaped with a backslas...
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.