16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/GraphWriter.h"
23 #define GEN_PASS_DEF_VIEWOPGRAPH
24 #include "mlir/Transforms/Passes.h.inc"
37 if (std::optional<int64_t> limit =
46 llvm::raw_string_ostream os(buf);
53 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
58 return "\"" + str +
"\"";
74 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
75 : id(
id), clusterId(clusterId) {}
78 std::optional<int> clusterId;
84 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
86 PrintOpPass(raw_ostream &os) : os(os) {}
87 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
89 void runOnOperation()
override {
90 initColorMapping(*getOperation());
92 processOperation(getOperation());
98 void emitRegionCFG(
Region ®ion) {
99 printControlFlowEdges =
true;
100 printDataFlowEdges =
false;
101 initColorMapping(region);
102 emitGraph([&]() { processRegion(region); });
109 template <
typename T>
110 void initColorMapping(T &irEntity) {
111 backgroundColors.clear();
114 auto &entry = backgroundColors[op->
getName()];
115 if (entry.first == 0)
120 double hue = ((double)indexedOps.index()) / ops.size();
121 backgroundColors[indexedOps.value()->getName()].second =
122 std::to_string(hue) +
" 1.0 1.0";
128 void emitAllEdgeStmts() {
129 for (
const std::string &edge : edges)
136 Node emitClusterStmt(
function_ref<
void()> builder, std::string label =
"") {
137 int clusterId = ++counter;
138 os <<
"subgraph cluster_" << clusterId <<
" {\n";
147 return Node(anchorNode.id, clusterId);
151 std::string attrStmt(
const Twine &key,
const Twine &value) {
152 return (key +
" = " + value).str();
156 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
158 interleaveComma(map, os, [&](
const auto &it) {
159 os << this->attrStmt(it.first, it.second);
165 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
170 if (isa<SplatElementsAttr>(attr)) {
176 auto elements = dyn_cast<ElementsAttr>(attr);
177 if (elements && elements.getNumElements() > largeAttrLimit) {
178 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
179 << std::string(elements.getShapedType().getRank(),
']') <<
" : "
180 << elements.getType();
184 auto array = dyn_cast<ArrayAttr>(attr);
185 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
192 llvm::raw_string_ostream ss(buf);
194 os << truncateString(ss.str());
199 void emitEdgeStmt(
Node n1,
Node n2, std::string label, StringRef style) {
201 attrs[
"style"] = style.str();
205 if (!n1.clusterId && !n2.clusterId)
209 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
211 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
213 edges.push_back(
strFromOs([&](raw_ostream &os) {
214 os << llvm::format(
"v%i -> v%i ", n1.id, n2.id);
215 emitAttrList(os, attrs);
221 os <<
"digraph G {\n";
224 os << attrStmt(
"compound",
"true") <<
";\n";
231 Node emitNodeStmt(std::string label, StringRef shape =
kShapeNode,
232 StringRef background =
"") {
233 int nodeId = ++counter;
236 attrs[
"shape"] = shape.str();
237 if (!background.empty()) {
238 attrs[
"style"] =
"filled";
239 attrs[
"fillcolor"] = (
"\"" + background +
"\"").str();
241 os << llvm::format(
"v%i ", nodeId);
242 emitAttrList(os, attrs);
252 if (printResultTypes) {
255 llvm::raw_string_ostream ss(buf);
257 os << truncateString(ss.str()) <<
")";
264 os <<
'\n' << attr.getName().getValue() <<
": ";
265 emitMlirAttr(os, attr.getValue());
278 void processBlock(
Block &block) {
279 emitClusterStmt([&]() {
281 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
284 std::optional<Node> prevNode;
286 Node nextNode = processOperation(&op);
287 if (printControlFlowEdges && prevNode)
288 emitEdgeStmt(*prevNode, nextNode,
"",
301 node = emitClusterStmt(
304 processRegion(region);
309 backgroundColors[op->
getName()].second);
313 if (printDataFlowEdges) {
315 for (
unsigned i = 0; i < numOperands; i++)
316 emitEdgeStmt(valueToNode[op->
getOperand(i)], node,
317 numOperands == 1 ?
"" : std::to_string(i),
322 valueToNode[result] = node;
328 void processRegion(
Region ®ion) {
334 std::string truncateString(std::string str) {
335 if (str.length() <= maxLabelLen)
337 return str.substr(0, maxLabelLen) +
"...";
344 std::vector<std::string> edges;
356 return std::make_unique<PrintOpPass>(os);
362 std::string filename = llvm::createGraphFilename(name.str(), fd);
364 llvm::raw_fd_ostream os(fd,
true);
366 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
369 PrintOpPass pass(os);
370 pass.emitRegionCFG(region);
372 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string escapeString(std::string str)
Escape special characters such as ' ' and quotation marks.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.