MLIR  22.0.0git
NodePrinter.cpp
Go to the documentation of this file.
1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "llvm/ADT/StringExtras.h"
11 #include "llvm/ADT/TypeSwitch.h"
12 #include "llvm/Support/SaveAndRestore.h"
13 #include "llvm/Support/ScopedPrinter.h"
14 #include <optional>
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18 
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 class NodePrinter {
25 public:
26  NodePrinter(raw_ostream &os) : os(os) {}
27 
28  /// Print the given type to the stream.
29  void print(Type type);
30 
31  /// Print the given node to the stream.
32  void print(const Node *node);
33 
34 private:
35  /// Print a range containing children of a node.
36  template <typename RangeT,
37  std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38  * = nullptr>
39  void printChildren(RangeT &&range) {
40  if (range.empty())
41  return;
42 
43  // Print the first N-1 elements with a prefix of "|-".
44  auto it = std::begin(range);
45  for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46  print(*it);
47 
48  // Print the last element.
49  elementIndentStack.back() = true;
50  print(*it);
51  }
52  template <typename RangeT, typename... OthersT,
53  std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54  * = nullptr>
55  void printChildren(RangeT &&range, OthersT &&...others) {
56  printChildren(ArrayRef<const Node *>({range, others...}));
57  }
58  /// Print a range containing children of a node, nesting the children under
59  /// the given label.
60  template <typename RangeT>
61  void printChildren(StringRef label, RangeT &&range) {
62  if (range.empty())
63  return;
64  elementIndentStack.reserve(elementIndentStack.size() + 1);
65  llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
66 
67  printIndent();
68  os << label << "`\n";
69  elementIndentStack.push_back(/*isLastElt*/ false);
70  printChildren(std::forward<RangeT>(range));
71  elementIndentStack.pop_back();
72  }
73 
74  /// Print the given derived node to the stream.
75  void printImpl(const CompoundStmt *stmt);
76  void printImpl(const EraseStmt *stmt);
77  void printImpl(const LetStmt *stmt);
78  void printImpl(const ReplaceStmt *stmt);
79  void printImpl(const ReturnStmt *stmt);
80  void printImpl(const RewriteStmt *stmt);
81 
82  void printImpl(const AttributeExpr *expr);
83  void printImpl(const CallExpr *expr);
84  void printImpl(const DeclRefExpr *expr);
85  void printImpl(const MemberAccessExpr *expr);
86  void printImpl(const OperationExpr *expr);
87  void printImpl(const RangeExpr *expr);
88  void printImpl(const TupleExpr *expr);
89  void printImpl(const TypeExpr *expr);
90 
91  void printImpl(const AttrConstraintDecl *decl);
92  void printImpl(const OpConstraintDecl *decl);
93  void printImpl(const TypeConstraintDecl *decl);
94  void printImpl(const TypeRangeConstraintDecl *decl);
95  void printImpl(const UserConstraintDecl *decl);
96  void printImpl(const ValueConstraintDecl *decl);
97  void printImpl(const ValueRangeConstraintDecl *decl);
98  void printImpl(const NamedAttributeDecl *decl);
99  void printImpl(const OpNameDecl *decl);
100  void printImpl(const PatternDecl *decl);
101  void printImpl(const UserRewriteDecl *decl);
102  void printImpl(const VariableDecl *decl);
103  void printImpl(const Module *module);
104 
105  /// Print the current indent stack.
106  void printIndent() {
107  if (elementIndentStack.empty())
108  return;
109 
110  for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
111  os << (isLastElt ? " " : " |");
112  os << (elementIndentStack.back() ? " `" : " |");
113  }
114 
115  /// The raw output stream.
116  raw_ostream &os;
117 
118  /// A stack of indents and a flag indicating if the current element being
119  /// printed at that indent is the last element.
120  SmallVector<bool> elementIndentStack;
121 };
122 } // namespace
123 
124 void NodePrinter::print(Type type) {
125  // Protect against invalid inputs.
126  if (!type) {
127  os << "Type<NULL>";
128  return;
129  }
130 
131  TypeSwitch<Type>(type)
132  .Case([&](AttributeType) { os << "Attr"; })
133  .Case([&](ConstraintType) { os << "Constraint"; })
134  .Case([&](OperationType type) {
135  os << "Op";
136  if (std::optional<StringRef> name = type.getName())
137  os << "<" << *name << ">";
138  })
139  .Case([&](RangeType type) {
140  print(type.getElementType());
141  os << "Range";
142  })
143  .Case([&](RewriteType) { os << "Rewrite"; })
144  .Case([&](TupleType type) {
145  os << "Tuple<";
146  llvm::interleaveComma(
147  llvm::zip(type.getElementNames(), type.getElementTypes()), os,
148  [&](auto it) {
149  if (!std::get<0>(it).empty())
150  os << std::get<0>(it) << ": ";
151  this->print(std::get<1>(it));
152  });
153  os << ">";
154  })
155  .Case([&](TypeType) { os << "Type"; })
156  .Case([&](ValueType) { os << "Value"; })
157  .Default([](Type) { llvm_unreachable("unknown AST type"); });
158 }
159 
160 void NodePrinter::print(const Node *node) {
161  printIndent();
162  os << "-";
163 
164  elementIndentStack.push_back(/*isLastElt*/ false);
166  .Case<
167  // Statements.
168  const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
169  const ReturnStmt, const RewriteStmt,
170 
171  // Expressions.
172  const AttributeExpr, const CallExpr, const DeclRefExpr,
173  const MemberAccessExpr, const OperationExpr, const RangeExpr,
174  const TupleExpr, const TypeExpr,
175 
176  // Decls.
181  const OpNameDecl, const PatternDecl, const UserRewriteDecl,
182  const VariableDecl,
183 
184  const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
185  .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
186  elementIndentStack.pop_back();
187 }
188 
189 void NodePrinter::printImpl(const CompoundStmt *stmt) {
190  os << "CompoundStmt " << stmt << "\n";
191  printChildren(stmt->getChildren());
192 }
193 
194 void NodePrinter::printImpl(const EraseStmt *stmt) {
195  os << "EraseStmt " << stmt << "\n";
196  printChildren(stmt->getRootOpExpr());
197 }
198 
199 void NodePrinter::printImpl(const LetStmt *stmt) {
200  os << "LetStmt " << stmt << "\n";
201  printChildren(stmt->getVarDecl());
202 }
203 
204 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
205  os << "ReplaceStmt " << stmt << "\n";
206  printChildren(stmt->getRootOpExpr());
207  printChildren("ReplValues", stmt->getReplExprs());
208 }
209 
210 void NodePrinter::printImpl(const ReturnStmt *stmt) {
211  os << "ReturnStmt " << stmt << "\n";
212  printChildren(stmt->getResultExpr());
213 }
214 
215 void NodePrinter::printImpl(const RewriteStmt *stmt) {
216  os << "RewriteStmt " << stmt << "\n";
217  printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
218 }
219 
220 void NodePrinter::printImpl(const AttributeExpr *expr) {
221  os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
222 }
223 
224 void NodePrinter::printImpl(const CallExpr *expr) {
225  os << "CallExpr " << expr << " Type<";
226  print(expr->getType());
227  os << ">";
228  if (expr->getIsNegated())
229  os << " Negated";
230  os << "\n";
231  printChildren(expr->getCallableExpr());
232  printChildren("Arguments", expr->getArguments());
233 }
234 
235 void NodePrinter::printImpl(const DeclRefExpr *expr) {
236  os << "DeclRefExpr " << expr << " Type<";
237  print(expr->getType());
238  os << ">\n";
239  printChildren(expr->getDecl());
240 }
241 
242 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
243  os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
244  << "> Type<";
245  print(expr->getType());
246  os << ">\n";
247  printChildren(expr->getParentExpr());
248 }
249 
250 void NodePrinter::printImpl(const OperationExpr *expr) {
251  os << "OperationExpr " << expr << " Type<";
252  print(expr->getType());
253  os << ">\n";
254 
255  printChildren(expr->getNameDecl());
256  printChildren("Operands", expr->getOperands());
257  printChildren("Result Types", expr->getResultTypes());
258  printChildren("Attributes", expr->getAttributes());
259 }
260 
261 void NodePrinter::printImpl(const RangeExpr *expr) {
262  os << "RangeExpr " << expr << " Type<";
263  print(expr->getType());
264  os << ">\n";
265 
266  printChildren(expr->getElements());
267 }
268 
269 void NodePrinter::printImpl(const TupleExpr *expr) {
270  os << "TupleExpr " << expr << " Type<";
271  print(expr->getType());
272  os << ">\n";
273 
274  printChildren(expr->getElements());
275 }
276 
277 void NodePrinter::printImpl(const TypeExpr *expr) {
278  os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
279 }
280 
281 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
282  os << "AttrConstraintDecl " << decl << "\n";
283  if (const auto *typeExpr = decl->getTypeExpr())
284  printChildren(typeExpr);
285 }
286 
287 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
288  os << "OpConstraintDecl " << decl << "\n";
289  printChildren(decl->getNameDecl());
290 }
291 
292 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
293  os << "TypeConstraintDecl " << decl << "\n";
294 }
295 
296 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
297  os << "TypeRangeConstraintDecl " << decl << "\n";
298 }
299 
300 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
301  os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
302  << "> ResultType<" << decl->getResultType() << ">";
303  if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
304  os << " Code<";
305  llvm::printEscapedString(*codeBlock, os);
306  os << ">";
307  }
308  os << "\n";
309  printChildren("Inputs", decl->getInputs());
310  printChildren("Results", decl->getResults());
311  if (const CompoundStmt *body = decl->getBody())
312  printChildren(body);
313 }
314 
315 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
316  os << "ValueConstraintDecl " << decl << "\n";
317  if (const auto *typeExpr = decl->getTypeExpr())
318  printChildren(typeExpr);
319 }
320 
321 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
322  os << "ValueRangeConstraintDecl " << decl << "\n";
323  if (const auto *typeExpr = decl->getTypeExpr())
324  printChildren(typeExpr);
325 }
326 
327 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
328  os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
329  << ">\n";
330  printChildren(decl->getValue());
331 }
332 
333 void NodePrinter::printImpl(const OpNameDecl *decl) {
334  os << "OpNameDecl " << decl;
335  if (std::optional<StringRef> name = decl->getName())
336  os << " Name<" << *name << ">";
337  os << "\n";
338 }
339 
340 void NodePrinter::printImpl(const PatternDecl *decl) {
341  os << "PatternDecl " << decl;
342  if (const Name *name = decl->getName())
343  os << " Name<" << name->getName() << ">";
344  if (std::optional<uint16_t> benefit = decl->getBenefit())
345  os << " Benefit<" << *benefit << ">";
346  if (decl->hasBoundedRewriteRecursion())
347  os << " Recursion";
348 
349  os << "\n";
350  printChildren(decl->getBody());
351 }
352 
353 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
354  os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
355  << "> ResultType<" << decl->getResultType() << ">";
356  if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
357  os << " Code<";
358  llvm::printEscapedString(*codeBlock, os);
359  os << ">";
360  }
361  os << "\n";
362  printChildren("Inputs", decl->getInputs());
363  printChildren("Results", decl->getResults());
364  if (const CompoundStmt *body = decl->getBody())
365  printChildren(body);
366 }
367 
368 void NodePrinter::printImpl(const VariableDecl *decl) {
369  os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
370  << "> Type<";
371  print(decl->getType());
372  os << ">\n";
373  if (Expr *initExpr = decl->getInitExpr())
374  printChildren(initExpr);
375 
376  auto constraints =
377  llvm::map_range(decl->getConstraints(),
378  [](const ConstraintRef &ref) { return ref.constraint; });
379  printChildren("Constraints", constraints);
380 }
381 
382 void NodePrinter::printImpl(const Module *module) {
383  os << "Module " << module << "\n";
384  printChildren(module->getChildren());
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // Entry point
389 //===----------------------------------------------------------------------===//
390 
391 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
392 
393 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:750
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition: Nodes.h:756
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:370
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:376
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:107
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:393
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:403
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition: Nodes.h:407
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
This class represents a PDLL type that corresponds to a constraint.
Definition: Types.h:121
This expression represents a reference to a Decl node.
Definition: Nodes.h:433
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:438
This statement represents the erase statement in PDLL.
Definition: Nodes.h:255
This class represents a base AST Expression node.
Definition: Nodes.h:348
This statement represents a let statement in PDLL.
Definition: Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:216
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:454
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:461
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:464
This class represents a top-level AST module.
Definition: Nodes.h:1297
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1302
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:998
Expr * getValue() const
Return value of the attribute.
Definition: Nodes.h:1007
const Name & getName() const
Return the name of the attribute.
Definition: Nodes.h:1004
This class represents a base AST node.
Definition: Nodes.h:108
void print(raw_ostream &os) const
Print this node to the given stream.
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition: Nodes.h:774
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:783
This Decl represents an OperationName.
Definition: Nodes.h:1022
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1028
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:512
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:540
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:548
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:525
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:532
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:134
std::optional< StringRef > getName() const
Return the name of this operation type, or std::nullopt if it doesn't have on.
Definition: Types.cpp:81
This Decl represents a single Pattern.
Definition: Nodes.h:1043
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1057
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition: Nodes.h:1054
This expression builds a range from a set of element values (which may be ranges themselves).
Definition: Nodes.h:586
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition: Nodes.h:592
RangeType getType() const
Return the range result type of this expression.
Definition: Nodes.h:600
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:159
Type getElementType() const
Return the element type of this range.
Definition: Types.cpp:100
This statement represents the replace statement in PDLL.
Definition: Nodes.h:271
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:277
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:324
Expr * getResultExpr()
Return the result expression of this statement.
Definition: Nodes.h:329
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition: Nodes.h:302
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:308
This class represents a PDLL type that corresponds to a rewrite reference.
Definition: Types.h:208
This expression builds a tuple from a set of element values.
Definition: Nodes.h:619
TupleType getType() const
Return the tuple result type of this expression.
Definition: Nodes.h:633
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:625
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
ArrayRef< StringRef > getElementNames() const
Return the element names of this tuple.
Definition: Types.cpp:159
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:155
The class represents a Type constraint, and constrains a variable to be a Type.
Definition: Nodes.h:800
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:648
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:654
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition: Nodes.h:815
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:249
void print(raw_ostream &os) const
Print this type to the given stream.
This decl represents a user defined constraint.
Definition: Nodes.h:888
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:914
const Name & getName() const
Return the name of the constraint.
Definition: Nodes.h:911
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition: Nodes.h:936
Type getResultType() const
Return the result type of this constraint.
Definition: Nodes.h:943
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition: Nodes.h:927
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition: Nodes.h:940
This decl represents a user defined rewrite.
Definition: Nodes.h:1098
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition: Nodes.h:1142
Type getResultType() const
Return the result type of this rewrite.
Definition: Nodes.h:1149
const Name & getName() const
Return the name of the rewrite.
Definition: Nodes.h:1121
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition: Nodes.h:1146
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition: Nodes.h:1124
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition: Nodes.h:1133
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:830
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition: Nodes.h:835
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:853
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition: Nodes.h:859
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:262
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1248
const Name & getName() const
Return the name of the decl.
Definition: Nodes.h:1267
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1264
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1255
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1270
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41