12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
25 #include "Templates/Header.txt"
34 struct DialectStrings {
35 std::string dialectName;
36 std::string dialectCppName;
37 std::string dialectCppShortName;
38 std::string dialectBaseTypeName;
40 std::string namespaceOpen;
41 std::string namespaceClose;
42 std::string namespacePath;
48 std::string typeCppName;
54 std::string opCppName;
60 std::string nameArray;
61 llvm::raw_string_ostream nameArrayStream(nameArray);
62 nameArrayStream <<
"{\"" << llvm::join(names,
"\", \"") <<
"\"}";
68 static std::string typeToCppName(irdl::TypeOp type) {
69 return llvm::formatv(
"{0}Type",
70 convertToCamelFromSnakeCase(type.getSymName(),
true));
74 static std::string opToCppName(irdl::OperationOp op) {
75 return llvm::formatv(
"{0}Op",
76 convertToCamelFromSnakeCase(op.getSymName(),
true));
80 static TypeStrings getStrings(irdl::TypeOp type) {
82 strings.typeName = type.getSymName();
83 strings.typeCppName = typeToCppName(type);
88 static OpStrings getStrings(irdl::OperationOp op) {
89 auto operandOp = op.getOp<irdl::OperandsOp>();
91 auto resultOp = op.getOp<irdl::ResultsOp>();
94 strings.opName = op.getSymName();
95 strings.opCppName = opToCppName(op);
99 llvm::map_range(operandOp->getNames(), [](
Attribute attr) {
100 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
106 llvm::map_range(resultOp->getNames(), [](
Attribute attr) {
107 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
116 const TypeStrings &strings) {
117 dict[
"TYPE_NAME"] = strings.typeName;
118 dict[
"TYPE_CPP_NAME"] = strings.typeCppName;
123 const auto operandCount = strings.opOperandNames.size();
124 const auto resultCount = strings.opResultNames.size();
126 dict[
"OP_NAME"] = strings.opName;
127 dict[
"OP_CPP_NAME"] = strings.opCppName;
128 dict[
"OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size());
129 dict[
"OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size());
130 dict[
"OP_OPERAND_INITIALIZER_LIST"] =
131 operandCount ? joinNameList(strings.opOperandNames) :
"{\"\"}";
132 dict[
"OP_RESULT_INITIALIZER_LIST"] =
133 resultCount ? joinNameList(strings.opResultNames) :
"{\"\"}";
138 const DialectStrings &strings) {
139 dict[
"DIALECT_NAME"] = strings.dialectName;
140 dict[
"DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
141 dict[
"DIALECT_CPP_NAME"] = strings.dialectCppName;
142 dict[
"DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
143 dict[
"NAMESPACE_OPEN"] = strings.namespaceOpen;
144 dict[
"NAMESPACE_CLOSE"] = strings.namespaceClose;
145 dict[
"NAMESPACE_PATH"] = strings.namespacePath;
148 static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
150 auto typeOps = dialect.getOps<irdl::TypeOp>();
151 auto range = llvm::map_range(typeOps, typeToCppName);
156 static LogicalResult generateOpList(irdl::DialectOp &dialect,
158 auto operationOps = dialect.getOps<irdl::OperationOp>();
159 auto range = llvm::map_range(operationOps, opToCppName);
169 #include
"Templates/TypeDecl.txt"
172 fillDict(dict, getStrings(type));
173 typeDeclTemplate.render(output, dict);
179 const OpStrings &opStrings) {
180 auto opGetters = std::string{};
181 auto resGetters = std::string{};
183 for (
size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
185 llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i],
true);
186 opGetters += llvm::formatv(
"::mlir::Value get{0}() { return "
187 "getStructuredOperands({1}).front(); }\n ",
190 for (
size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
192 llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i],
true);
193 resGetters += llvm::formatv(
194 R
"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
199 dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200 dict[
"OP_RESULT_GETTER_DECLS"] = resGetters;
204 const OpStrings &opStrings) {
205 std::string buildDecls;
206 llvm::raw_string_ostream stream{buildDecls};
209 llvm::join(llvm::map_range(opStrings.opResultNames,
210 [](StringRef name) -> std::string {
211 return llvm::formatv(
212 "::mlir::Type {0}, ",
213 llvm::convertToCamelFromSnakeCase(name));
218 llvm::join(llvm::map_range(opStrings.opOperandNames,
219 [](StringRef name) -> std::string {
220 return llvm::formatv(
221 "::mlir::Value {0}, ",
222 llvm::convertToCamelFromSnakeCase(name));
226 stream << llvm::formatv(
227 R
"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
228 resultParams, operandParams);
230 stream << llvm::formatv(
231 R
"(static {0} create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
232 opStrings.opCppName, resultParams, operandParams);
234 stream << llvm::formatv(
235 R
"(static {0} create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
236 opStrings.opCppName, resultParams, operandParams);
238 dict[
"OP_BUILD_DECLS"] = buildDecls;
245 #include
"Templates/PerOperationDecl.txt"
247 const auto opStrings = getStrings(op);
248 fillDict(dict, opStrings);
253 perOpDeclTemplate.render(output, dict);
259 DialectStrings &dialectStrings) {
261 #include
"Templates/DialectDecl.txt"
264 #include
"Templates/TypeHeaderDecl.txt"
268 fillDict(dict, dialectStrings);
270 dialectDeclTemplate.render(output, dict);
271 typeHeaderDeclTemplate.render(output, dict);
273 auto typeOps = dialect.getOps<irdl::TypeOp>();
274 auto operationOps = dialect.getOps<irdl::OperationOp>();
276 for (
auto &&typeOp : typeOps) {
282 if (
failed(generateOpList(dialect, opNames)))
285 auto classDeclarations =
286 llvm::join(llvm::map_range(opNames,
287 [](llvm::StringRef name) -> std::string {
288 return llvm::formatv(
"class {0};", name);
291 const auto forwardDeclarations = llvm::formatv(
292 "{1}\n{0}\n{2}", std::move(classDeclarations),
293 dialectStrings.namespaceOpen, dialectStrings.namespaceClose);
295 output << forwardDeclarations;
296 for (
auto &&operationOp : operationOps) {
305 irdl::OperationOp op) {
307 #include "Templates/PerOperationDef.txt"
310 auto opStrings = getStrings(op);
311 fillDict(dict, opStrings);
313 const auto operandCount = opStrings.opOperandNames.size();
314 const auto operandNames =
315 operandCount ? joinNameList(opStrings.opOperandNames) :
"{\"\"}";
317 const auto resultNames = joinNameList(opStrings.opResultNames);
319 auto resultTypes = llvm::join(
320 llvm::map_range(opStrings.opResultNames,
321 [](StringRef attr) -> std::string {
322 return llvm::formatv(
"::mlir::Type {0}, ", attr);
325 auto operandTypes = llvm::join(
326 llvm::map_range(opStrings.opOperandNames,
327 [](StringRef attr) -> std::string {
328 return llvm::formatv(
"::mlir::Value {0}, ", attr);
332 llvm::join(llvm::map_range(opStrings.opOperandNames,
333 [](StringRef attr) -> std::string {
334 return llvm::formatv(
335 " opState.addOperands({0});", attr);
338 auto resultAdder = llvm::join(
339 llvm::map_range(opStrings.opResultNames,
340 [](StringRef attr) -> std::string {
341 return llvm::formatv(
" opState.addTypes({0});", attr);
345 const auto buildDefinition = llvm::formatv(
347 void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
352 {0} {0}::create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
353 ::mlir::OperationState __state__(location, getOperationName());
354 build(opBuilder, __state__, {5} {6} attributes);
355 auto __res__ = ::llvm::dyn_cast<{0}>(opBuilder.create(__state__));
356 assert(__res__ && "builder didn't return the right type");
360 {0} {0}::create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
361 return create(opBuilder, opBuilder.getLoc(), {5} {6} attributes);
364 opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
365 std::move(operandAdder), std::move(resultAdder),
366 llvm::join(opStrings.opResultNames, ",") +
367 (!opStrings.opResultNames.empty() ?
"," :
""),
368 llvm::join(opStrings.opOperandNames,
",") +
369 (!opStrings.opOperandNames.empty() ?
"," :
""));
371 dict[
"OP_BUILD_DEFS"] = buildDefinition;
374 llvm::raw_string_ostream stream{str};
375 perOpDefTemplate.render(stream, dict);
381 return llvm::formatv(
382 R
"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
383 value = {1}::{0}::get(parser.getContext());
384 return ::mlir::success(!!value);
386 name, dialectStrings.namespacePath);
389 static LogicalResult
generateLib(irdl::DialectOp dialect, raw_ostream &output,
390 DialectStrings &dialectStrings) {
393 #include "Templates/TypeHeaderDef.txt"
396 #include "Templates/TypeDef.txt"
399 #include "Templates/DialectDef.txt"
403 fillDict(dict, dialectStrings);
405 typeHeaderDefTemplate.render(output, dict);
408 if (
failed(generateTypedefList(dialect, typeNames)))
411 dict[
"TYPE_LIST"] = llvm::join(
412 llvm::map_range(typeNames,
413 [&dialectStrings](llvm::StringRef name) -> std::string {
414 return llvm::formatv(
415 "{0}::{1}", dialectStrings.namespacePath, name);
419 auto typeVerifierGenerator =
420 [&dialectStrings](llvm::StringRef name) -> std::string {
425 llvm::join(llvm::map_range(typeNames, typeVerifierGenerator),
"\n");
427 dict[
"TYPE_PARSER"] = llvm::formatv(
428 R
"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429 return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
431 .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
436 std::move(typeCase));
439 llvm::join(llvm::map_range(typeNames,
440 [&](llvm::StringRef name) -> std::string {
441 return llvm::formatv(
442 R
"(.Case<{1}::{0}>([&](auto t) {
443 printer << {1}::{0}::getMnemonic();
444 return ::mlir::success();
446 name, dialectStrings.namespacePath);
449 dict[
"TYPE_PRINTER"] = llvm::formatv(
450 R
"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
451 return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
453 .Default([](auto) {{ return ::mlir::failure(); });
455 std::move(typePrintCase));
457 dict["TYPE_DEFINES"] =
458 join(map_range(typeNames,
459 [&](StringRef name) -> std::string {
460 return formatv(
"MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
461 name, dialectStrings.namespacePath);
465 typeDefTemplate.render(output, dict);
467 auto operations = dialect.getOps<irdl::OperationOp>();
469 if (
failed(generateOpList(dialect, opNames)))
472 const auto commaSeparatedOpList = llvm::join(
474 [&dialectStrings](llvm::StringRef name) -> std::string {
475 return llvm::formatv(
"{0}::{1}", dialectStrings.namespacePath,
480 const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
484 const auto perOpDefinitions =
485 llvm::join(llvm::map_range(operations, opDefinitionGenerator),
"\n");
487 dict[
"OP_LIST"] = commaSeparatedOpList;
488 dict[
"OP_CLASSES"] = perOpDefinitions;
489 output << perOpDefinitions;
490 dialectDefTemplate.render(output, dict);
496 LogicalResult res = success();
500 .Case<irdl::DialectOp>(([](irdl::DialectOp) {
return success(); }))
501 .Case<irdl::OperationOp>(
502 ([](irdl::OperationOp) {
return success(); }))
503 .Case<irdl::TypeOp>(([](irdl::TypeOp) {
return success(); }))
504 .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
506 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
507 return attr.getValue() == irdl::Variadicity::single;
510 return op.emitError(
"IRDL C++ translation does not yet support "
511 "variadic operations");
513 .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
515 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
516 return attr.getValue() == irdl::Variadicity::single;
520 "IRDL C++ translation does not yet support variadic results");
522 .Case<irdl::AnyOp>(([](irdl::AnyOp) {
return success(); }))
524 return op->
emitError(
"IRDL C++ translation does not yet support "
526 << op->
getName() <<
" operation";
540 raw_ostream &output) {
541 static const auto typeDefTempl = detail::Template(
542 #include
"Templates/TypeDef.txt"
545 llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
547 for (
auto dialect : dialects) {
551 StringRef dialectName = dialect.getSymName();
554 std::string namespaceOpen;
555 std::string namespaceClose;
556 std::string namespacePath;
557 llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
558 llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
559 llvm::raw_string_ostream namespacePathStream(namespacePath);
560 for (
auto &pathElement : namespaceAbsolutePath) {
561 namespaceOpenStream <<
"namespace " << pathElement <<
" {\n";
562 namespaceCloseStream <<
"} // namespace " << pathElement <<
"\n";
563 namespacePathStream <<
"::" << pathElement;
566 std::string cppShortName =
567 llvm::convertToCamelFromSnakeCase(dialectName,
true);
568 std::string dialectBaseTypeName = llvm::formatv(
"{0}Type", cppShortName);
569 std::string cppName = llvm::formatv(
"{0}Dialect", cppShortName);
571 DialectStrings dialectStrings;
572 dialectStrings.dialectName = dialectName;
573 dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
574 dialectStrings.dialectCppName = cppName;
575 dialectStrings.dialectCppShortName = cppShortName;
576 dialectStrings.namespaceOpen = namespaceOpen;
577 dialectStrings.namespaceClose = namespaceClose;
578 dialectStrings.namespacePath = namespacePath;
580 dialectStringTable[dialect] = std::move(dialectStrings);
587 for (
auto dialect : dialects) {
589 auto &dialectStrings = dialectStringTable[dialect];
590 auto &dialectName = dialectStrings.dialectName;
593 return dialect->emitError(
"Error in Dialect " + dialectName +
594 " while generating headers");
599 for (
auto &dialect : dialects) {
600 auto &dialectStrings = dialectStringTable[dialect];
601 auto &dialectName = dialectStrings.dialectName;
604 return dialect->emitError(
"Error in Dialect " + dialectName +
605 " while generating library");
static LogicalResult verifySupported(irdl::DialectOp dialect)
static LogicalResult generateInclude(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
constexpr char declarationMacroFlag[]
constexpr char headerTemplateText[]
static std::string generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings)
static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, irdl::detail::dictionary &dict)
static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
constexpr char definitionMacroFlag[]
static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
static WalkResult advance()
static WalkResult interrupt()
Template Code as used by IRDL-to-Cpp.
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
LogicalResult translateIRDLDialectToCpp(llvm::ArrayRef< irdl::DialectOp > dialects, raw_ostream &output)
Translates an IRDL dialect definition to a C++ definition that can be used with MLIR.
Include the generated interface declarations.