17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
38 if (std::optional<int64_t> limit =
47 llvm::raw_string_ostream os(buf);
54 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
59 return "\"" + str +
"\"";
75 Node(
int id = 0, std::optional<int> clusterId = std::nullopt)
76 : id(
id), clusterId(clusterId) {}
79 std::optional<int> clusterId;
85 class PrintOpPass :
public impl::ViewOpGraphBase<PrintOpPass> {
87 PrintOpPass(raw_ostream &os) : os(os) {}
88 PrintOpPass(
const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
90 void runOnOperation()
override {
91 initColorMapping(*getOperation());
93 processOperation(getOperation());
99 void emitRegionCFG(
Region ®ion) {
100 printControlFlowEdges =
true;
101 printDataFlowEdges =
false;
102 initColorMapping(region);
103 emitGraph([&]() { processRegion(region); });
110 template <
typename T>
111 void initColorMapping(T &irEntity) {
112 backgroundColors.clear();
115 auto &entry = backgroundColors[op->
getName()];
116 if (entry.first == 0)
121 double hue = ((double)indexedOps.index()) / ops.size();
122 backgroundColors[indexedOps.value()->getName()].second =
123 std::to_string(hue) +
" 1.0 1.0";
129 void emitAllEdgeStmts() {
130 if (printDataFlowEdges) {
131 for (
const auto &[value, node, label] : dataFlowEdges) {
136 for (
const std::string &edge : edges)
143 Node emitClusterStmt(
function_ref<
void()> builder, std::string label =
"") {
144 int clusterId = ++counter;
145 os <<
"subgraph cluster_" << clusterId <<
" {\n";
154 return Node(anchorNode.id, clusterId);
158 std::string attrStmt(
const Twine &key,
const Twine &value) {
159 return (key +
" = " + value).str();
163 void emitAttrList(raw_ostream &os,
const AttributeMap &map) {
165 interleaveComma(map, os, [&](
const auto &it) {
166 os << this->attrStmt(it.first, it.second);
172 void emitMlirAttr(raw_ostream &os,
Attribute attr) {
177 if (isa<SplatElementsAttr>(attr)) {
183 auto elements = dyn_cast<ElementsAttr>(attr);
184 if (elements && elements.getNumElements() > largeAttrLimit) {
185 os << std::string(elements.getShapedType().getRank(),
'[') <<
"..."
186 << std::string(elements.getShapedType().getRank(),
']') <<
" : "
187 << elements.getType();
191 auto array = dyn_cast<ArrayAttr>(attr);
192 if (array &&
static_cast<int64_t
>(array.size()) > largeAttrLimit) {
199 llvm::raw_string_ostream ss(buf);
201 os << truncateString(ss.str());
206 void emitEdgeStmt(
Node n1,
Node n2, std::string label, StringRef style) {
208 attrs[
"style"] = style.str();
212 if (!n1.clusterId && !n2.clusterId)
216 attrs[
"ltail"] =
"cluster_" + std::to_string(*n1.clusterId);
218 attrs[
"lhead"] =
"cluster_" + std::to_string(*n2.clusterId);
220 edges.push_back(
strFromOs([&](raw_ostream &os) {
221 os << llvm::format(
"v%i -> v%i ", n1.id, n2.id);
222 emitAttrList(os, attrs);
228 os <<
"digraph G {\n";
231 os << attrStmt(
"compound",
"true") <<
";\n";
238 Node emitNodeStmt(std::string label, StringRef shape =
kShapeNode,
239 StringRef background =
"") {
240 int nodeId = ++counter;
243 attrs[
"shape"] = shape.str();
244 if (!background.empty()) {
245 attrs[
"style"] =
"filled";
246 attrs[
"fillcolor"] = (
"\"" + background +
"\"").str();
248 os << llvm::format(
"v%i ", nodeId);
249 emitAttrList(os, attrs);
259 if (printResultTypes) {
262 llvm::raw_string_ostream ss(buf);
264 os << truncateString(ss.str()) <<
")";
271 os <<
'\n' << attr.getName().getValue() <<
": ";
272 emitMlirAttr(os, attr.getValue());
285 void processBlock(
Block &block) {
286 emitClusterStmt([&]() {
288 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
291 std::optional<Node> prevNode;
293 Node nextNode = processOperation(&op);
294 if (printControlFlowEdges && prevNode)
295 emitEdgeStmt(*prevNode, nextNode,
"",
308 node = emitClusterStmt(
311 processRegion(region);
316 backgroundColors[op->
getName()].second);
320 if (printDataFlowEdges) {
322 for (
unsigned i = 0; i < numOperands; i++)
323 dataFlowEdges.push_back({op->getOperand(i), node,
324 numOperands == 1 ?
"" : std::to_string(i)});
328 valueToNode[result] = node;
334 void processRegion(
Region ®ion) {
340 std::string truncateString(std::string str) {
341 if (str.length() <= maxLabelLen)
343 return str.substr(0, maxLabelLen) +
"...";
350 std::vector<std::string> edges;
354 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
364 return std::make_unique<PrintOpPass>(os);
370 std::string filename = llvm::createGraphFilename(name.str(), fd);
372 llvm::raw_fd_ostream os(fd,
true);
374 llvm::errs() <<
"error opening file '" << filename <<
"' for writing\n";
377 PrintOpPass pass(os);
378 pass.emitRegionCFG(region);
380 llvm::DisplayGraph(filename,
false, llvm::GraphProgram::DOT);
MemRefDependenceGraph::Node Node
static std::string quoteString(const std::string &str)
Put quotation marks around a given string.
static const StringRef kLineStyleDataFlow
static const StringRef kLineStyleControlFlow
static int64_t getLargeAttributeSizeLimit()
Return the size limits for eliding large attributes.
std::map< std::string, std::string > AttributeMap
static const StringRef kShapeNone
static const StringRef kShapeNode
static void llvmViewGraph(Region ®ion, const Twine &name)
Generate a CFG for a region and show it in a window.
static std::string escapeString(std::string str)
Escape special characters such as ' ' and quotation marks.
static std::string strFromOs(function_ref< void(raw_ostream &)> func)
Return all values printed onto a stream as a string.
Attributes are known-constant values of operations.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
NamedAttribute represents a combination of a name and an Attribute value.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
raw_ostream subclass that simplifies indention a sequence of code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpGraphPass(raw_ostream &os=llvm::errs())
Creates a pass to print op graphs.