MLIR  20.0.0git
CPPGen.cpp
Go to the documentation of this file.
1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14 
18 #include "mlir/IR/BuiltinOps.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::pdll;
31 
32 //===----------------------------------------------------------------------===//
33 // CodeGen
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 class CodeGen {
38 public:
39  CodeGen(raw_ostream &os) : os(os) {}
40 
41  /// Generate C++ code for the given PDL pattern module.
42  void generate(const ast::Module &astModule, ModuleOp module);
43 
44 private:
45  void generate(pdl::PatternOp pattern, StringRef patternName,
46  StringSet<> &nativeFunctions);
47 
48  /// Generate C++ code for all user defined constraints and rewrites with
49  /// native code.
50  void generateConstraintAndRewrites(const ast::Module &astModule,
51  ModuleOp module,
52  StringSet<> &nativeFunctions);
53  void generate(const ast::UserConstraintDecl *decl,
54  StringSet<> &nativeFunctions);
55  void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
56  void generateConstraintOrRewrite(const ast::CallableDecl *decl,
57  bool isConstraint,
58  StringSet<> &nativeFunctions);
59 
60  /// Return the native name for the type of the given type.
61  StringRef getNativeTypeName(ast::Type type);
62 
63  /// Return the native name for the type of the given variable decl.
64  StringRef getNativeTypeName(ast::VariableDecl *decl);
65 
66  /// The stream to output to.
67  raw_ostream &os;
68 };
69 } // namespace
70 
71 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
73  StringSet<> nativeFunctions;
74 
75  // Generate code for any native functions within the module.
76  generateConstraintAndRewrites(astModule, module, nativeFunctions);
77 
78  os << "namespace {\n";
79  std::string basePatternName = "GeneratedPDLLPattern";
80  int patternIndex = 0;
81  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
82  // If the pattern has a name, use that. Otherwise, generate a unique name.
83  if (std::optional<StringRef> patternName = pattern.getSymName()) {
84  patternNames.insert(patternName->str());
85  } else {
86  std::string name;
87  do {
88  name = (basePatternName + Twine(patternIndex++)).str();
89  } while (!patternNames.insert(name));
90  }
91 
92  generate(pattern, patternNames.back(), nativeFunctions);
93  }
94  os << "} // end namespace\n\n";
95 
96  // Emit function to add the generated matchers to the pattern list.
97  os << "template <typename... ConfigsT>\n"
98  "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
99  "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
100  for (const auto &name : patternNames)
101  os << " patterns.add<" << name
102  << ">(patterns.getContext(), configs...);\n";
103  os << "}\n";
104 }
105 
106 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
107  StringSet<> &nativeFunctions) {
108  const char *patternClassStartStr = R"(
109 struct {0} : ::mlir::PDLPatternModule {{
110  template <typename... ConfigsT>
111  {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
112  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
113 )";
114  os << llvm::formatv(patternClassStartStr, patternName);
115 
116  os << "R\"mlir(";
117  pattern->print(os, OpPrintingFlags().enableDebugInfo());
118  os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
119 
120  // Register any native functions used within the pattern.
121  StringSet<> registeredNativeFunctions;
122  auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
123  if (!nativeFunctions.count(fnName) ||
124  !registeredNativeFunctions.insert(fnName).second)
125  return;
126  os << " register" << fnType << "Function(\"" << fnName << "\", "
127  << fnName << "PDLFn);\n";
128  };
129  pattern.walk([&](Operation *op) {
130  if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131  checkRegisterNativeFn(constraintOp.getName(), "Constraint");
132  else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133  checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
134  });
135  os << " }\n};\n\n";
136 }
137 
138 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
139  ModuleOp module,
140  StringSet<> &nativeFunctions) {
141  // First check to see which constraints and rewrites are actually referenced
142  // in the module.
143  StringSet<> usedFns;
144  module.walk([&](Operation *op) {
146  .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147  [&](auto op) { usedFns.insert(op.getName()); });
148  });
149 
150  for (const ast::Decl *decl : astModule.getChildren()) {
153  [&](const auto *decl) {
154  // We only generate code for inline native decls that have been
155  // referenced.
156  if (decl->getCodeBlock() &&
157  usedFns.contains(decl->getName().getName()))
158  this->generate(decl, nativeFunctions);
159  });
160  }
161 }
162 
163 void CodeGen::generate(const ast::UserConstraintDecl *decl,
164  StringSet<> &nativeFunctions) {
165  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
166  /*isConstraint=*/true, nativeFunctions);
167 }
168 
169 void CodeGen::generate(const ast::UserRewriteDecl *decl,
170  StringSet<> &nativeFunctions) {
171  return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
172  /*isConstraint=*/false, nativeFunctions);
173 }
174 
175 StringRef CodeGen::getNativeTypeName(ast::Type type) {
177  .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
178  .Case([&](ast::OperationType opType) -> StringRef {
179  // Use the derived Op class when available.
180  if (const auto *odsOp = opType.getODSOperation())
181  return odsOp->getNativeClassName();
182  return "::mlir::Operation *";
183  })
184  .Case([&](ast::TypeType) { return "::mlir::Type"; })
185  .Case([&](ast::ValueType) { return "::mlir::Value"; })
186  .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
187  .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
188 }
189 
190 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
191  // Try to extract a type name from the variable's constraints.
192  for (ast::ConstraintRef &cst : decl->getConstraints()) {
193  if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
194  if (std::optional<StringRef> name = userCst->getNativeInputType(0))
195  return *name;
196  return getNativeTypeName(userCst->getInputs()[0]);
197  }
198  }
199 
200  // Otherwise, use the type of the variable.
201  return getNativeTypeName(decl->getType());
202 }
203 
204 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
205  bool isConstraint,
206  StringSet<> &nativeFunctions) {
207  StringRef name = decl->getName()->getName();
208  nativeFunctions.insert(name);
209 
210  os << "static ";
211 
212  // TODO: Work out a proper modeling for "optionality".
213 
214  // Emit the result type.
215  // If this is a constraint, we always return a LogicalResult.
216  // TODO: This will need to change if we allow Constraints to return values as
217  // well.
218  if (isConstraint) {
219  os << "::llvm::LogicalResult";
220  } else {
221  // Otherwise, generate a type based on the results of the callable.
222  // If the callable has explicit results, use those to build the result.
223  // Otherwise, use the type of the callable.
224  ArrayRef<ast::VariableDecl *> results = decl->getResults();
225  if (results.empty()) {
226  os << "void";
227  } else if (results.size() == 1) {
228  os << getNativeTypeName(results[0]);
229  } else {
230  os << "std::tuple<";
231  llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
232  os << getNativeTypeName(result);
233  });
234  os << ">";
235  }
236  }
237 
238  os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
239  if (!decl->getInputs().empty()) {
240  os << ", ";
241  llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
242  os << getNativeTypeName(input) << " " << input->getName().getName();
243  });
244  }
245  os << ") {\n";
246  os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // CPPGen
251 //===----------------------------------------------------------------------===//
252 
253 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
254  raw_ostream &os) {
255  CodeGen codegen(os);
256  codegen.generate(astModule, module);
257 }
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:136
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1188
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition: Nodes.h:1223
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1199
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition: Nodes.h:1215
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
This class represents a top-level AST module.
Definition: Nodes.h:1291
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1296
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:163
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
This class represents a PDLL type that corresponds to an mlir::TypeRange.
Definition: Types.h:203
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:277
This decl represents a user defined constraint.
Definition: Nodes.h:882
This decl represents a user defined rewrite.
Definition: Nodes.h:1092
This class represents a PDLL type that corresponds to an mlir::ValueRange.
Definition: Types.h:218
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:290
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1242
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1249
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1264
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Definition: CPPGen.cpp:253
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41