MLIR  22.0.0git
CPPGen.cpp
Go to the documentation of this file.
1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14 
17 #include "mlir/IR/BuiltinOps.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::pdll;
29 
30 //===----------------------------------------------------------------------===//
31 // CodeGen
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 class CodeGen {
36 public:
37  CodeGen(raw_ostream &os) : os(os) {}
38 
39  /// Generate C++ code for the given PDL pattern module.
40  void generate(const ast::Module &astModule, ModuleOp module);
41 
42 private:
43  void generate(pdl::PatternOp pattern, StringRef patternName,
44  StringSet<> &nativeFunctions);
45 
46  /// Generate C++ code for all user defined constraints and rewrites with
47  /// native code.
48  void generateConstraintAndRewrites(const ast::Module &astModule,
49  ModuleOp module,
50  StringSet<> &nativeFunctions);
51  void generate(const ast::UserConstraintDecl *decl,
52  StringSet<> &nativeFunctions);
53  void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
54  void generateConstraintOrRewrite(const ast::CallableDecl *decl,
55  bool isConstraint,
56  StringSet<> &nativeFunctions);
57 
58  /// Return the native name for the type of the given type.
59  StringRef getNativeTypeName(ast::Type type);
60 
61  /// Return the native name for the type of the given variable decl.
62  StringRef getNativeTypeName(ast::VariableDecl *decl);
63 
64  /// The stream to output to.
65  raw_ostream &os;
66 };
67 } // namespace
68 
69 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
71  StringSet<> nativeFunctions;
72 
73  // Generate code for any native functions within the module.
74  generateConstraintAndRewrites(astModule, module, nativeFunctions);
75 
76  os << "namespace {\n";
77  std::string basePatternName = "GeneratedPDLLPattern";
78  int patternIndex = 0;
79  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
80  // If the pattern has a name, use that. Otherwise, generate a unique name.
81  if (std::optional<StringRef> patternName = pattern.getSymName()) {
82  patternNames.insert(patternName->str());
83  } else {
84  std::string name;
85  do {
86  name = (basePatternName + Twine(patternIndex++)).str();
87  } while (!patternNames.insert(name));
88  }
89 
90  generate(pattern, patternNames.back(), nativeFunctions);
91  }
92  os << "} // end namespace\n\n";
93 
94  // Emit function to add the generated matchers to the pattern list.
95  os << "template <typename... ConfigsT>\n"
96  "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
97  "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
98  for (const auto &name : patternNames)
99  os << " patterns.add<" << name
100  << ">(patterns.getContext(), configs...);\n";
101  os << "}\n";
102 }
103 
104 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
105  StringSet<> &nativeFunctions) {
106  const char *patternClassStartStr = R"(
107 struct {0} : ::mlir::PDLPatternModule {{
108  template <typename... ConfigsT>
109  {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
110  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
111 )";
112  os << llvm::formatv(patternClassStartStr, patternName);
113 
114  os << "R\"mlir(";
115  pattern->print(os, OpPrintingFlags().enableDebugInfo());
116  os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
117 
118  // Register any native functions used within the pattern.
119  StringSet<> registeredNativeFunctions;
120  auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
121  if (!nativeFunctions.count(fnName) ||
122  !registeredNativeFunctions.insert(fnName).second)
123  return;
124  os << " register" << fnType << "Function(\"" << fnName << "\", "
125  << fnName << "PDLFn);\n";
126  };
127  pattern.walk([&](Operation *op) {
128  if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
129  checkRegisterNativeFn(constraintOp.getName(), "Constraint");
130  else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
131  checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
132  });
133  os << " }\n};\n\n";
134 }
135 
136 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
137  ModuleOp module,
138  StringSet<> &nativeFunctions) {
139  // First check to see which constraints and rewrites are actually referenced
140  // in the module.
141  StringSet<> usedFns;
142  module.walk([&](Operation *op) {
144  .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
145  [&](auto op) { usedFns.insert(op.getName()); });
146  });
147 
148  for (const ast::Decl *decl : astModule.getChildren()) {
151  [&](const auto *decl) {
152  // We only generate code for inline native decls that have been
153  // referenced.
154  if (decl->getCodeBlock() &&
155  usedFns.contains(decl->getName().getName()))
156  this->generate(decl, nativeFunctions);
157  });
158  }
159 }
160 
161 void CodeGen::generate(const ast::UserConstraintDecl *decl,
162  StringSet<> &nativeFunctions) {
163  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
164  /*isConstraint=*/true, nativeFunctions);
165 }
166 
167 void CodeGen::generate(const ast::UserRewriteDecl *decl,
168  StringSet<> &nativeFunctions) {
169  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
170  /*isConstraint=*/false, nativeFunctions);
171 }
172 
173 StringRef CodeGen::getNativeTypeName(ast::Type type) {
175  .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
176  .Case([&](ast::OperationType opType) -> StringRef {
177  // Use the derived Op class when available.
178  if (const auto *odsOp = opType.getODSOperation())
179  return odsOp->getNativeClassName();
180  return "::mlir::Operation *";
181  })
182  .Case([&](ast::TypeType) { return "::mlir::Type"; })
183  .Case([&](ast::ValueType) { return "::mlir::Value"; })
184  .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
185  .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
186 }
187 
188 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
189  // Try to extract a type name from the variable's constraints.
190  for (ast::ConstraintRef &cst : decl->getConstraints()) {
191  if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
192  if (std::optional<StringRef> name = userCst->getNativeInputType(0))
193  return *name;
194  return getNativeTypeName(userCst->getInputs()[0]);
195  }
196  }
197 
198  // Otherwise, use the type of the variable.
199  return getNativeTypeName(decl->getType());
200 }
201 
202 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
203  bool isConstraint,
204  StringSet<> &nativeFunctions) {
205  StringRef name = decl->getName()->getName();
206  nativeFunctions.insert(name);
207 
208  os << "static ";
209 
210  // TODO: Work out a proper modeling for "optionality".
211 
212  // Emit the result type.
213  // If this is a constraint, we always return a LogicalResult.
214  // TODO: This will need to change if we allow Constraints to return values as
215  // well.
216  if (isConstraint) {
217  os << "::llvm::LogicalResult";
218  } else {
219  // Otherwise, generate a type based on the results of the callable.
220  // If the callable has explicit results, use those to build the result.
221  // Otherwise, use the type of the callable.
222  ArrayRef<ast::VariableDecl *> results = decl->getResults();
223  if (results.empty()) {
224  os << "void";
225  } else if (results.size() == 1) {
226  os << getNativeTypeName(results[0]);
227  } else {
228  os << "std::tuple<";
229  llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
230  os << getNativeTypeName(result);
231  });
232  os << ">";
233  }
234  }
235 
236  os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
237  if (!decl->getInputs().empty()) {
238  os << ", ";
239  llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
240  os << getNativeTypeName(input) << " " << input->getName().getName();
241  });
242  }
243  os << ") {\n";
244  os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // CPPGen
249 //===----------------------------------------------------------------------===//
250 
251 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
252  raw_ostream &os) {
253  CodeGen codegen(os);
254  codegen.generate(astModule, module);
255 }
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:107
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1194
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition: Nodes.h:1229
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1205
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition: Nodes.h:1221
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
This class represents a top-level AST module.
Definition: Nodes.h:1297
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1302
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:134
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
This class represents a PDLL type that corresponds to an mlir::TypeRange.
Definition: Types.h:175
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:249
This decl represents a user defined constraint.
Definition: Nodes.h:888
This decl represents a user defined rewrite.
Definition: Nodes.h:1098
This class represents a PDLL type that corresponds to an mlir::ValueRange.
Definition: Types.h:191
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:262
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1248
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1255
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1270
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Definition: CPPGen.cpp:251
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41