MLIR  19.0.0git
NodePrinter.cpp
Go to the documentation of this file.
1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace mlir::pdll::ast;
19 
20 //===----------------------------------------------------------------------===//
21 // NodePrinter
22 //===----------------------------------------------------------------------===//
23 
24 namespace {
25 class NodePrinter {
26 public:
27  NodePrinter(raw_ostream &os) : os(os) {}
28 
29  /// Print the given type to the stream.
30  void print(Type type);
31 
32  /// Print the given node to the stream.
33  void print(const Node *node);
34 
35 private:
36  /// Print a range containing children of a node.
37  template <typename RangeT,
38  std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
39  * = nullptr>
40  void printChildren(RangeT &&range) {
41  if (range.empty())
42  return;
43 
44  // Print the first N-1 elements with a prefix of "|-".
45  auto it = std::begin(range);
46  for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
47  print(*it);
48 
49  // Print the last element.
50  elementIndentStack.back() = true;
51  print(*it);
52  }
53  template <typename RangeT, typename... OthersT,
54  std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
55  * = nullptr>
56  void printChildren(RangeT &&range, OthersT &&...others) {
57  printChildren(ArrayRef<const Node *>({range, others...}));
58  }
59  /// Print a range containing children of a node, nesting the children under
60  /// the given label.
61  template <typename RangeT>
62  void printChildren(StringRef label, RangeT &&range) {
63  if (range.empty())
64  return;
65  elementIndentStack.reserve(elementIndentStack.size() + 1);
66  llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
67 
68  printIndent();
69  os << label << "`\n";
70  elementIndentStack.push_back(/*isLastElt*/ false);
71  printChildren(std::forward<RangeT>(range));
72  elementIndentStack.pop_back();
73  }
74 
75  /// Print the given derived node to the stream.
76  void printImpl(const CompoundStmt *stmt);
77  void printImpl(const EraseStmt *stmt);
78  void printImpl(const LetStmt *stmt);
79  void printImpl(const ReplaceStmt *stmt);
80  void printImpl(const ReturnStmt *stmt);
81  void printImpl(const RewriteStmt *stmt);
82 
83  void printImpl(const AttributeExpr *expr);
84  void printImpl(const CallExpr *expr);
85  void printImpl(const DeclRefExpr *expr);
86  void printImpl(const MemberAccessExpr *expr);
87  void printImpl(const OperationExpr *expr);
88  void printImpl(const RangeExpr *expr);
89  void printImpl(const TupleExpr *expr);
90  void printImpl(const TypeExpr *expr);
91 
92  void printImpl(const AttrConstraintDecl *decl);
93  void printImpl(const OpConstraintDecl *decl);
94  void printImpl(const TypeConstraintDecl *decl);
95  void printImpl(const TypeRangeConstraintDecl *decl);
96  void printImpl(const UserConstraintDecl *decl);
97  void printImpl(const ValueConstraintDecl *decl);
98  void printImpl(const ValueRangeConstraintDecl *decl);
99  void printImpl(const NamedAttributeDecl *decl);
100  void printImpl(const OpNameDecl *decl);
101  void printImpl(const PatternDecl *decl);
102  void printImpl(const UserRewriteDecl *decl);
103  void printImpl(const VariableDecl *decl);
104  void printImpl(const Module *module);
105 
106  /// Print the current indent stack.
107  void printIndent() {
108  if (elementIndentStack.empty())
109  return;
110 
111  for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
112  os << (isLastElt ? " " : " |");
113  os << (elementIndentStack.back() ? " `" : " |");
114  }
115 
116  /// The raw output stream.
117  raw_ostream &os;
118 
119  /// A stack of indents and a flag indicating if the current element being
120  /// printed at that indent is the last element.
121  SmallVector<bool> elementIndentStack;
122 };
123 } // namespace
124 
125 void NodePrinter::print(Type type) {
126  // Protect against invalid inputs.
127  if (!type) {
128  os << "Type<NULL>";
129  return;
130  }
131 
132  TypeSwitch<Type>(type)
133  .Case([&](AttributeType) { os << "Attr"; })
134  .Case([&](ConstraintType) { os << "Constraint"; })
135  .Case([&](OperationType type) {
136  os << "Op";
137  if (std::optional<StringRef> name = type.getName())
138  os << "<" << *name << ">";
139  })
140  .Case([&](RangeType type) {
141  print(type.getElementType());
142  os << "Range";
143  })
144  .Case([&](RewriteType) { os << "Rewrite"; })
145  .Case([&](TupleType type) {
146  os << "Tuple<";
147  llvm::interleaveComma(
148  llvm::zip(type.getElementNames(), type.getElementTypes()), os,
149  [&](auto it) {
150  if (!std::get<0>(it).empty())
151  os << std::get<0>(it) << ": ";
152  this->print(std::get<1>(it));
153  });
154  os << ">";
155  })
156  .Case([&](TypeType) { os << "Type"; })
157  .Case([&](ValueType) { os << "Value"; })
158  .Default([](Type) { llvm_unreachable("unknown AST type"); });
159 }
160 
161 void NodePrinter::print(const Node *node) {
162  printIndent();
163  os << "-";
164 
165  elementIndentStack.push_back(/*isLastElt*/ false);
167  .Case<
168  // Statements.
169  const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
170  const ReturnStmt, const RewriteStmt,
171 
172  // Expressions.
173  const AttributeExpr, const CallExpr, const DeclRefExpr,
174  const MemberAccessExpr, const OperationExpr, const RangeExpr,
175  const TupleExpr, const TypeExpr,
176 
177  // Decls.
182  const OpNameDecl, const PatternDecl, const UserRewriteDecl,
183  const VariableDecl,
184 
185  const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
186  .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
187  elementIndentStack.pop_back();
188 }
189 
190 void NodePrinter::printImpl(const CompoundStmt *stmt) {
191  os << "CompoundStmt " << stmt << "\n";
192  printChildren(stmt->getChildren());
193 }
194 
195 void NodePrinter::printImpl(const EraseStmt *stmt) {
196  os << "EraseStmt " << stmt << "\n";
197  printChildren(stmt->getRootOpExpr());
198 }
199 
200 void NodePrinter::printImpl(const LetStmt *stmt) {
201  os << "LetStmt " << stmt << "\n";
202  printChildren(stmt->getVarDecl());
203 }
204 
205 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
206  os << "ReplaceStmt " << stmt << "\n";
207  printChildren(stmt->getRootOpExpr());
208  printChildren("ReplValues", stmt->getReplExprs());
209 }
210 
211 void NodePrinter::printImpl(const ReturnStmt *stmt) {
212  os << "ReturnStmt " << stmt << "\n";
213  printChildren(stmt->getResultExpr());
214 }
215 
216 void NodePrinter::printImpl(const RewriteStmt *stmt) {
217  os << "RewriteStmt " << stmt << "\n";
218  printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
219 }
220 
221 void NodePrinter::printImpl(const AttributeExpr *expr) {
222  os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
223 }
224 
225 void NodePrinter::printImpl(const CallExpr *expr) {
226  os << "CallExpr " << expr << " Type<";
227  print(expr->getType());
228  os << ">";
229  if (expr->getIsNegated())
230  os << " Negated";
231  os << "\n";
232  printChildren(expr->getCallableExpr());
233  printChildren("Arguments", expr->getArguments());
234 }
235 
236 void NodePrinter::printImpl(const DeclRefExpr *expr) {
237  os << "DeclRefExpr " << expr << " Type<";
238  print(expr->getType());
239  os << ">\n";
240  printChildren(expr->getDecl());
241 }
242 
243 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
244  os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
245  << "> Type<";
246  print(expr->getType());
247  os << ">\n";
248  printChildren(expr->getParentExpr());
249 }
250 
251 void NodePrinter::printImpl(const OperationExpr *expr) {
252  os << "OperationExpr " << expr << " Type<";
253  print(expr->getType());
254  os << ">\n";
255 
256  printChildren(expr->getNameDecl());
257  printChildren("Operands", expr->getOperands());
258  printChildren("Result Types", expr->getResultTypes());
259  printChildren("Attributes", expr->getAttributes());
260 }
261 
262 void NodePrinter::printImpl(const RangeExpr *expr) {
263  os << "RangeExpr " << expr << " Type<";
264  print(expr->getType());
265  os << ">\n";
266 
267  printChildren(expr->getElements());
268 }
269 
270 void NodePrinter::printImpl(const TupleExpr *expr) {
271  os << "TupleExpr " << expr << " Type<";
272  print(expr->getType());
273  os << ">\n";
274 
275  printChildren(expr->getElements());
276 }
277 
278 void NodePrinter::printImpl(const TypeExpr *expr) {
279  os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
280 }
281 
282 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
283  os << "AttrConstraintDecl " << decl << "\n";
284  if (const auto *typeExpr = decl->getTypeExpr())
285  printChildren(typeExpr);
286 }
287 
288 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
289  os << "OpConstraintDecl " << decl << "\n";
290  printChildren(decl->getNameDecl());
291 }
292 
293 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
294  os << "TypeConstraintDecl " << decl << "\n";
295 }
296 
297 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
298  os << "TypeRangeConstraintDecl " << decl << "\n";
299 }
300 
301 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
302  os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
303  << "> ResultType<" << decl->getResultType() << ">";
304  if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
305  os << " Code<";
306  llvm::printEscapedString(*codeBlock, os);
307  os << ">";
308  }
309  os << "\n";
310  printChildren("Inputs", decl->getInputs());
311  printChildren("Results", decl->getResults());
312  if (const CompoundStmt *body = decl->getBody())
313  printChildren(body);
314 }
315 
316 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
317  os << "ValueConstraintDecl " << decl << "\n";
318  if (const auto *typeExpr = decl->getTypeExpr())
319  printChildren(typeExpr);
320 }
321 
322 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
323  os << "ValueRangeConstraintDecl " << decl << "\n";
324  if (const auto *typeExpr = decl->getTypeExpr())
325  printChildren(typeExpr);
326 }
327 
328 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
329  os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
330  << ">\n";
331  printChildren(decl->getValue());
332 }
333 
334 void NodePrinter::printImpl(const OpNameDecl *decl) {
335  os << "OpNameDecl " << decl;
336  if (std::optional<StringRef> name = decl->getName())
337  os << " Name<" << *name << ">";
338  os << "\n";
339 }
340 
341 void NodePrinter::printImpl(const PatternDecl *decl) {
342  os << "PatternDecl " << decl;
343  if (const Name *name = decl->getName())
344  os << " Name<" << name->getName() << ">";
345  if (std::optional<uint16_t> benefit = decl->getBenefit())
346  os << " Benefit<" << *benefit << ">";
347  if (decl->hasBoundedRewriteRecursion())
348  os << " Recursion";
349 
350  os << "\n";
351  printChildren(decl->getBody());
352 }
353 
354 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
355  os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
356  << "> ResultType<" << decl->getResultType() << ">";
357  if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
358  os << " Code<";
359  llvm::printEscapedString(*codeBlock, os);
360  os << ">";
361  }
362  os << "\n";
363  printChildren("Inputs", decl->getInputs());
364  printChildren("Results", decl->getResults());
365  if (const CompoundStmt *body = decl->getBody())
366  printChildren(body);
367 }
368 
369 void NodePrinter::printImpl(const VariableDecl *decl) {
370  os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
371  << "> Type<";
372  print(decl->getType());
373  os << ">\n";
374  if (Expr *initExpr = decl->getInitExpr())
375  printChildren(initExpr);
376 
377  auto constraints =
378  llvm::map_range(decl->getConstraints(),
379  [](const ConstraintRef &ref) { return ref.constraint; });
380  printChildren("Constraints", constraints);
381 }
382 
383 void NodePrinter::printImpl(const Module *module) {
384  os << "Module " << module << "\n";
385  printChildren(module->getChildren());
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // Entry point
390 //===----------------------------------------------------------------------===//
391 
392 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
393 
394 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition: Nodes.h:755
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:367
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:373
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:131
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:390
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:397
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:400
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition: Nodes.h:408
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
This class represents a PDLL type that corresponds to a constraint.
Definition: Types.h:145
This expression represents a reference to a Decl node.
Definition: Nodes.h:434
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:439
This statement represents the erase statement in PDLL.
Definition: Nodes.h:254
This class represents a base AST Expression node.
Definition: Nodes.h:345
This statement represents a let statement in PDLL.
Definition: Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:216
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:455
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:462
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:465
This class represents a top-level AST module.
Definition: Nodes.h:1291
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1296
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
Expr * getValue() const
Return value of the attribute.
Definition: Nodes.h:1001
const Name & getName() const
Return the name of the attribute.
Definition: Nodes.h:998
This class represents a base AST node.
Definition: Nodes.h:108
void print(raw_ostream &os) const
Print this node to the given stream.
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition: Nodes.h:772
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:781
This Decl represents an OperationName.
Definition: Nodes.h:1016
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1022
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:512
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:540
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:548
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:525
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:532
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:158
std::optional< StringRef > getName() const
Return the name of this operation type, or std::nullopt if it doesn't have on.
Definition: Types.cpp:81
This Decl represents a single Pattern.
Definition: Nodes.h:1037
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1051
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1045
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition: Nodes.h:1048
This expression builds a range from a set of element values (which may be ranges themselves).
Definition: Nodes.h:586
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition: Nodes.h:592
RangeType getType() const
Return the range result type of this expression.
Definition: Nodes.h:600
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:183
Type getElementType() const
Return the element type of this range.
Definition: Types.cpp:100
This statement represents the replace statement in PDLL.
Definition: Nodes.h:269
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:275
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:321
Expr * getResultExpr()
Return the result expression of this statement.
Definition: Nodes.h:326
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition: Nodes.h:299
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:305
This class represents a PDLL type that corresponds to a rewrite reference.
Definition: Types.h:230
This expression builds a tuple from a set of element values.
Definition: Nodes.h:619
TupleType getType() const
Return the tuple result type of this expression.
Definition: Nodes.h:633
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:625
This class represents a PDLL tuple type, i.e.
Definition: Types.h:244
ArrayRef< StringRef > getElementNames() const
Return the element names of this tuple.
Definition: Types.cpp:156
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:152
The class represents a Type constraint, and constrains a variable to be a Type.
Definition: Nodes.h:797
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:648
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:654
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition: Nodes.h:811
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:272
void print(raw_ostream &os) const
Print this type to the given stream.
This decl represents a user defined constraint.
Definition: Nodes.h:882
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:908
const Name & getName() const
Return the name of the constraint.
Definition: Nodes.h:905
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition: Nodes.h:930
Type getResultType() const
Return the result type of this constraint.
Definition: Nodes.h:937
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition: Nodes.h:921
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition: Nodes.h:934
This decl represents a user defined rewrite.
Definition: Nodes.h:1092
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition: Nodes.h:1136
Type getResultType() const
Return the result type of this rewrite.
Definition: Nodes.h:1143
const Name & getName() const
Return the name of the rewrite.
Definition: Nodes.h:1115
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition: Nodes.h:1140
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition: Nodes.h:1118
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition: Nodes.h:1127
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition: Nodes.h:830
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition: Nodes.h:853
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:285
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1242
const Name & getName() const
Return the name of the decl.
Definition: Nodes.h:1261
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1258
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1249
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1264
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41