21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
39 CodeGen(raw_ostream &os) : os(os) {}
42 void generate(
const ast::Module &astModule, ModuleOp module);
45 void generate(pdl::PatternOp pattern, StringRef patternName,
50 void generateConstraintAndRewrites(
const ast::Module &astModule,
61 StringRef getNativeTypeName(
ast::Type type);
71 void CodeGen::generate(
const ast::Module &astModule, ModuleOp module) {
76 generateConstraintAndRewrites(astModule, module, nativeFunctions);
78 os <<
"namespace {\n";
79 std::string basePatternName =
"GeneratedPDLLPattern";
81 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
83 if (std::optional<StringRef> patternName = pattern.getSymName()) {
84 patternNames.insert(patternName->str());
88 name = (basePatternName + Twine(patternIndex++)).str();
89 }
while (!patternNames.insert(name));
92 generate(pattern, patternNames.back(), nativeFunctions);
94 os <<
"} // end namespace\n\n";
97 os <<
"template <typename... ConfigsT>\n"
98 "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
99 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
100 for (
const auto &name : patternNames)
101 os <<
" patterns.add<" << name
102 <<
">(patterns.getContext(), configs...);\n";
106 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
108 const char *patternClassStartStr = R
"(
109 struct {0} : ::mlir::PDLPatternModule {{
110 template <typename... ConfigsT>
111 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
112 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
114 os << llvm::formatv(patternClassStartStr, patternName);
118 os <<
"\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
122 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
123 if (!nativeFunctions.count(fnName) ||
124 !registeredNativeFunctions.insert(fnName).second)
126 os <<
" register" << fnType <<
"Function(\"" << fnName <<
"\", "
127 << fnName <<
"PDLFn);\n";
130 if (
auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131 checkRegisterNativeFn(constraintOp.getName(),
"Constraint");
132 else if (
auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133 checkRegisterNativeFn(rewriteOp.getName(),
"Rewrite");
138 void CodeGen::generateConstraintAndRewrites(
const ast::Module &astModule,
146 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147 [&](
auto op) { usedFns.insert(op.
getName()); });
153 [&](
const auto *decl) {
156 if (decl->getCodeBlock() &&
157 usedFns.contains(decl->getName().getName()))
158 this->generate(decl, nativeFunctions);
165 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
166 true, nativeFunctions);
171 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
172 false, nativeFunctions);
175 StringRef CodeGen::getNativeTypeName(
ast::Type type) {
181 return odsOp->getNativeClassName();
182 return "::mlir::Operation *";
193 if (
auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
194 if (std::optional<StringRef> name = userCst->getNativeInputType(0))
196 return getNativeTypeName(userCst->getInputs()[0]);
201 return getNativeTypeName(decl->
getType());
208 nativeFunctions.insert(name);
219 os <<
"::llvm::LogicalResult";
225 if (results.empty()) {
227 }
else if (results.size() == 1) {
228 os << getNativeTypeName(results[0]);
232 os << getNativeTypeName(result);
238 os <<
" " << name <<
"PDLFn(::mlir::PatternRewriter &rewriter";
242 os << getNativeTypeName(input) <<
" " << input->getName().getName();
256 codegen.generate(astModule, module);
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This decl represents a shared interface for all callable decls.
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
This class represents the base Decl node.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
This class represents a PDLL type that corresponds to an mlir::Operation.
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
This class represents a PDLL type that corresponds to an mlir::TypeRange.
This class represents a PDLL type that corresponds to an mlir::Type.
This decl represents a user defined constraint.
This decl represents a user defined rewrite.
This class represents a PDLL type that corresponds to an mlir::ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Include the generated interface declarations.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
StringRef getName() const
Return the raw string name.