MLIR 22.0.0git
CPPGen.cpp
Go to the documentation of this file.
1//===- CPPGen.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This files contains a PDLL generator that outputs C++ code that defines PDLL
10// patterns as individual C++ PDLPatternModules for direct use in native code,
11// and also defines any native constraints whose bodies were defined in PDLL.
12//
13//===----------------------------------------------------------------------===//
14
17#include "mlir/IR/BuiltinOps.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/ADT/StringSet.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/ErrorHandling.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::pdll;
29
30//===----------------------------------------------------------------------===//
31// CodeGen
32//===----------------------------------------------------------------------===//
33
34namespace {
35class CodeGen {
36public:
37 CodeGen(raw_ostream &os) : os(os) {}
38
39 /// Generate C++ code for the given PDL pattern module.
40 void generate(const ast::Module &astModule, ModuleOp module);
41
42private:
43 void generate(pdl::PatternOp pattern, StringRef patternName,
44 StringSet<> &nativeFunctions);
45
46 /// Generate C++ code for all user defined constraints and rewrites with
47 /// native code.
48 void generateConstraintAndRewrites(const ast::Module &astModule,
49 ModuleOp module,
50 StringSet<> &nativeFunctions);
51 void generate(const ast::UserConstraintDecl *decl,
52 StringSet<> &nativeFunctions);
53 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
54 void generateConstraintOrRewrite(const ast::CallableDecl *decl,
55 bool isConstraint,
56 StringSet<> &nativeFunctions);
57
58 /// Return the native name for the type of the given type.
59 StringRef getNativeTypeName(ast::Type type);
60
61 /// Return the native name for the type of the given variable decl.
62 StringRef getNativeTypeName(ast::VariableDecl *decl);
63
64 /// The stream to output to.
65 raw_ostream &os;
66};
67} // namespace
68
69void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
71 StringSet<> nativeFunctions;
72
73 // Generate code for any native functions within the module.
74 generateConstraintAndRewrites(astModule, module, nativeFunctions);
75
76 os << "namespace {\n";
77 std::string basePatternName = "GeneratedPDLLPattern";
78 int patternIndex = 0;
79 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
80 // If the pattern has a name, use that. Otherwise, generate a unique name.
81 if (std::optional<StringRef> patternName = pattern.getSymName()) {
82 patternNames.insert(patternName->str());
83 } else {
84 std::string name;
85 do {
86 name = (basePatternName + Twine(patternIndex++)).str();
87 } while (!patternNames.insert(name));
88 }
89
90 generate(pattern, patternNames.back(), nativeFunctions);
91 }
92 os << "} // end namespace\n\n";
93
94 // Emit function to add the generated matchers to the pattern list.
95 os << "template <typename... ConfigsT>\n"
96 "[[maybe_unused]] static void populateGeneratedPDLLPatterns("
97 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
98 for (const auto &name : patternNames)
99 os << " patterns.add<" << name
100 << ">(patterns.getContext(), configs...);\n";
101 os << "}\n";
102}
103
104void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
105 StringSet<> &nativeFunctions) {
106 const char *patternClassStartStr = R"(
107struct {0} : ::mlir::PDLPatternModule {{
108 template <typename... ConfigsT>
109 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
110 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
111)";
112 os << llvm::formatv(patternClassStartStr, patternName);
113
114 os << "R\"mlir(";
115 pattern->print(os, OpPrintingFlags().enableDebugInfo());
116 os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
117
118 // Register any native functions used within the pattern.
119 StringSet<> registeredNativeFunctions;
120 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
121 if (!nativeFunctions.count(fnName) ||
122 !registeredNativeFunctions.insert(fnName).second)
123 return;
124 os << " register" << fnType << "Function(\"" << fnName << "\", "
125 << fnName << "PDLFn);\n";
126 };
127 pattern.walk([&](Operation *op) {
128 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
129 checkRegisterNativeFn(constraintOp.getName(), "Constraint");
130 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
131 checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
132 });
133 os << " }\n};\n\n";
134}
135
136void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
137 ModuleOp module,
138 StringSet<> &nativeFunctions) {
139 // First check to see which constraints and rewrites are actually referenced
140 // in the module.
141 StringSet<> usedFns;
142 module.walk([&](Operation *op) {
143 TypeSwitch<Operation *>(op)
144 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
145 [&](auto op) { usedFns.insert(op.getName()); });
146 });
147
148 for (const ast::Decl *decl : astModule.getChildren()) {
150 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
151 [&](const auto *decl) {
152 // We only generate code for inline native decls that have been
153 // referenced.
154 if (decl->getCodeBlock() &&
155 usedFns.contains(decl->getName().getName()))
156 this->generate(decl, nativeFunctions);
157 });
158 }
159}
160
161void CodeGen::generate(const ast::UserConstraintDecl *decl,
162 StringSet<> &nativeFunctions) {
163 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
164 /*isConstraint=*/true, nativeFunctions);
165}
166
167void CodeGen::generate(const ast::UserRewriteDecl *decl,
168 StringSet<> &nativeFunctions) {
169 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
170 /*isConstraint=*/false, nativeFunctions);
171}
172
173StringRef CodeGen::getNativeTypeName(ast::Type type) {
174 return llvm::TypeSwitch<ast::Type, StringRef>(type)
175 .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
176 .Case([&](ast::OperationType opType) -> StringRef {
177 // Use the derived Op class when available.
178 if (const auto *odsOp = opType.getODSOperation())
179 return odsOp->getNativeClassName();
180 return "::mlir::Operation *";
181 })
182 .Case([&](ast::TypeType) { return "::mlir::Type"; })
183 .Case([&](ast::ValueType) { return "::mlir::Value"; })
184 .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
185 .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
186}
187
188StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
189 // Try to extract a type name from the variable's constraints.
190 for (ast::ConstraintRef &cst : decl->getConstraints()) {
191 if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
192 if (std::optional<StringRef> name = userCst->getNativeInputType(0))
193 return *name;
194 return getNativeTypeName(userCst->getInputs()[0]);
195 }
196 }
197
198 // Otherwise, use the type of the variable.
199 return getNativeTypeName(decl->getType());
200}
201
202void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
203 bool isConstraint,
204 StringSet<> &nativeFunctions) {
205 StringRef name = decl->getName()->getName();
206 nativeFunctions.insert(name);
207
208 os << "static ";
209
210 // TODO: Work out a proper modeling for "optionality".
211
212 // Emit the result type.
213 // If this is a constraint, we always return a LogicalResult.
214 // TODO: This will need to change if we allow Constraints to return values as
215 // well.
216 if (isConstraint) {
217 os << "::llvm::LogicalResult";
218 } else {
219 // Otherwise, generate a type based on the results of the callable.
220 // If the callable has explicit results, use those to build the result.
221 // Otherwise, use the type of the callable.
222 ArrayRef<ast::VariableDecl *> results = decl->getResults();
223 if (results.empty()) {
224 os << "void";
225 } else if (results.size() == 1) {
226 os << getNativeTypeName(results[0]);
227 } else {
228 os << "std::tuple<";
229 llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
230 os << getNativeTypeName(result);
231 });
232 os << ">";
233 }
234 }
235
236 os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
237 if (!decl->getInputs().empty()) {
238 os << ", ";
239 llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
240 os << getNativeTypeName(input) << " " << input->getName().getName();
241 });
242 }
243 os << ") {\n";
244 os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
245}
247//===----------------------------------------------------------------------===//
248// CPPGen
249//===----------------------------------------------------------------------===//
250
251void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
252 raw_ostream &os) {
253 CodeGen codegen(os);
254 codegen.generate(astModule, module);
255}
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition Nodes.h:1229
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition Nodes.h:1221
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition Nodes.h:1205
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
This class represents a top-level AST module.
Definition Nodes.h:1297
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition Types.cpp:87
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition Nodes.h:1255
Type getType() const
Return the type of the decl.
Definition Nodes.h:1270
void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os)
Definition CPPGen.cpp:246
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::StringSet< AllocatorTy > StringSet
Definition LLVM.h:133
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41