17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
38 if (std::optional<int64_t> limit =
47 llvm::raw_string_ostream os(buf);
54 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
59 return "\"" + str +
"\"";
75 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
76 : id(
id), clusterId(clusterId) {}
79 std::optional<int> clusterId;
85 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
87 PrintOpPass(raw_ostream &os) : os(os) {}
88 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
90 void runOnOperation()
override {
91 initColorMapping(*getOperation());
93 processOperation(getOperation());
96 markAllAnalysesPreserved();
100 void emitRegionCFG(
Region ®ion) {
101 printControlFlowEdges =
true;
102 printDataFlowEdges =
false;
103 initColorMapping(region);
104 emitGraph([&]() { processRegion(region); });
111 template <
typename T>
112 void initColorMapping(T &irEntity) {
113 backgroundColors.clear();
116 auto &entry = backgroundColors[op->
getName()];
117 if (entry.first == 0)
122 double hue = ((double)indexedOps.index()) / ops.size();
123 backgroundColors[indexedOps.value()->getName()].second =
124 std::to_string(hue) +
" 1.0 1.0";
130 void emitAllEdgeStmts() {
131 if (printDataFlowEdges) {
132 for (
const auto &[value, node, label] : dataFlowEdges) {
137 for (
const std::string &edge : edges)
144 Node emitClusterStmt(
function_ref<
void()> builder, std::string label =
"") {
145 int clusterId = ++counter;
146 os <<
"subgraph cluster_" << clusterId <<
" {\n";
155 return Node(anchorNode.id, clusterId);
159 std::string attrStmt(
const Twine &key,
const Twine &value) {
160 return (key +
" = " + value).str();
164 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
166 interleaveComma(map, os, [&](
const auto &it) {
167 os << this->attrStmt(it.first, it.second);
173 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
178 if (isa<SplatElementsAttr>(attr)) {
184 auto elements = dyn_cast<ElementsAttr>(attr);
185 if (elements && elements.getNumElements() > largeAttrLimit) {
186 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
187 << std::string(elements.getShapedType().getRank(),
']') <<
" : "
188 << elements.getType();
192 auto array = dyn_cast<ArrayAttr>(attr);
193 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
200 llvm::raw_string_ostream ss(buf);
202 os << truncateString(buf);
207 void emitEdgeStmt(
Node n1,
Node n2, std::string label, StringRef style) {
209 attrs[
"style"] = style.str();
213 if (!n1.clusterId && !n2.clusterId)
217 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
219 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
221 edges.push_back(
strFromOs([&](raw_ostream &os) {
222 os << llvm::format(
"v%i -> v%i ", n1.id, n2.id);
223 emitAttrList(os, attrs);
229 os <<
"digraph G {\n";
232 os << attrStmt(
"compound",
"true") <<
";\n";
239 Node emitNodeStmt(std::string label, StringRef shape =
kShapeNode,
240 StringRef background =
"") {
241 int nodeId = ++counter;
244 attrs[
"shape"] = shape.str();
245 if (!background.empty()) {
246 attrs[
"style"] =
"filled";
247 attrs[
"fillcolor"] = (
"\"" + background +
"\"").str();
249 os << llvm::format(
"v%i ", nodeId);
250 emitAttrList(os, attrs);
260 if (printResultTypes) {
263 llvm::raw_string_ostream ss(buf);
265 os << truncateString(buf) <<
")";
272 os <<
'\n' << attr.getName().getValue() <<
": ";
273 emitMlirAttr(os, attr.getValue());
286 void processBlock(
Block &block) {
287 emitClusterStmt([&]() {
289 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
292 std::optional<Node> prevNode;
294 Node nextNode = processOperation(&op);
295 if (printControlFlowEdges && prevNode)
296 emitEdgeStmt(*prevNode, nextNode,
"",
309 node = emitClusterStmt(
312 processRegion(region);
317 backgroundColors[op->
getName()].second);
321 if (printDataFlowEdges) {
323 for (
unsigned i = 0; i < numOperands; i++)
324 dataFlowEdges.push_back({op->getOperand(i), node,
325 numOperands == 1 ?
"" : std::to_string(i)});
329 valueToNode[result] = node;
335 void processRegion(
Region ®ion) {
341 std::string truncateString(std::string str) {
342 if (str.length() <= maxLabelLen)
344 return str.substr(0, maxLabelLen) +
"...";
351 std::vector<std::string> edges;
355 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
365 return std::make_unique<PrintOpPass>(os);
371 std::string filename = llvm::createGraphFilename(name.str(), fd);
373 llvm::raw_fd_ostream os(fd,
true);
375 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
378 PrintOpPass pass(os);
379 pass.emitRegionCFG(region);
381 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string escapeString(std::string str)
Escape special characters such as ' ' and quotation marks.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.