12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
25 #include "Templates/Header.txt"
34 struct DialectStrings {
35 std::string dialectName;
36 std::string dialectCppName;
37 std::string dialectCppShortName;
38 std::string dialectBaseTypeName;
40 std::string namespaceOpen;
41 std::string namespaceClose;
42 std::string namespacePath;
48 std::string typeCppName;
54 std::string opCppName;
60 std::string nameArray;
61 llvm::raw_string_ostream nameArrayStream(nameArray);
62 nameArrayStream <<
"{\"" << llvm::join(names,
"\", \"") <<
"\"}";
68 static std::string typeToCppName(irdl::TypeOp type) {
69 return llvm::formatv(
"{0}Type",
70 convertToCamelFromSnakeCase(type.getSymName(),
true));
74 static std::string opToCppName(irdl::OperationOp op) {
75 return llvm::formatv(
"{0}Op",
76 convertToCamelFromSnakeCase(op.getSymName(),
true));
80 static TypeStrings getStrings(irdl::TypeOp type) {
82 strings.typeName = type.getSymName();
83 strings.typeCppName = typeToCppName(type);
88 static OpStrings getStrings(irdl::OperationOp op) {
89 auto operandOp = op.getOp<irdl::OperandsOp>();
91 auto resultOp = op.getOp<irdl::ResultsOp>();
94 strings.opName = op.getSymName();
95 strings.opCppName = opToCppName(op);
99 llvm::map_range(operandOp->getNames(), [](
Attribute attr) {
100 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
106 llvm::map_range(resultOp->getNames(), [](
Attribute attr) {
107 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
116 const TypeStrings &strings) {
117 dict[
"TYPE_NAME"] = strings.typeName;
118 dict[
"TYPE_CPP_NAME"] = strings.typeCppName;
123 const auto operandCount = strings.opOperandNames.size();
124 const auto resultCount = strings.opResultNames.size();
126 dict[
"OP_NAME"] = strings.opName;
127 dict[
"OP_CPP_NAME"] = strings.opCppName;
128 dict[
"OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size());
129 dict[
"OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size());
130 dict[
"OP_OPERAND_INITIALIZER_LIST"] =
131 operandCount ? joinNameList(strings.opOperandNames) :
"{\"\"}";
132 dict[
"OP_RESULT_INITIALIZER_LIST"] =
133 resultCount ? joinNameList(strings.opResultNames) :
"{\"\"}";
138 const DialectStrings &strings) {
139 dict[
"DIALECT_NAME"] = strings.dialectName;
140 dict[
"DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
141 dict[
"DIALECT_CPP_NAME"] = strings.dialectCppName;
142 dict[
"DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
143 dict[
"NAMESPACE_OPEN"] = strings.namespaceOpen;
144 dict[
"NAMESPACE_CLOSE"] = strings.namespaceClose;
145 dict[
"NAMESPACE_PATH"] = strings.namespacePath;
148 static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
150 auto typeOps = dialect.getOps<irdl::TypeOp>();
151 auto range = llvm::map_range(typeOps, typeToCppName);
156 static LogicalResult generateOpList(irdl::DialectOp &dialect,
158 auto operationOps = dialect.getOps<irdl::OperationOp>();
159 auto range = llvm::map_range(operationOps, opToCppName);
169 #include
"Templates/TypeDecl.txt"
172 fillDict(dict, getStrings(type));
173 typeDeclTemplate.render(output, dict);
179 const OpStrings &opStrings) {
180 auto opGetters = std::string{};
181 auto resGetters = std::string{};
183 for (
size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
185 llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i],
true);
186 opGetters += llvm::formatv(
"::mlir::Value get{0}() { return "
187 "getStructuredOperands({1}).front(); }\n ",
190 for (
size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
192 llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i],
true);
193 resGetters += llvm::formatv(
194 R
"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
199 dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200 dict[
"OP_RESULT_GETTER_DECLS"] = resGetters;
204 const OpStrings &opStrings) {
205 std::string buildDecls;
206 llvm::raw_string_ostream stream{buildDecls};
209 llvm::join(llvm::map_range(opStrings.opResultNames,
210 [](StringRef name) -> std::string {
211 return llvm::formatv(
212 "::mlir::Type {0}, ",
213 llvm::convertToCamelFromSnakeCase(name));
218 llvm::join(llvm::map_range(opStrings.opOperandNames,
219 [](StringRef name) -> std::string {
220 return llvm::formatv(
221 "::mlir::Value {0}, ",
222 llvm::convertToCamelFromSnakeCase(name));
226 stream << llvm::formatv(
227 R
"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
228 resultParams, operandParams);
229 dict["OP_BUILD_DECLS"] = buildDecls;
236 #include
"Templates/PerOperationDecl.txt"
238 const auto opStrings = getStrings(op);
239 fillDict(dict, opStrings);
244 perOpDeclTemplate.render(output, dict);
250 DialectStrings &dialectStrings) {
252 #include
"Templates/DialectDecl.txt"
255 #include
"Templates/TypeHeaderDecl.txt"
259 fillDict(dict, dialectStrings);
261 dialectDeclTemplate.render(output, dict);
262 typeHeaderDeclTemplate.render(output, dict);
264 auto typeOps = dialect.getOps<irdl::TypeOp>();
265 auto operationOps = dialect.getOps<irdl::OperationOp>();
267 for (
auto &&typeOp : typeOps) {
273 if (failed(generateOpList(dialect, opNames)))
276 auto classDeclarations =
277 llvm::join(llvm::map_range(opNames,
278 [](llvm::StringRef name) -> std::string {
279 return llvm::formatv(
"class {0};", name);
282 const auto forwardDeclarations = llvm::formatv(
283 "{1}\n{0}\n{2}", std::move(classDeclarations),
284 dialectStrings.namespaceOpen, dialectStrings.namespaceClose);
286 output << forwardDeclarations;
287 for (
auto &&operationOp : operationOps) {
296 irdl::OperationOp op) {
298 #include "Templates/PerOperationDef.txt"
301 auto opStrings = getStrings(op);
302 fillDict(dict, opStrings);
304 const auto operandCount = opStrings.opOperandNames.size();
305 const auto operandNames =
306 operandCount ? joinNameList(opStrings.opOperandNames) :
"{\"\"}";
308 const auto resultNames = joinNameList(opStrings.opResultNames);
310 auto resultTypes = llvm::join(
311 llvm::map_range(opStrings.opResultNames,
312 [](StringRef attr) -> std::string {
313 return llvm::formatv(
"::mlir::Type {0}, ", attr);
316 auto operandTypes = llvm::join(
317 llvm::map_range(opStrings.opOperandNames,
318 [](StringRef attr) -> std::string {
319 return llvm::formatv(
"::mlir::Value {0}, ", attr);
323 llvm::join(llvm::map_range(opStrings.opOperandNames,
324 [](StringRef attr) -> std::string {
325 return llvm::formatv(
326 " opState.addOperands({0});", attr);
329 auto resultAdder = llvm::join(
330 llvm::map_range(opStrings.opResultNames,
331 [](StringRef attr) -> std::string {
332 return llvm::formatv(
" opState.addTypes({0});", attr);
336 const auto buildDefinition = llvm::formatv(
338 void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
343 opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
344 std::move(operandAdder), std::move(resultAdder));
346 dict["OP_BUILD_DEFS"] = buildDefinition;
349 llvm::raw_string_ostream stream{str};
350 perOpDefTemplate.render(stream, dict);
356 return llvm::formatv(
357 R
"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
358 value = {1}::{0}::get(parser.getContext());
359 return ::mlir::success(!!value);
361 name, dialectStrings.namespacePath);
364 static LogicalResult
generateLib(irdl::DialectOp dialect, raw_ostream &output,
365 DialectStrings &dialectStrings) {
368 #include "Templates/TypeHeaderDef.txt"
371 #include "Templates/TypeDef.txt"
374 #include "Templates/DialectDef.txt"
378 fillDict(dict, dialectStrings);
380 typeHeaderDefTemplate.render(output, dict);
383 if (failed(generateTypedefList(dialect, typeNames)))
386 dict[
"TYPE_LIST"] = llvm::join(
387 llvm::map_range(typeNames,
388 [&dialectStrings](llvm::StringRef name) -> std::string {
389 return llvm::formatv(
390 "{0}::{1}", dialectStrings.namespacePath, name);
394 auto typeVerifierGenerator =
395 [&dialectStrings](llvm::StringRef name) -> std::string {
400 llvm::join(llvm::map_range(typeNames, typeVerifierGenerator),
"\n");
402 dict[
"TYPE_PARSER"] = llvm::formatv(
403 R
"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
404 return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
406 .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
411 std::move(typeCase));
414 llvm::join(llvm::map_range(typeNames,
415 [&](llvm::StringRef name) -> std::string {
416 return llvm::formatv(
417 R
"(.Case<{1}::{0}>([&](auto t) {
418 printer << {1}::{0}::getMnemonic();
419 return ::mlir::success();
421 name, dialectStrings.namespacePath);
424 dict[
"TYPE_PRINTER"] = llvm::formatv(
425 R
"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
426 return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
428 .Default([](auto) {{ return ::mlir::failure(); });
430 std::move(typePrintCase));
432 dict["TYPE_DEFINES"] =
433 join(map_range(typeNames,
434 [&](StringRef name) -> std::string {
435 return formatv(
"MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
436 name, dialectStrings.namespacePath);
440 typeDefTemplate.render(output, dict);
442 auto operations = dialect.getOps<irdl::OperationOp>();
444 if (failed(generateOpList(dialect, opNames)))
447 const auto commaSeparatedOpList = llvm::join(
449 [&dialectStrings](llvm::StringRef name) -> std::string {
450 return llvm::formatv(
"{0}::{1}", dialectStrings.namespacePath,
455 const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
459 const auto perOpDefinitions =
460 llvm::join(llvm::map_range(operations, opDefinitionGenerator),
"\n");
462 dict[
"OP_LIST"] = commaSeparatedOpList;
463 dict[
"OP_CLASSES"] = perOpDefinitions;
464 output << perOpDefinitions;
465 dialectDefTemplate.render(output, dict);
471 LogicalResult res = success();
475 .Case<irdl::DialectOp>(([](irdl::DialectOp) {
return success(); }))
476 .Case<irdl::OperationOp>(
477 ([](irdl::OperationOp) {
return success(); }))
478 .Case<irdl::TypeOp>(([](irdl::TypeOp) {
return success(); }))
479 .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
481 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
482 return attr.getValue() == irdl::Variadicity::single;
485 return op.emitError(
"IRDL C++ translation does not yet support "
486 "variadic operations");
488 .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
490 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
491 return attr.getValue() == irdl::Variadicity::single;
495 "IRDL C++ translation does not yet support variadic results");
497 .Case<irdl::AnyOp>(([](irdl::AnyOp) {
return success(); }))
499 return op->
emitError(
"IRDL C++ translation does not yet support "
501 << op->
getName() <<
" operation";
515 raw_ostream &output) {
516 static const auto typeDefTempl = detail::Template(
517 #include
"Templates/TypeDef.txt"
520 llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
522 for (
auto dialect : dialects) {
526 StringRef dialectName = dialect.getSymName();
529 std::string namespaceOpen;
530 std::string namespaceClose;
531 std::string namespacePath;
532 llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
533 llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
534 llvm::raw_string_ostream namespacePathStream(namespacePath);
535 for (
auto &pathElement : namespaceAbsolutePath) {
536 namespaceOpenStream <<
"namespace " << pathElement <<
" {\n";
537 namespaceCloseStream <<
"} // namespace " << pathElement <<
"\n";
538 namespacePathStream <<
"::" << pathElement;
541 std::string cppShortName =
542 llvm::convertToCamelFromSnakeCase(dialectName,
true);
543 std::string dialectBaseTypeName = llvm::formatv(
"{0}Type", cppShortName);
544 std::string cppName = llvm::formatv(
"{0}Dialect", cppShortName);
546 DialectStrings dialectStrings;
547 dialectStrings.dialectName = dialectName;
548 dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
549 dialectStrings.dialectCppName = cppName;
550 dialectStrings.dialectCppShortName = cppShortName;
551 dialectStrings.namespaceOpen = namespaceOpen;
552 dialectStrings.namespaceClose = namespaceClose;
553 dialectStrings.namespacePath = namespacePath;
555 dialectStringTable[dialect] = std::move(dialectStrings);
562 for (
auto dialect : dialects) {
564 auto &dialectStrings = dialectStringTable[dialect];
565 auto &dialectName = dialectStrings.dialectName;
568 return dialect->emitError(
"Error in Dialect " + dialectName +
569 " while generating headers");
574 for (
auto &dialect : dialects) {
575 auto &dialectStrings = dialectStringTable[dialect];
576 auto &dialectName = dialectStrings.dialectName;
578 if (failed(
generateLib(dialect, output, dialectStrings)))
579 return dialect->emitError(
"Error in Dialect " + dialectName +
580 " while generating library");
static LogicalResult verifySupported(irdl::DialectOp dialect)
static LogicalResult generateInclude(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
constexpr char declarationMacroFlag[]
constexpr char headerTemplateText[]
static std::string generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings)
static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, irdl::detail::dictionary &dict)
static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
constexpr char definitionMacroFlag[]
static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
static WalkResult advance()
static WalkResult interrupt()
Template Code as used by IRDL-to-Cpp.
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
LogicalResult translateIRDLDialectToCpp(llvm::ArrayRef< irdl::DialectOp > dialects, raw_ostream &output)
Translates an IRDL dialect definition to a C++ definition that can be used with MLIR.
Include the generated interface declarations.