MLIR 22.0.0git
IRDLToCpp.cpp
Go to the documentation of this file.
1//===- IRDLToCpp.cpp - Converts IRDL definitions to C++ -------------------===//
2//
3// Part of the LLVM Project, under the A0ache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Support/LLVM.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/ADT/SmallVector.h"
15#include "llvm/ADT/StringExtras.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/raw_ostream.h"
19
20#include "TemplatingUtils.h"
21
22using namespace mlir;
23
24constexpr char headerTemplateText[] =
25#include "Templates/Header.txt"
26 ;
27
28constexpr char declarationMacroFlag[] = "GEN_DIALECT_DECL_HEADER";
29constexpr char definitionMacroFlag[] = "GEN_DIALECT_DEF";
30
31namespace {
32
33/// The set of strings that can be generated from a Dialect declaraiton
34struct DialectStrings {
35 std::string dialectName;
36 std::string dialectCppName;
37 std::string dialectCppShortName;
38 std::string dialectBaseTypeName;
39
40 std::string namespaceOpen;
41 std::string namespaceClose;
42 std::string namespacePath;
43};
44
45/// The set of strings that can be generated from a Type declaraiton
46struct TypeStrings {
47 StringRef typeName;
48 std::string typeCppName;
49};
50
51/// The set of strings that can be generated from an Operation declaraiton
52struct OpStrings {
53 StringRef opName;
54 std::string opCppName;
55 SmallVector<std::string> opResultNames;
56 SmallVector<std::string> opOperandNames;
57 SmallVector<std::string> opRegionNames;
58};
59
60static std::string joinNameList(llvm::ArrayRef<std::string> names) {
61 std::string nameArray;
62 llvm::raw_string_ostream nameArrayStream(nameArray);
63 nameArrayStream << "{\"" << llvm::join(names, "\", \"") << "\"}";
64
65 return nameArray;
66}
67
68/// Generates the C++ type name for a TypeOp
69static std::string typeToCppName(irdl::TypeOp type) {
70 return llvm::formatv("{0}Type",
71 convertToCamelFromSnakeCase(type.getSymName(), true));
72}
73
74/// Generates the C++ class name for an OperationOp
75static std::string opToCppName(irdl::OperationOp op) {
76 return llvm::formatv("{0}Op",
77 convertToCamelFromSnakeCase(op.getSymName(), true));
78}
79
80/// Generates TypeStrings from a TypeOp
81static TypeStrings getStrings(irdl::TypeOp type) {
82 TypeStrings strings;
83 strings.typeName = type.getSymName();
84 strings.typeCppName = typeToCppName(type);
85 return strings;
86}
87
88/// Generates OpStrings from an OperatioOp
89static OpStrings getStrings(irdl::OperationOp op) {
90 auto operandOp = op.getOp<irdl::OperandsOp>();
91 auto resultOp = op.getOp<irdl::ResultsOp>();
92 auto regionsOp = op.getOp<irdl::RegionsOp>();
93
94 OpStrings strings;
95 strings.opName = op.getSymName();
96 strings.opCppName = opToCppName(op);
97
98 if (operandOp) {
99 strings.opOperandNames = SmallVector<std::string>(
100 llvm::map_range(operandOp->getNames(), [](Attribute attr) {
101 return llvm::formatv("{0}", cast<StringAttr>(attr));
102 }));
103 }
104
105 if (resultOp) {
106 strings.opResultNames = SmallVector<std::string>(
107 llvm::map_range(resultOp->getNames(), [](Attribute attr) {
108 return llvm::formatv("{0}", cast<StringAttr>(attr));
109 }));
110 }
111
112 if (regionsOp) {
113 strings.opRegionNames = SmallVector<std::string>(
114 llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
115 return llvm::formatv("{0}", cast<StringAttr>(attr));
116 }));
117 }
118
119 return strings;
120}
121
122/// Fills a dictionary with values from TypeStrings
123static void fillDict(irdl::detail::dictionary &dict,
124 const TypeStrings &strings) {
125 dict["TYPE_NAME"] = strings.typeName;
126 dict["TYPE_CPP_NAME"] = strings.typeCppName;
127}
128
129/// Fills a dictionary with values from OpStrings
130static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131 const auto operandCount = strings.opOperandNames.size();
132 const auto resultCount = strings.opResultNames.size();
133 const auto regionCount = strings.opRegionNames.size();
134
135 dict["OP_NAME"] = strings.opName;
136 dict["OP_CPP_NAME"] = strings.opCppName;
137 dict["OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size());
138 dict["OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size());
139 dict["OP_OPERAND_INITIALIZER_LIST"] =
140 operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
141 dict["OP_RESULT_INITIALIZER_LIST"] =
142 resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
143 dict["OP_REGION_COUNT"] = std::to_string(regionCount);
144}
145
146/// Fills a dictionary with values from DialectStrings
147static void fillDict(irdl::detail::dictionary &dict,
148 const DialectStrings &strings) {
149 dict["DIALECT_NAME"] = strings.dialectName;
150 dict["DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
151 dict["DIALECT_CPP_NAME"] = strings.dialectCppName;
152 dict["DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
153 dict["NAMESPACE_OPEN"] = strings.namespaceOpen;
154 dict["NAMESPACE_CLOSE"] = strings.namespaceClose;
155 dict["NAMESPACE_PATH"] = strings.namespacePath;
156}
157
158static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
159 SmallVector<std::string> &typeNames) {
160 auto typeOps = dialect.getOps<irdl::TypeOp>();
161 auto range = llvm::map_range(typeOps, typeToCppName);
162 typeNames = SmallVector<std::string>(range);
163 return success();
164}
165
166static LogicalResult generateOpList(irdl::DialectOp &dialect,
167 SmallVector<std::string> &opNames) {
168 auto operationOps = dialect.getOps<irdl::OperationOp>();
169 auto range = llvm::map_range(operationOps, opToCppName);
170 opNames = SmallVector<std::string>(range);
171 return success();
172}
173
174} // namespace
175
176static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output,
178 static const auto typeDeclTemplate = irdl::detail::Template(
179#include "Templates/TypeDecl.txt"
180 );
181
182 fillDict(dict, getStrings(type));
183 typeDeclTemplate.render(output, dict);
184
185 return success();
186}
187
189 const OpStrings &opStrings) {
190 auto opGetters = std::string{};
191 auto resGetters = std::string{};
192 auto regionGetters = std::string{};
193 auto regionAdaptorGetters = std::string{};
194
195 for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
196 const auto op =
197 llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i], true);
198 opGetters += llvm::formatv("::mlir::Value get{0}() { return "
199 "getStructuredOperands({1}).front(); }\n ",
200 op, i);
201 }
202 for (size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
203 const auto op =
204 llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i], true);
205 resGetters += llvm::formatv(
206 R"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
207 )",
208 op, i);
209 }
210
211 for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
212 const auto op =
213 llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
214 regionAdaptorGetters += llvm::formatv(
215 R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
216 )",
217 op, i);
218 regionGetters += llvm::formatv(
219 R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
220 )",
221 op, i);
222 }
223
224 dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
225 dict["OP_RESULT_GETTER_DECLS"] = resGetters;
226 dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
227 dict["OP_REGION_GETTER_DECLS"] = regionGetters;
228}
229
231 const OpStrings &opStrings) {
232 std::string buildDecls;
233 llvm::raw_string_ostream stream{buildDecls};
234
235 auto resultParams =
236 llvm::join(llvm::map_range(opStrings.opResultNames,
237 [](StringRef name) -> std::string {
238 return llvm::formatv(
239 "::mlir::Type {0}, ",
240 llvm::convertToCamelFromSnakeCase(name));
241 }),
242 "");
243
244 auto operandParams =
245 llvm::join(llvm::map_range(opStrings.opOperandNames,
246 [](StringRef name) -> std::string {
247 return llvm::formatv(
248 "::mlir::Value {0}, ",
249 llvm::convertToCamelFromSnakeCase(name));
250 }),
251 "");
252
253 stream << llvm::formatv(
254 R"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
255 resultParams, operandParams);
256 stream << "\n";
257 stream << llvm::formatv(
258 R"(static {0} create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
259 opStrings.opCppName, resultParams, operandParams);
260 stream << "\n";
261 stream << llvm::formatv(
262 R"(static {0} create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
263 opStrings.opCppName, resultParams, operandParams);
264 stream << "\n";
265 dict["OP_BUILD_DECLS"] = buildDecls;
267
268// add traits to the dictionary, return true if any were added
269static SmallVector<std::string> generateTraits(irdl::OperationOp op,
270 const OpStrings &strings) {
271 SmallVector<std::string> cppTraitNames;
272 if (!strings.opRegionNames.empty()) {
273 cppTraitNames.push_back(
274 llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
275 strings.opRegionNames.size())
276 .str());
277
278 // Requires verifyInvariantsImpl is implemented on the op
279 cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
280 }
281 return cppTraitNames;
282}
283
284static LogicalResult generateOperationInclude(irdl::OperationOp op,
285 raw_ostream &output,
287 static const auto perOpDeclTemplate = irdl::detail::Template(
288#include "Templates/PerOperationDecl.txt"
289 );
290 const auto opStrings = getStrings(op);
291 fillDict(dict, opStrings);
292
293 SmallVector<std::string> traitNames = generateTraits(op, opStrings);
294 if (traitNames.empty())
295 dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
296 else
297 dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
298 llvm::join(traitNames, ", "));
299
300 generateOpGetterDeclarations(dict, opStrings);
301 generateOpBuilderDeclarations(dict, opStrings);
302
303 perOpDeclTemplate.render(output, dict);
304 return success();
305}
306
307static LogicalResult generateInclude(irdl::DialectOp dialect,
308 raw_ostream &output,
309 DialectStrings &dialectStrings) {
310 static const auto dialectDeclTemplate = irdl::detail::Template(
311#include "Templates/DialectDecl.txt"
312 );
313 static const auto typeHeaderDeclTemplate = irdl::detail::Template(
314#include "Templates/TypeHeaderDecl.txt"
315 );
316
318 fillDict(dict, dialectStrings);
319
320 dialectDeclTemplate.render(output, dict);
321 typeHeaderDeclTemplate.render(output, dict);
322
323 auto typeOps = dialect.getOps<irdl::TypeOp>();
324 auto operationOps = dialect.getOps<irdl::OperationOp>();
325
326 for (auto &&typeOp : typeOps) {
327 if (failed(generateTypeInclude(typeOp, output, dict)))
328 return failure();
329 }
330
332 if (failed(generateOpList(dialect, opNames)))
333 return failure();
334
335 auto classDeclarations =
336 llvm::join(llvm::map_range(opNames,
337 [](llvm::StringRef name) -> std::string {
338 return llvm::formatv("class {0};", name);
339 }),
340 "\n");
341 const auto forwardDeclarations = llvm::formatv(
342 "{1}\n{0}\n{2}", std::move(classDeclarations),
343 dialectStrings.namespaceOpen, dialectStrings.namespaceClose);
344
345 output << forwardDeclarations;
346 for (auto &&operationOp : operationOps) {
347 if (failed(generateOperationInclude(operationOp, output, dict)))
348 return failure();
349 }
350
351 return success();
352}
353
355 irdl::detail::dictionary &dict, irdl::OperationOp op,
356 const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
357 SmallVectorImpl<std::string> &verifierCalls) {
358 auto regionsOp = op.getOp<irdl::RegionsOp>();
359 if (strings.opRegionNames.empty() || !regionsOp)
360 return;
361
362 for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
363 std::string regionName = strings.opRegionNames[i];
364 std::string helperFnName =
365 llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
366 strings.opCppName, regionName)
367 .str();
368
369 // Extract the actual region constraint from the IRDL RegionOp
370 std::string condition = "true";
371 std::string textualConditionName = "any region";
372
373 if (auto regionDefOp =
374 dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
375 // Generate constraint condition based on RegionOp attributes
376 SmallVector<std::string> conditionParts;
377 SmallVector<std::string> descriptionParts;
378
379 // Check number of blocks constraint
380 if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
381 conditionParts.push_back(
382 llvm::formatv("region.getBlocks().size() == {0}",
383 blockCount.value())
384 .str());
385 descriptionParts.push_back(
386 llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
387 }
388
389 // Check entry block arguments constraint
390 if (regionDefOp.getConstrainedArguments()) {
391 size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
392 conditionParts.push_back(
393 llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
394 .str());
395 descriptionParts.push_back(
396 llvm::formatv("{0} entry block argument(s)", expectedArgCount)
397 .str());
398 }
399
400 // Combine conditions
401 if (!conditionParts.empty()) {
402 condition = llvm::join(conditionParts, " && ");
403 }
404
405 // Generate descriptive error message
406 if (!descriptionParts.empty()) {
407 textualConditionName =
408 llvm::formatv("region with {0}",
409 llvm::join(descriptionParts, " and "))
410 .str();
411 }
412 }
413
414 verifierHelpers.push_back(llvm::formatv(
415 R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
416 if (!({1})) {{
417 return op->emitOpError("region #") << regionIndex
418 << (regionName.empty() ? " " : " ('" + regionName + "') ")
419 << "failed to verify constraint: {2}";
420 }
421 return ::mlir::success();
422})",
423 helperFnName, condition, textualConditionName));
424
425 verifierCalls.push_back(llvm::formatv(R"(
426 if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
427 return ::mlir::failure();)",
428 helperFnName, i, regionName)
429 .str());
430 }
431}
432
434 irdl::OperationOp op, const OpStrings &strings) {
435 SmallVector<std::string> verifierHelpers;
436 SmallVector<std::string> verifierCalls;
437
438 generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
439 verifierCalls);
440
441 // Add an overall verifier that sequences the helper calls
442 std::string verifierDef =
443 llvm::formatv(R"(
444::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
445 if(::mlir::failed(verify()))
446 return ::mlir::failure();
447
448 {1}
449
450 return ::mlir::success();
451})",
452 strings.opCppName, llvm::join(verifierCalls, "\n"));
453
454 dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
455 dict["OP_VERIFIER"] = verifierDef;
456}
457
458static std::string generateOpDefinition(irdl::detail::dictionary &dict,
459 irdl::OperationOp op) {
460 static const auto perOpDefTemplate = mlir::irdl::detail::Template{
461#include "Templates/PerOperationDef.txt"
462 };
463
464 auto opStrings = getStrings(op);
465 fillDict(dict, opStrings);
466
467 auto resultTypes = llvm::join(
468 llvm::map_range(opStrings.opResultNames,
469 [](StringRef attr) -> std::string {
470 return llvm::formatv("::mlir::Type {0}, ", attr);
471 }),
472 "");
473 auto operandTypes = llvm::join(
474 llvm::map_range(opStrings.opOperandNames,
475 [](StringRef attr) -> std::string {
476 return llvm::formatv("::mlir::Value {0}, ", attr);
477 }),
478 "");
479 auto operandAdder =
480 llvm::join(llvm::map_range(opStrings.opOperandNames,
481 [](StringRef attr) -> std::string {
482 return llvm::formatv(
483 " opState.addOperands({0});", attr);
484 }),
485 "\n");
486 auto resultAdder = llvm::join(
487 llvm::map_range(opStrings.opResultNames,
488 [](StringRef attr) -> std::string {
489 return llvm::formatv(" opState.addTypes({0});", attr);
490 }),
491 "\n");
492
493 const auto buildDefinition = llvm::formatv(
494 R"(
495void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
496{3}
497{4}
498}
500{0} {0}::create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
501 ::mlir::OperationState __state__(location, getOperationName());
502 build(opBuilder, __state__, {5} {6} attributes);
503 auto __res__ = ::llvm::dyn_cast<{0}>(opBuilder.create(__state__));
504 assert(__res__ && "builder didn't return the right type");
505 return __res__;
506}
507
508{0} {0}::create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
509 return create(opBuilder, opBuilder.getLoc(), {5} {6} attributes);
510}
511)",
512 opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
513 std::move(operandAdder), std::move(resultAdder),
514 llvm::join(opStrings.opResultNames, ",") +
515 (!opStrings.opResultNames.empty() ? "," : ""),
516 llvm::join(opStrings.opOperandNames, ",") +
517 (!opStrings.opOperandNames.empty() ? "," : ""));
518
519 dict["OP_BUILD_DEFS"] = buildDefinition;
520
521 generateVerifiers(dict, op, opStrings);
522
523 std::string str;
524 llvm::raw_string_ostream stream{str};
525 perOpDefTemplate.render(stream, dict);
526 return str;
527}
528
529static std::string
530generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings) {
531 return llvm::formatv(
532 R"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
533value = {1}::{0}::get(parser.getContext());
534return ::mlir::success(!!value);
535}))",
536 name, dialectStrings.namespacePath);
537}
538
539static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
540 DialectStrings &dialectStrings) {
541
542 static const auto typeHeaderDefTemplate = mlir::irdl::detail::Template{
543#include "Templates/TypeHeaderDef.txt"
544 };
545 static const auto typeDefTemplate = mlir::irdl::detail::Template{
546#include "Templates/TypeDef.txt"
547 };
548 static const auto dialectDefTemplate = mlir::irdl::detail::Template{
549#include "Templates/DialectDef.txt"
550 };
551
553 fillDict(dict, dialectStrings);
554
555 typeHeaderDefTemplate.render(output, dict);
556
557 SmallVector<std::string> typeNames;
558 if (failed(generateTypedefList(dialect, typeNames)))
559 return failure();
560
561 dict["TYPE_LIST"] = llvm::join(
562 llvm::map_range(typeNames,
563 [&dialectStrings](llvm::StringRef name) -> std::string {
564 return llvm::formatv(
565 "{0}::{1}", dialectStrings.namespacePath, name);
566 }),
567 ",\n");
568
569 auto typeVerifierGenerator =
570 [&dialectStrings](llvm::StringRef name) -> std::string {
571 return generateTypeVerifierCase(name, dialectStrings);
572 };
573
574 auto typeCase =
575 llvm::join(llvm::map_range(typeNames, typeVerifierGenerator), "\n");
576
577 dict["TYPE_PARSER"] = llvm::formatv(
578 R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
579 return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
580 {0}
581 .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
582 *mnemonic = keyword;
583 return std::nullopt;
584 });
585})",
586 std::move(typeCase));
587
588 auto typePrintCase =
589 llvm::join(llvm::map_range(typeNames,
590 [&](llvm::StringRef name) -> std::string {
591 return llvm::formatv(
592 R"(.Case<{1}::{0}>([&](auto t) {
593 printer << {1}::{0}::getMnemonic();
594 return ::mlir::success();
595 }))",
596 name, dialectStrings.namespacePath);
597 }),
598 "\n");
599 dict["TYPE_PRINTER"] = llvm::formatv(
600 R"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
601 return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
602 {0}
603 .Default([](auto) {{ return ::mlir::failure(); });
604})",
605 std::move(typePrintCase));
606
607 dict["TYPE_DEFINES"] =
608 join(map_range(typeNames,
609 [&](StringRef name) -> std::string {
610 return formatv("MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
611 name, dialectStrings.namespacePath);
612 }),
613 "\n");
614
615 typeDefTemplate.render(output, dict);
616
617 auto operations = dialect.getOps<irdl::OperationOp>();
619 if (failed(generateOpList(dialect, opNames)))
620 return failure();
621
622 const auto commaSeparatedOpList = llvm::join(
623 map_range(opNames,
624 [&dialectStrings](llvm::StringRef name) -> std::string {
625 return llvm::formatv("{0}::{1}", dialectStrings.namespacePath,
626 name);
627 }),
628 ",\n");
629
630 const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
631 return generateOpDefinition(dict, op);
632 };
633
634 const auto perOpDefinitions =
635 llvm::join(llvm::map_range(operations, opDefinitionGenerator), "\n");
636
637 dict["OP_LIST"] = commaSeparatedOpList;
638 dict["OP_CLASSES"] = perOpDefinitions;
639 output << perOpDefinitions;
640 dialectDefTemplate.render(output, dict);
641
642 return success();
643}
644
645static LogicalResult verifySupported(irdl::DialectOp dialect) {
646 LogicalResult res = success();
647 dialect.walk([&](mlir::Operation *op) {
648 res =
650 .Case<irdl::DialectOp>(([](irdl::DialectOp) { return success(); }))
651 .Case<irdl::OperationOp>(
652 ([](irdl::OperationOp) { return success(); }))
653 .Case<irdl::TypeOp>(([](irdl::TypeOp) { return success(); }))
654 .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
655 if (llvm::all_of(
656 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
657 return attr.getValue() == irdl::Variadicity::single;
658 }))
659 return success();
660 return op.emitError("IRDL C++ translation does not yet support "
661 "variadic operations");
662 }))
663 .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
664 if (llvm::all_of(
665 op.getVariadicity(), [](irdl::VariadicityAttr attr) {
666 return attr.getValue() == irdl::Variadicity::single;
667 }))
668 return success();
669 return op.emitError(
670 "IRDL C++ translation does not yet support variadic results");
671 }))
672 .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
673 .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
674 .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
675 .Default([](mlir::Operation *op) -> LogicalResult {
676 return op->emitError("IRDL C++ translation does not yet support "
677 "translation of ")
678 << op->getName() << " operation";
679 });
680
681 if (failed(res))
682 return WalkResult::interrupt();
683
684 return WalkResult::advance();
685 });
686
687 return res;
688}
689
690LogicalResult
692 raw_ostream &output) {
693 static const auto typeDefTempl = detail::Template(
694#include "Templates/TypeDef.txt"
695 );
696
697 llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
698
699 for (auto dialect : dialects) {
700 if (failed(verifySupported(dialect)))
701 return failure();
702
703 StringRef dialectName = dialect.getSymName();
704
705 SmallVector<SmallString<8>> namespaceAbsolutePath{{"mlir"}, dialectName};
706 std::string namespaceOpen;
707 std::string namespaceClose;
708 std::string namespacePath;
709 llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
710 llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
711 llvm::raw_string_ostream namespacePathStream(namespacePath);
712 for (auto &pathElement : namespaceAbsolutePath) {
713 namespaceOpenStream << "namespace " << pathElement << " {\n";
714 namespaceCloseStream << "} // namespace " << pathElement << "\n";
715 namespacePathStream << "::" << pathElement;
716 }
717
718 std::string cppShortName =
719 llvm::convertToCamelFromSnakeCase(dialectName, true);
720 std::string dialectBaseTypeName = llvm::formatv("{0}Type", cppShortName);
721 std::string cppName = llvm::formatv("{0}Dialect", cppShortName);
722
723 DialectStrings dialectStrings;
724 dialectStrings.dialectName = dialectName;
725 dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
726 dialectStrings.dialectCppName = cppName;
727 dialectStrings.dialectCppShortName = cppShortName;
728 dialectStrings.namespaceOpen = namespaceOpen;
729 dialectStrings.namespaceClose = namespaceClose;
730 dialectStrings.namespacePath = namespacePath;
731
732 dialectStringTable[dialect] = std::move(dialectStrings);
733 }
734
735 // generate the actual header
736 output << headerTemplateText;
737
738 output << llvm::formatv("#ifdef {0}\n#undef {0}\n", declarationMacroFlag);
739 for (auto dialect : dialects) {
740
741 auto &dialectStrings = dialectStringTable[dialect];
742 auto &dialectName = dialectStrings.dialectName;
743
744 if (failed(generateInclude(dialect, output, dialectStrings)))
745 return dialect->emitError("Error in Dialect " + dialectName +
746 " while generating headers");
747 }
748 output << llvm::formatv("#endif // #ifdef {}\n", declarationMacroFlag);
749
750 output << llvm::formatv("#ifdef {0}\n#undef {0}\n ", definitionMacroFlag);
751 for (auto &dialect : dialects) {
752 auto &dialectStrings = dialectStringTable[dialect];
753 auto &dialectName = dialectStrings.dialectName;
754
755 if (failed(generateLib(dialect, output, dialectStrings)))
756 return dialect->emitError("Error in Dialect " + dialectName +
757 " while generating library");
758 }
759 output << llvm::formatv("#endif // #ifdef {}\n", definitionMacroFlag);
760
761 return success();
762}
return success()
static LogicalResult verifySupported(irdl::DialectOp dialect)
static LogicalResult generateInclude(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
static SmallVector< std::string > generateTraits(irdl::OperationOp op, const OpStrings &strings)
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
constexpr char declarationMacroFlag[]
Definition IRDLToCpp.cpp:28
constexpr char headerTemplateText[]
Definition IRDLToCpp.cpp:24
static std::string generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings)
static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, irdl::detail::dictionary &dict)
static void generateRegionConstraintVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings, SmallVectorImpl< std::string > &verifierHelpers, SmallVectorImpl< std::string > &verifierCalls)
static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
constexpr char definitionMacroFlag[]
Definition IRDLToCpp.cpp:29
static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op)
static void generateVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Template Code as used by IRDL-to-Cpp.
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
LogicalResult translateIRDLDialectToCpp(llvm::ArrayRef< irdl::DialectOp > dialects, raw_ostream &output)
Translates an IRDL dialect definition to a C++ definition that can be used with MLIR.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.