12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
25 #include "Templates/Header.txt"
34 struct DialectStrings {
35 std::string dialectName;
36 std::string dialectCppName;
37 std::string dialectCppShortName;
38 std::string dialectBaseTypeName;
40 std::string namespaceOpen;
41 std::string namespaceClose;
42 std::string namespacePath;
48 std::string typeCppName;
54 std::string opCppName;
61 std::string nameArray;
62 llvm::raw_string_ostream nameArrayStream(nameArray);
63 nameArrayStream <<
"{\"" << llvm::join(names,
"\", \"") <<
"\"}";
69 static std::string typeToCppName(irdl::TypeOp type) {
70 return llvm::formatv(
"{0}Type",
71 convertToCamelFromSnakeCase(type.getSymName(),
true));
75 static std::string opToCppName(irdl::OperationOp op) {
76 return llvm::formatv(
"{0}Op",
77 convertToCamelFromSnakeCase(op.getSymName(),
true));
81 static TypeStrings getStrings(irdl::TypeOp type) {
83 strings.typeName = type.getSymName();
84 strings.typeCppName = typeToCppName(type);
89 static OpStrings getStrings(irdl::OperationOp op) {
90 auto operandOp = op.getOp<irdl::OperandsOp>();
91 auto resultOp = op.getOp<irdl::ResultsOp>();
92 auto regionsOp = op.getOp<irdl::RegionsOp>();
95 strings.opName = op.getSymName();
96 strings.opCppName = opToCppName(op);
100 llvm::map_range(operandOp->getNames(), [](
Attribute attr) {
101 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
107 llvm::map_range(resultOp->getNames(), [](
Attribute attr) {
108 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
114 llvm::map_range(regionsOp->getNames(), [](
Attribute attr) {
115 return llvm::formatv(
"{0}", cast<StringAttr>(attr));
124 const TypeStrings &strings) {
125 dict[
"TYPE_NAME"] = strings.typeName;
126 dict[
"TYPE_CPP_NAME"] = strings.typeCppName;
131 const auto operandCount = strings.opOperandNames.size();
132 const auto resultCount = strings.opResultNames.size();
133 const auto regionCount = strings.opRegionNames.size();
135 dict[
"OP_NAME"] = strings.opName;
136 dict[
"OP_CPP_NAME"] = strings.opCppName;
137 dict[
"OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size());
138 dict[
"OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size());
139 dict[
"OP_OPERAND_INITIALIZER_LIST"] =
140 operandCount ? joinNameList(strings.opOperandNames) :
"{\"\"}";
141 dict[
"OP_RESULT_INITIALIZER_LIST"] =
142 resultCount ? joinNameList(strings.opResultNames) :
"{\"\"}";
143 dict[
"OP_REGION_COUNT"] = std::to_string(regionCount);
148 const DialectStrings &strings) {
149 dict[
"DIALECT_NAME"] = strings.dialectName;
150 dict[
"DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
151 dict[
"DIALECT_CPP_NAME"] = strings.dialectCppName;
152 dict[
"DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
153 dict[
"NAMESPACE_OPEN"] = strings.namespaceOpen;
154 dict[
"NAMESPACE_CLOSE"] = strings.namespaceClose;
155 dict[
"NAMESPACE_PATH"] = strings.namespacePath;
158 static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
160 auto typeOps = dialect.getOps<irdl::TypeOp>();
161 auto range = llvm::map_range(typeOps, typeToCppName);
166 static LogicalResult generateOpList(irdl::DialectOp &dialect,
168 auto operationOps = dialect.getOps<irdl::OperationOp>();
169 auto range = llvm::map_range(operationOps, opToCppName);
179 #include
"Templates/TypeDecl.txt"
182 fillDict(dict, getStrings(type));
183 typeDeclTemplate.render(output, dict);
189 const OpStrings &opStrings) {
190 auto opGetters = std::string{};
191 auto resGetters = std::string{};
192 auto regionGetters = std::string{};
193 auto regionAdaptorGetters = std::string{};
195 for (
size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
197 llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i],
true);
198 opGetters += llvm::formatv(
"::mlir::Value get{0}() { return "
199 "getStructuredOperands({1}).front(); }\n ",
202 for (
size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
204 llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i],
true);
205 resGetters += llvm::formatv(
206 R
"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
211 for (
size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
213 llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i],
true);
214 regionAdaptorGetters += llvm::formatv(
215 R
"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
218 regionGetters += llvm::formatv(
219 R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
224 dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
225 dict[
"OP_RESULT_GETTER_DECLS"] = resGetters;
226 dict[
"OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
227 dict[
"OP_REGION_GETTER_DECLS"] = regionGetters;
231 const OpStrings &opStrings) {
232 std::string buildDecls;
233 llvm::raw_string_ostream stream{buildDecls};
236 llvm::join(llvm::map_range(opStrings.opResultNames,
237 [](StringRef name) -> std::string {
238 return llvm::formatv(
239 "::mlir::Type {0}, ",
240 llvm::convertToCamelFromSnakeCase(name));
245 llvm::join(llvm::map_range(opStrings.opOperandNames,
246 [](StringRef name) -> std::string {
247 return llvm::formatv(
248 "::mlir::Value {0}, ",
249 llvm::convertToCamelFromSnakeCase(name));
253 stream << llvm::formatv(
254 R
"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
255 resultParams, operandParams);
257 stream << llvm::formatv(
258 R
"(static {0} create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
259 opStrings.opCppName, resultParams, operandParams);
261 stream << llvm::formatv(
262 R
"(static {0} create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
263 opStrings.opCppName, resultParams, operandParams);
265 dict[
"OP_BUILD_DECLS"] = buildDecls;
270 const OpStrings &strings) {
272 if (!strings.opRegionNames.empty()) {
273 cppTraitNames.push_back(
274 llvm::formatv(
"::mlir::OpTrait::NRegions<{0}>::Impl",
275 strings.opRegionNames.size())
279 cppTraitNames.emplace_back(
"::mlir::OpTrait::OpInvariants");
281 return cppTraitNames;
288 #include
"Templates/PerOperationDecl.txt"
290 const auto opStrings = getStrings(op);
291 fillDict(dict, opStrings);
294 if (traitNames.empty())
295 dict[
"OP_TEMPLATE_ARGS"] = opStrings.opCppName;
297 dict[
"OP_TEMPLATE_ARGS"] = llvm::formatv(
"{0}, {1}", opStrings.opCppName,
298 llvm::join(traitNames,
", "));
303 perOpDeclTemplate.render(output, dict);
309 DialectStrings &dialectStrings) {
311 #include
"Templates/DialectDecl.txt"
314 #include
"Templates/TypeHeaderDecl.txt"
318 fillDict(dict, dialectStrings);
320 dialectDeclTemplate.render(output, dict);
321 typeHeaderDeclTemplate.render(output, dict);
323 auto typeOps = dialect.getOps<irdl::TypeOp>();
324 auto operationOps = dialect.getOps<irdl::OperationOp>();
326 for (
auto &&typeOp : typeOps) {
332 if (
failed(generateOpList(dialect, opNames)))
335 auto classDeclarations =
336 llvm::join(llvm::map_range(opNames,
337 [](llvm::StringRef name) -> std::string {
338 return llvm::formatv(
"class {0};", name);
341 const auto forwardDeclarations = llvm::formatv(
342 "{1}\n{0}\n{2}", std::move(classDeclarations),
343 dialectStrings.namespaceOpen, dialectStrings.namespaceClose);
345 output << forwardDeclarations;
346 for (
auto &&operationOp : operationOps) {
358 auto regionsOp = op.getOp<irdl::RegionsOp>();
359 if (strings.opRegionNames.empty() || !regionsOp)
362 for (
size_t i = 0; i < strings.opRegionNames.size(); ++i) {
363 std::string regionName = strings.opRegionNames[i];
364 std::string helperFnName =
365 llvm::formatv(
"__mlir_irdl_local_region_constraint_{0}_{1}",
366 strings.opCppName, regionName)
370 std::string condition =
"true";
371 std::string textualConditionName =
"any region";
373 if (
auto regionDefOp =
374 dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
380 if (
auto blockCount = regionDefOp.getNumberOfBlocks()) {
381 conditionParts.push_back(
382 llvm::formatv(
"region.getBlocks().size() == {0}",
385 descriptionParts.push_back(
386 llvm::formatv(
"exactly {0} block(s)", blockCount.value()).str());
390 if (regionDefOp.getConstrainedArguments()) {
391 size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
392 conditionParts.push_back(
393 llvm::formatv(
"region.getNumArguments() == {0}", expectedArgCount)
395 descriptionParts.push_back(
396 llvm::formatv(
"{0} entry block argument(s)", expectedArgCount)
401 if (!conditionParts.empty()) {
402 condition = llvm::join(conditionParts,
" && ");
406 if (!descriptionParts.empty()) {
407 textualConditionName =
408 llvm::formatv(
"region with {0}",
409 llvm::join(descriptionParts,
" and "))
414 verifierHelpers.push_back(llvm::formatv(
415 R
"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
417 return op->emitOpError("region #") << regionIndex
418 << (regionName.empty() ? " " : " ('" + regionName + "') ")
419 << "failed to verify constraint: {2}";
421 return ::mlir::success();
423 helperFnName, condition, textualConditionName));
425 verifierCalls.push_back(llvm::formatv(R"(
426 if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
427 return ::mlir::failure();)",
428 helperFnName, i, regionName)
434 irdl::OperationOp op,
const OpStrings &strings) {
442 std::string verifierDef =
444 ::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
445 if(::mlir::failed(verify()))
446 return ::mlir::failure();
450 return ::mlir::success();
452 strings.opCppName, llvm::join(verifierCalls, "\n"));
454 dict[
"OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers,
"\n");
455 dict[
"OP_VERIFIER"] = verifierDef;
459 irdl::OperationOp op) {
461 #include "Templates/PerOperationDef.txt"
464 auto opStrings = getStrings(op);
465 fillDict(dict, opStrings);
467 auto resultTypes = llvm::join(
468 llvm::map_range(opStrings.opResultNames,
469 [](StringRef attr) -> std::string {
470 return llvm::formatv(
"::mlir::Type {0}, ", attr);
473 auto operandTypes = llvm::join(
474 llvm::map_range(opStrings.opOperandNames,
475 [](StringRef attr) -> std::string {
476 return llvm::formatv(
"::mlir::Value {0}, ", attr);
480 llvm::join(llvm::map_range(opStrings.opOperandNames,
481 [](StringRef attr) -> std::string {
482 return llvm::formatv(
483 " opState.addOperands({0});", attr);
486 auto resultAdder = llvm::join(
487 llvm::map_range(opStrings.opResultNames,
488 [](StringRef attr) -> std::string {
489 return llvm::formatv(
" opState.addTypes({0});", attr);
493 const auto buildDefinition = llvm::formatv(
495 void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
500 {0} {0}::create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
501 ::mlir::OperationState __state__(location, getOperationName());
502 build(opBuilder, __state__, {5} {6} attributes);
503 auto __res__ = ::llvm::dyn_cast<{0}>(opBuilder.create(__state__));
504 assert(__res__ && "builder didn't return the right type");
508 {0} {0}::create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
509 return create(opBuilder, opBuilder.getLoc(), {5} {6} attributes);
512 opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
513 std::move(operandAdder), std::move(resultAdder),
514 llvm::join(opStrings.opResultNames, ",") +
515 (!opStrings.opResultNames.empty() ?
"," :
""),
516 llvm::join(opStrings.opOperandNames,
",") +
517 (!opStrings.opOperandNames.empty() ?
"," :
""));
519 dict[
"OP_BUILD_DEFS"] = buildDefinition;
524 llvm::raw_string_ostream stream{str};
525 perOpDefTemplate.render(stream, dict);
531 return llvm::formatv(
532 R
"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
533 value = {1}::{0}::get(parser.getContext());
534 return ::mlir::success(!!value);
536 name, dialectStrings.namespacePath);
539 static LogicalResult
generateLib(irdl::DialectOp dialect, raw_ostream &output,
540 DialectStrings &dialectStrings) {
543 #include "Templates/TypeHeaderDef.txt"
546 #include "Templates/TypeDef.txt"
549 #include "Templates/DialectDef.txt"
553 fillDict(dict, dialectStrings);
555 typeHeaderDefTemplate.render(output, dict);
558 if (
failed(generateTypedefList(dialect, typeNames)))
561 dict[
"TYPE_LIST"] = llvm::join(
562 llvm::map_range(typeNames,
563 [&dialectStrings](llvm::StringRef name) -> std::string {
564 return llvm::formatv(
565 "{0}::{1}", dialectStrings.namespacePath, name);
569 auto typeVerifierGenerator =
570 [&dialectStrings](llvm::StringRef name) -> std::string {
575 llvm::join(llvm::map_range(typeNames, typeVerifierGenerator),
"\n");
577 dict[
"TYPE_PARSER"] = llvm::formatv(
578 R
"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
579 return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
581 .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
586 std::move(typeCase));
589 llvm::join(llvm::map_range(typeNames,
590 [&](llvm::StringRef name) -> std::string {
591 return llvm::formatv(
592 R
"(.Case<{1}::{0}>([&](auto t) {
593 printer << {1}::{0}::getMnemonic();
594 return ::mlir::success();
596 name, dialectStrings.namespacePath);
599 dict[
"TYPE_PRINTER"] = llvm::formatv(
600 R
"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
601 return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
603 .Default([](auto) {{ return ::mlir::failure(); });
605 std::move(typePrintCase));
607 dict["TYPE_DEFINES"] =
608 join(map_range(typeNames,
609 [&](StringRef name) -> std::string {
610 return formatv(
"MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
611 name, dialectStrings.namespacePath);
615 typeDefTemplate.render(output, dict);
617 auto operations = dialect.getOps<irdl::OperationOp>();
619 if (
failed(generateOpList(dialect, opNames)))
622 const auto commaSeparatedOpList = llvm::join(
624 [&dialectStrings](llvm::StringRef name) -> std::string {
625 return llvm::formatv(
"{0}::{1}", dialectStrings.namespacePath,
630 const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
634 const auto perOpDefinitions =
635 llvm::join(llvm::map_range(operations, opDefinitionGenerator),
"\n");
637 dict[
"OP_LIST"] = commaSeparatedOpList;
638 dict[
"OP_CLASSES"] = perOpDefinitions;
639 output << perOpDefinitions;
640 dialectDefTemplate.render(output, dict);
646 LogicalResult res = success();
650 .Case<irdl::DialectOp>(([](irdl::DialectOp) {
return success(); }))
651 .Case<irdl::OperationOp>(
652 ([](irdl::OperationOp) {
return success(); }))
653 .Case<irdl::TypeOp>(([](irdl::TypeOp) {
return success(); }))
654 .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
656 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
657 return attr.getValue() == irdl::Variadicity::single;
660 return op.emitError(
"IRDL C++ translation does not yet support "
661 "variadic operations");
663 .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
665 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
666 return attr.getValue() == irdl::Variadicity::single;
670 "IRDL C++ translation does not yet support variadic results");
672 .Case<irdl::AnyOp>(([](irdl::AnyOp) {
return success(); }))
673 .Case<irdl::RegionOp>(([](irdl::RegionOp) {
return success(); }))
674 .Case<irdl::RegionsOp>(([](irdl::RegionsOp) {
return success(); }))
676 return op->
emitError(
"IRDL C++ translation does not yet support "
678 << op->
getName() <<
" operation";
692 raw_ostream &output) {
693 static const auto typeDefTempl = detail::Template(
694 #include
"Templates/TypeDef.txt"
697 llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
699 for (
auto dialect : dialects) {
703 StringRef dialectName = dialect.getSymName();
706 std::string namespaceOpen;
707 std::string namespaceClose;
708 std::string namespacePath;
709 llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
710 llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
711 llvm::raw_string_ostream namespacePathStream(namespacePath);
712 for (
auto &pathElement : namespaceAbsolutePath) {
713 namespaceOpenStream <<
"namespace " << pathElement <<
" {\n";
714 namespaceCloseStream <<
"} // namespace " << pathElement <<
"\n";
715 namespacePathStream <<
"::" << pathElement;
718 std::string cppShortName =
719 llvm::convertToCamelFromSnakeCase(dialectName,
true);
720 std::string dialectBaseTypeName = llvm::formatv(
"{0}Type", cppShortName);
721 std::string cppName = llvm::formatv(
"{0}Dialect", cppShortName);
723 DialectStrings dialectStrings;
724 dialectStrings.dialectName = dialectName;
725 dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
726 dialectStrings.dialectCppName = cppName;
727 dialectStrings.dialectCppShortName = cppShortName;
728 dialectStrings.namespaceOpen = namespaceOpen;
729 dialectStrings.namespaceClose = namespaceClose;
730 dialectStrings.namespacePath = namespacePath;
732 dialectStringTable[dialect] = std::move(dialectStrings);
739 for (
auto dialect : dialects) {
741 auto &dialectStrings = dialectStringTable[dialect];
742 auto &dialectName = dialectStrings.dialectName;
745 return dialect->emitError(
"Error in Dialect " + dialectName +
746 " while generating headers");
751 for (
auto &dialect : dialects) {
752 auto &dialectStrings = dialectStringTable[dialect];
753 auto &dialectName = dialectStrings.dialectName;
756 return dialect->emitError(
"Error in Dialect " + dialectName +
757 " while generating library");
static LogicalResult verifySupported(irdl::DialectOp dialect)
static LogicalResult generateInclude(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
static SmallVector< std::string > generateTraits(irdl::OperationOp op, const OpStrings &strings)
static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
constexpr char declarationMacroFlag[]
constexpr char headerTemplateText[]
static std::string generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings)
static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateRegionConstraintVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings, SmallVectorImpl< std::string > &verifierHelpers, SmallVectorImpl< std::string > &verifierCalls)
static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
constexpr char definitionMacroFlag[]
static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op)
static void generateVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings)
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
static WalkResult advance()
static WalkResult interrupt()
Template Code as used by IRDL-to-Cpp.
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
LogicalResult translateIRDLDialectToCpp(llvm::ArrayRef< irdl::DialectOp > dialects, raw_ostream &output)
Translates an IRDL dialect definition to a C++ definition that can be used with MLIR.
Include the generated interface declarations.