20#include "llvm/ADT/StringExtras.h"
21#include "llvm/ADT/StringSet.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/ErrorHandling.h"
24#include "llvm/Support/FormatVariadic.h"
37 CodeGen(raw_ostream &os) : os(os) {}
40 void generate(
const ast::Module &astModule, ModuleOp module);
43 void generate(pdl::PatternOp pattern, StringRef patternName,
48 void generateConstraintAndRewrites(
const ast::Module &astModule,
51 void generate(
const ast::UserConstraintDecl *decl,
53 void generate(
const ast::UserRewriteDecl *decl,
StringSet<> &nativeFunctions);
54 void generateConstraintOrRewrite(
const ast::CallableDecl *decl,
59 StringRef getNativeTypeName(ast::Type type);
62 StringRef getNativeTypeName(ast::VariableDecl *decl);
69void CodeGen::generate(
const ast::Module &astModule, ModuleOp module) {
74 generateConstraintAndRewrites(astModule, module, nativeFunctions);
76 os <<
"namespace {\n";
77 std::string basePatternName =
"GeneratedPDLLPattern";
79 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
81 if (std::optional<StringRef> patternName = pattern.getSymName()) {
82 patternNames.insert(patternName->str());
86 name = (basePatternName + Twine(patternIndex++)).str();
87 }
while (!patternNames.insert(name));
90 generate(pattern, patternNames.back(), nativeFunctions);
92 os <<
"} // end namespace\n\n";
95 os <<
"template <typename... ConfigsT>\n"
96 "[[maybe_unused]] static void populateGeneratedPDLLPatterns("
97 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
98 for (
const auto &name : patternNames)
99 os <<
" patterns.add<" << name
100 <<
">(patterns.getContext(), configs...);\n";
104void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
106 const char *patternClassStartStr = R
"(
107struct {0} : ::mlir::PDLPatternModule {{
108 template <typename... ConfigsT>
109 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
110 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
112 os << llvm::formatv(patternClassStartStr, patternName);
115 pattern->print(os, OpPrintingFlags().enableDebugInfo());
116 os <<
"\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
120 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
121 if (!nativeFunctions.count(fnName) ||
122 !registeredNativeFunctions.insert(fnName).second)
124 os <<
" register" << fnType <<
"Function(\"" << fnName <<
"\", "
125 << fnName <<
"PDLFn);\n";
127 pattern.walk([&](Operation *op) {
128 if (
auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
129 checkRegisterNativeFn(constraintOp.getName(),
"Constraint");
130 else if (
auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
131 checkRegisterNativeFn(rewriteOp.getName(),
"Rewrite");
136void CodeGen::generateConstraintAndRewrites(
const ast::Module &astModule,
142 module.walk([&](Operation *op) {
143 TypeSwitch<Operation *>(op)
144 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
145 [&](auto op) { usedFns.insert(op.getName()); });
148 for (
const ast::Decl *decl : astModule.getChildren()) {
150 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
151 [&](
const auto *decl) {
154 if (decl->getCodeBlock() &&
155 usedFns.contains(decl->getName().getName()))
156 this->generate(decl, nativeFunctions);
161void CodeGen::generate(
const ast::UserConstraintDecl *decl,
163 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
164 true, nativeFunctions);
167void CodeGen::generate(
const ast::UserRewriteDecl *decl,
169 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
170 false, nativeFunctions);
173StringRef CodeGen::getNativeTypeName(ast::Type type) {
174 return llvm::TypeSwitch<ast::Type, StringRef>(type)
175 .Case([&](ast::AttributeType) {
return "::mlir::Attribute"; })
176 .Case([&](ast::OperationType opType) -> StringRef {
179 return odsOp->getNativeClassName();
180 return "::mlir::Operation *";
182 .Case([&](ast::TypeType) {
return "::mlir::Type"; })
183 .Case([&](ast::ValueType) {
return "::mlir::Value"; })
184 .Case([&](ast::TypeRangeType) {
return "::mlir::TypeRange"; })
185 .Case([&](ast::ValueRangeType) {
return "::mlir::ValueRange"; });
188StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
191 if (
auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
192 if (std::optional<StringRef> name = userCst->getNativeInputType(0))
194 return getNativeTypeName(userCst->getInputs()[0]);
199 return getNativeTypeName(decl->
getType());
202void CodeGen::generateConstraintOrRewrite(
const ast::CallableDecl *decl,
206 nativeFunctions.insert(name);
217 os <<
"::llvm::LogicalResult";
222 ArrayRef<ast::VariableDecl *> results = decl->
getResults();
223 if (results.empty()) {
225 }
else if (results.size() == 1) {
226 os << getNativeTypeName(results[0]);
229 llvm::interleaveComma(results, os, [&](ast::VariableDecl *
result) {
230 os << getNativeTypeName(
result);
236 os <<
" " << name <<
"PDLFn(::mlir::PatternRewriter &rewriter";
239 llvm::interleaveComma(decl->
getInputs(), os, [&](ast::VariableDecl *input) {
240 os << getNativeTypeName(input) <<
" " << input->getName().getName();
254 codegen.generate(astModule, module);
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class represents a top-level AST module.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::StringSet< AllocatorTy > StringSet
llvm::TypeSwitch< T, ResultT > TypeSwitch
StringRef getName() const
Return the raw string name.