20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/FormatVariadic.h"
37 CodeGen(raw_ostream &os) : os(os) {}
40 void generate(
const ast::Module &astModule, ModuleOp module);
43 void generate(pdl::PatternOp pattern, StringRef patternName,
48 void generateConstraintAndRewrites(
const ast::Module &astModule,
59 StringRef getNativeTypeName(
ast::Type type);
69 void CodeGen::generate(
const ast::Module &astModule, ModuleOp module) {
74 generateConstraintAndRewrites(astModule, module, nativeFunctions);
76 os <<
"namespace {\n";
77 std::string basePatternName =
"GeneratedPDLLPattern";
79 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
81 if (std::optional<StringRef> patternName = pattern.getSymName()) {
82 patternNames.insert(patternName->str());
86 name = (basePatternName + Twine(patternIndex++)).str();
87 }
while (!patternNames.insert(name));
90 generate(pattern, patternNames.back(), nativeFunctions);
92 os <<
"} // end namespace\n\n";
95 os <<
"template <typename... ConfigsT>\n"
96 "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
97 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
98 for (
const auto &name : patternNames)
99 os <<
" patterns.add<" << name
100 <<
">(patterns.getContext(), configs...);\n";
104 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
106 const char *patternClassStartStr = R
"(
107 struct {0} : ::mlir::PDLPatternModule {{
108 template <typename... ConfigsT>
109 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
110 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
112 os << llvm::formatv(patternClassStartStr, patternName);
116 os <<
"\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
120 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
121 if (!nativeFunctions.count(fnName) ||
122 !registeredNativeFunctions.insert(fnName).second)
124 os <<
" register" << fnType <<
"Function(\"" << fnName <<
"\", "
125 << fnName <<
"PDLFn);\n";
128 if (
auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
129 checkRegisterNativeFn(constraintOp.getName(),
"Constraint");
130 else if (
auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
131 checkRegisterNativeFn(rewriteOp.getName(),
"Rewrite");
136 void CodeGen::generateConstraintAndRewrites(
const ast::Module &astModule,
144 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
145 [&](
auto op) { usedFns.insert(op.
getName()); });
151 [&](
const auto *decl) {
154 if (decl->getCodeBlock() &&
155 usedFns.contains(decl->getName().getName()))
156 this->generate(decl, nativeFunctions);
163 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
164 true, nativeFunctions);
169 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
170 false, nativeFunctions);
173 StringRef CodeGen::getNativeTypeName(
ast::Type type) {
179 return odsOp->getNativeClassName();
180 return "::mlir::Operation *";
191 if (
auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
192 if (std::optional<StringRef> name = userCst->getNativeInputType(0))
194 return getNativeTypeName(userCst->getInputs()[0]);
199 return getNativeTypeName(decl->
getType());
206 nativeFunctions.insert(name);
217 os <<
"::llvm::LogicalResult";
223 if (results.empty()) {
225 }
else if (results.size() == 1) {
226 os << getNativeTypeName(results[0]);
230 os << getNativeTypeName(result);
236 os <<
" " << name <<
"PDLFn(::mlir::PatternRewriter &rewriter";
240 os << getNativeTypeName(input) <<
" " << input->getName().getName();
254 codegen.generate(astModule, module);
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This decl represents a shared interface for all callable decls.
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
This class represents a PDLL type that corresponds to an mlir::TypeRange.
This class represents a PDLL type that corresponds to an mlir::Type.
This decl represents a user defined constraint.
This decl represents a user defined rewrite.
This class represents a PDLL type that corresponds to an mlir::ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
StringRef getName() const
Return the raw string name.