MLIR 22.0.0git
NodePrinter.cpp
Go to the documentation of this file.
1//===- NodePrinter.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/StringExtras.h"
11#include "llvm/ADT/TypeSwitch.h"
12#include "llvm/Support/SaveAndRestore.h"
13#include "llvm/Support/ScopedPrinter.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::pdll::ast;
18
19//===----------------------------------------------------------------------===//
20// NodePrinter
21//===----------------------------------------------------------------------===//
22
23namespace {
24class NodePrinter {
25public:
26 NodePrinter(raw_ostream &os) : os(os) {}
27
28 /// Print the given type to the stream.
29 void print(Type type);
30
31 /// Print the given node to the stream.
32 void print(const Node *node);
33
34private:
35 /// Print a range containing children of a node.
36 template <typename RangeT,
37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38 * = nullptr>
39 void printChildren(RangeT &&range) {
40 if (range.empty())
41 return;
42
43 // Print the first N-1 elements with a prefix of "|-".
44 auto it = std::begin(range);
45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46 print(*it);
47
48 // Print the last element.
49 elementIndentStack.back() = true;
50 print(*it);
51 }
52 template <typename RangeT, typename... OthersT,
53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54 * = nullptr>
55 void printChildren(RangeT &&range, OthersT &&...others) {
56 printChildren(ArrayRef<const Node *>({range, others...}));
57 }
58 /// Print a range containing children of a node, nesting the children under
59 /// the given label.
60 template <typename RangeT>
61 void printChildren(StringRef label, RangeT &&range) {
62 if (range.empty())
63 return;
64 elementIndentStack.reserve(elementIndentStack.size() + 1);
65 llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
66
67 printIndent();
68 os << label << "`\n";
69 elementIndentStack.push_back(/*isLastElt*/ false);
70 printChildren(std::forward<RangeT>(range));
71 elementIndentStack.pop_back();
72 }
73
74 /// Print the given derived node to the stream.
75 void printImpl(const CompoundStmt *stmt);
76 void printImpl(const EraseStmt *stmt);
77 void printImpl(const LetStmt *stmt);
78 void printImpl(const ReplaceStmt *stmt);
79 void printImpl(const ReturnStmt *stmt);
80 void printImpl(const RewriteStmt *stmt);
81
82 void printImpl(const AttributeExpr *expr);
83 void printImpl(const CallExpr *expr);
84 void printImpl(const DeclRefExpr *expr);
85 void printImpl(const MemberAccessExpr *expr);
86 void printImpl(const OperationExpr *expr);
87 void printImpl(const RangeExpr *expr);
88 void printImpl(const TupleExpr *expr);
89 void printImpl(const TypeExpr *expr);
90
91 void printImpl(const AttrConstraintDecl *decl);
92 void printImpl(const OpConstraintDecl *decl);
93 void printImpl(const TypeConstraintDecl *decl);
94 void printImpl(const TypeRangeConstraintDecl *decl);
95 void printImpl(const UserConstraintDecl *decl);
96 void printImpl(const ValueConstraintDecl *decl);
97 void printImpl(const ValueRangeConstraintDecl *decl);
98 void printImpl(const NamedAttributeDecl *decl);
99 void printImpl(const OpNameDecl *decl);
100 void printImpl(const PatternDecl *decl);
101 void printImpl(const UserRewriteDecl *decl);
102 void printImpl(const VariableDecl *decl);
103 void printImpl(const Module *module);
104
105 /// Print the current indent stack.
106 void printIndent() {
107 if (elementIndentStack.empty())
108 return;
109
110 for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
111 os << (isLastElt ? " " : " |");
112 os << (elementIndentStack.back() ? " `" : " |");
113 }
114
115 /// The raw output stream.
116 raw_ostream &os;
117
118 /// A stack of indents and a flag indicating if the current element being
119 /// printed at that indent is the last element.
120 SmallVector<bool> elementIndentStack;
121};
122} // namespace
123
124void NodePrinter::print(Type type) {
125 // Protect against invalid inputs.
126 if (!type) {
127 os << "Type<NULL>";
128 return;
129 }
130
131 TypeSwitch<Type>(type)
132 .Case([&](AttributeType) { os << "Attr"; })
133 .Case([&](ConstraintType) { os << "Constraint"; })
134 .Case([&](OperationType type) {
135 os << "Op";
136 if (std::optional<StringRef> name = type.getName())
137 os << "<" << *name << ">";
138 })
139 .Case([&](RangeType type) {
140 print(type.getElementType());
141 os << "Range";
142 })
143 .Case([&](RewriteType) { os << "Rewrite"; })
144 .Case([&](TupleType type) {
145 os << "Tuple<";
146 llvm::interleaveComma(
147 llvm::zip(type.getElementNames(), type.getElementTypes()), os,
148 [&](auto it) {
149 if (!std::get<0>(it).empty())
150 os << std::get<0>(it) << ": ";
151 this->print(std::get<1>(it));
152 });
153 os << ">";
154 })
155 .Case([&](TypeType) { os << "Type"; })
156 .Case([&](ValueType) { os << "Value"; })
157 .DefaultUnreachable("unknown AST type");
158}
159
160void NodePrinter::print(const Node *node) {
161 printIndent();
162 os << "-";
163
164 elementIndentStack.push_back(/*isLastElt*/ false);
166 .Case<
167 // Statements.
168 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
169 const ReturnStmt, const RewriteStmt,
170
171 // Expressions.
172 const AttributeExpr, const CallExpr, const DeclRefExpr,
173 const MemberAccessExpr, const OperationExpr, const RangeExpr,
174 const TupleExpr, const TypeExpr,
175
176 // Decls.
177 const AttrConstraintDecl, const OpConstraintDecl,
178 const TypeConstraintDecl, const TypeRangeConstraintDecl,
179 const UserConstraintDecl, const ValueConstraintDecl,
180 const ValueRangeConstraintDecl, const NamedAttributeDecl,
181 const OpNameDecl, const PatternDecl, const UserRewriteDecl,
182 const VariableDecl,
183
184 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
185 .DefaultUnreachable("unknown AST node");
186 elementIndentStack.pop_back();
187}
188
189void NodePrinter::printImpl(const CompoundStmt *stmt) {
190 os << "CompoundStmt " << stmt << "\n";
191 printChildren(stmt->getChildren());
192}
193
194void NodePrinter::printImpl(const EraseStmt *stmt) {
195 os << "EraseStmt " << stmt << "\n";
196 printChildren(stmt->getRootOpExpr());
197}
198
199void NodePrinter::printImpl(const LetStmt *stmt) {
200 os << "LetStmt " << stmt << "\n";
201 printChildren(stmt->getVarDecl());
202}
203
204void NodePrinter::printImpl(const ReplaceStmt *stmt) {
205 os << "ReplaceStmt " << stmt << "\n";
206 printChildren(stmt->getRootOpExpr());
207 printChildren("ReplValues", stmt->getReplExprs());
208}
209
210void NodePrinter::printImpl(const ReturnStmt *stmt) {
211 os << "ReturnStmt " << stmt << "\n";
212 printChildren(stmt->getResultExpr());
213}
214
215void NodePrinter::printImpl(const RewriteStmt *stmt) {
216 os << "RewriteStmt " << stmt << "\n";
217 printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
218}
219
220void NodePrinter::printImpl(const AttributeExpr *expr) {
221 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
222}
223
224void NodePrinter::printImpl(const CallExpr *expr) {
225 os << "CallExpr " << expr << " Type<";
226 print(expr->getType());
227 os << ">";
228 if (expr->getIsNegated())
229 os << " Negated";
230 os << "\n";
231 printChildren(expr->getCallableExpr());
232 printChildren("Arguments", expr->getArguments());
233}
234
235void NodePrinter::printImpl(const DeclRefExpr *expr) {
236 os << "DeclRefExpr " << expr << " Type<";
237 print(expr->getType());
238 os << ">\n";
239 printChildren(expr->getDecl());
240}
241
242void NodePrinter::printImpl(const MemberAccessExpr *expr) {
243 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
244 << "> Type<";
245 print(expr->getType());
246 os << ">\n";
247 printChildren(expr->getParentExpr());
248}
249
250void NodePrinter::printImpl(const OperationExpr *expr) {
251 os << "OperationExpr " << expr << " Type<";
252 print(expr->getType());
253 os << ">\n";
254
255 printChildren(expr->getNameDecl());
256 printChildren("Operands", expr->getOperands());
257 printChildren("Result Types", expr->getResultTypes());
258 printChildren("Attributes", expr->getAttributes());
259}
260
261void NodePrinter::printImpl(const RangeExpr *expr) {
262 os << "RangeExpr " << expr << " Type<";
263 print(expr->getType());
264 os << ">\n";
265
266 printChildren(expr->getElements());
267}
268
269void NodePrinter::printImpl(const TupleExpr *expr) {
270 os << "TupleExpr " << expr << " Type<";
271 print(expr->getType());
272 os << ">\n";
273
274 printChildren(expr->getElements());
275}
276
277void NodePrinter::printImpl(const TypeExpr *expr) {
278 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
279}
280
281void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
282 os << "AttrConstraintDecl " << decl << "\n";
283 if (const auto *typeExpr = decl->getTypeExpr())
284 printChildren(typeExpr);
285}
286
287void NodePrinter::printImpl(const OpConstraintDecl *decl) {
288 os << "OpConstraintDecl " << decl << "\n";
289 printChildren(decl->getNameDecl());
290}
291
292void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
293 os << "TypeConstraintDecl " << decl << "\n";
294}
295
296void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
297 os << "TypeRangeConstraintDecl " << decl << "\n";
298}
299
300void NodePrinter::printImpl(const UserConstraintDecl *decl) {
301 os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
302 << "> ResultType<" << decl->getResultType() << ">";
303 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
304 os << " Code<";
305 llvm::printEscapedString(*codeBlock, os);
306 os << ">";
307 }
308 os << "\n";
309 printChildren("Inputs", decl->getInputs());
310 printChildren("Results", decl->getResults());
311 if (const CompoundStmt *body = decl->getBody())
312 printChildren(body);
313}
314
315void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
316 os << "ValueConstraintDecl " << decl << "\n";
317 if (const auto *typeExpr = decl->getTypeExpr())
318 printChildren(typeExpr);
319}
320
321void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
322 os << "ValueRangeConstraintDecl " << decl << "\n";
323 if (const auto *typeExpr = decl->getTypeExpr())
324 printChildren(typeExpr);
325}
326
327void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
328 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
329 << ">\n";
330 printChildren(decl->getValue());
331}
332
333void NodePrinter::printImpl(const OpNameDecl *decl) {
334 os << "OpNameDecl " << decl;
335 if (std::optional<StringRef> name = decl->getName())
336 os << " Name<" << *name << ">";
337 os << "\n";
338}
339
340void NodePrinter::printImpl(const PatternDecl *decl) {
341 os << "PatternDecl " << decl;
342 if (const Name *name = decl->getName())
343 os << " Name<" << name->getName() << ">";
344 if (std::optional<uint16_t> benefit = decl->getBenefit())
345 os << " Benefit<" << *benefit << ">";
346 if (decl->hasBoundedRewriteRecursion())
347 os << " Recursion";
348
349 os << "\n";
350 printChildren(decl->getBody());
351}
352
353void NodePrinter::printImpl(const UserRewriteDecl *decl) {
354 os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
355 << "> ResultType<" << decl->getResultType() << ">";
356 if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
357 os << " Code<";
358 llvm::printEscapedString(*codeBlock, os);
359 os << ">";
360 }
361 os << "\n";
362 printChildren("Inputs", decl->getInputs());
363 printChildren("Results", decl->getResults());
364 if (const CompoundStmt *body = decl->getBody())
365 printChildren(body);
366}
367
368void NodePrinter::printImpl(const VariableDecl *decl) {
369 os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
370 << "> Type<";
371 print(decl->getType());
372 os << ">\n";
373 if (Expr *initExpr = decl->getInitExpr())
374 printChildren(initExpr);
375
376 auto constraints =
377 llvm::map_range(decl->getConstraints(),
378 [](const ConstraintRef &ref) { return ref.constraint; });
379 printChildren("Constraints", constraints);
380}
381
382void NodePrinter::printImpl(const Module *module) {
383 os << "Module " << module << "\n";
384 printChildren(module->getChildren());
385}
386
387//===----------------------------------------------------------------------===//
388// Entry point
389//===----------------------------------------------------------------------===//
390
391void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
392
393void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
MemRefDependenceGraph::Node Node
Definition Utils.cpp:38
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition Nodes.h:756
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:376
Expr * getCallableExpr() const
Return the callable of this call.
Definition Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition Nodes.h:403
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition Nodes.h:407
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition Nodes.h:185
Decl * getDecl() const
Get the decl referenced by this expression.
Definition Nodes.h:438
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
Type getType() const
Return the type of this expression.
Definition Nodes.h:351
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition Nodes.h:216
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition Nodes.h:461
StringRef getMemberName() const
Return the name of the member being accessed.
Definition Nodes.h:464
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition Nodes.h:1302
const Name & getName() const
Return the name of the attribute.
Definition Nodes.h:1004
Expr * getValue() const
Return value of the attribute.
Definition Nodes.h:1007
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition Nodes.h:783
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition Nodes.h:1028
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition Nodes.h:237
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition Nodes.h:532
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition Nodes.h:548
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition Nodes.h:525
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition Nodes.h:540
std::optional< StringRef > getName() const
Return the name of this operation type, or std::nullopt if it doesn't have on.
Definition Types.cpp:81
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition Nodes.h:1057
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition Nodes.h:1054
RangeType getType() const
Return the range result type of this expression.
Definition Nodes.h:600
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition Nodes.h:592
Type getElementType() const
Return the element type of this range.
Definition Types.cpp:100
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition Nodes.h:277
Expr * getResultExpr()
Return the result expression of this statement.
Definition Nodes.h:329
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition Nodes.h:308
TupleType getType() const
Return the tuple result type of this expression.
Definition Nodes.h:633
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition Nodes.h:625
ArrayRef< StringRef > getElementNames() const
Return the element names of this tuple.
Definition Types.cpp:159
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition Types.cpp:155
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:654
void print(raw_ostream &os) const
Print this type to the given stream.
const Name & getName() const
Return the name of the constraint.
Definition Nodes.h:911
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition Nodes.h:940
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition Nodes.h:927
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition Nodes.h:914
Type getResultType() const
Return the result type of this constraint.
Definition Nodes.h:943
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition Nodes.h:936
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition Nodes.h:1142
const Name & getName() const
Return the name of the rewrite.
Definition Nodes.h:1121
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition Nodes.h:1146
Type getResultType() const
Return the result type of this rewrite.
Definition Nodes.h:1149
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition Nodes.h:1133
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition Nodes.h:1124
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition Nodes.h:835
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition Nodes.h:859
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition Nodes.h:1255
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition Nodes.h:1264
const Name & getName() const
Return the name of the decl.
Definition Nodes.h:1267
Type getType() const
Return the type of the decl.
Definition Nodes.h:1270
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41