MLIR  22.0.0git
IRDLToCpp.cpp
Go to the documentation of this file.
1 //===- IRDLToCpp.cpp - Converts IRDL definitions to C++ -------------------===//
2 //
3 // Part of the LLVM Project, under the A0ache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 #include "TemplatingUtils.h"
21 
22 using namespace mlir;
23 
24 constexpr char headerTemplateText[] =
25 #include "Templates/Header.txt"
26  ;
27 
28 constexpr char declarationMacroFlag[] = "GEN_DIALECT_DECL_HEADER";
29 constexpr char definitionMacroFlag[] = "GEN_DIALECT_DEF";
30 
31 namespace {
32 
33 /// The set of strings that can be generated from a Dialect declaraiton
34 struct DialectStrings {
35  std::string dialectName;
36  std::string dialectCppName;
37  std::string dialectCppShortName;
38  std::string dialectBaseTypeName;
39 
40  std::string namespaceOpen;
41  std::string namespaceClose;
42  std::string namespacePath;
43 };
44 
45 /// The set of strings that can be generated from a Type declaraiton
46 struct TypeStrings {
47  StringRef typeName;
48  std::string typeCppName;
49 };
50 
51 /// The set of strings that can be generated from an Operation declaraiton
52 struct OpStrings {
53  StringRef opName;
54  std::string opCppName;
55  SmallVector<std::string> opResultNames;
56  SmallVector<std::string> opOperandNames;
57  SmallVector<std::string> opRegionNames;
58 };
59 
60 static std::string joinNameList(llvm::ArrayRef<std::string> names) {
61  std::string nameArray;
62  llvm::raw_string_ostream nameArrayStream(nameArray);
63  nameArrayStream << "{\"" << llvm::join(names, "\", \"") << "\"}";
64 
65  return nameArray;
66 }
67 
68 /// Generates the C++ type name for a TypeOp
69 static std::string typeToCppName(irdl::TypeOp type) {
70  return llvm::formatv("{0}Type",
71  convertToCamelFromSnakeCase(type.getSymName(), true));
72 }
73 
74 /// Generates the C++ class name for an OperationOp
75 static std::string opToCppName(irdl::OperationOp op) {
76  return llvm::formatv("{0}Op",
77  convertToCamelFromSnakeCase(op.getSymName(), true));
78 }
79 
80 /// Generates TypeStrings from a TypeOp
81 static TypeStrings getStrings(irdl::TypeOp type) {
82  TypeStrings strings;
83  strings.typeName = type.getSymName();
84  strings.typeCppName = typeToCppName(type);
85  return strings;
86 }
87 
88 /// Generates OpStrings from an OperatioOp
89 static OpStrings getStrings(irdl::OperationOp op) {
90  auto operandOp = op.getOp<irdl::OperandsOp>();
91  auto resultOp = op.getOp<irdl::ResultsOp>();
92  auto regionsOp = op.getOp<irdl::RegionsOp>();
93 
94  OpStrings strings;
95  strings.opName = op.getSymName();
96  strings.opCppName = opToCppName(op);
97 
98  if (operandOp) {
99  strings.opOperandNames = SmallVector<std::string>(
100  llvm::map_range(operandOp->getNames(), [](Attribute attr) {
101  return llvm::formatv("{0}", cast<StringAttr>(attr));
102  }));
103  }
104 
105  if (resultOp) {
106  strings.opResultNames = SmallVector<std::string>(
107  llvm::map_range(resultOp->getNames(), [](Attribute attr) {
108  return llvm::formatv("{0}", cast<StringAttr>(attr));
109  }));
110  }
111 
112  if (regionsOp) {
113  strings.opRegionNames = SmallVector<std::string>(
114  llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
115  return llvm::formatv("{0}", cast<StringAttr>(attr));
116  }));
117  }
118 
119  return strings;
120 }
121 
122 /// Fills a dictionary with values from TypeStrings
123 static void fillDict(irdl::detail::dictionary &dict,
124  const TypeStrings &strings) {
125  dict["TYPE_NAME"] = strings.typeName;
126  dict["TYPE_CPP_NAME"] = strings.typeCppName;
127 }
128 
129 /// Fills a dictionary with values from OpStrings
130 static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131  const auto operandCount = strings.opOperandNames.size();
132  const auto resultCount = strings.opResultNames.size();
133  const auto regionCount = strings.opRegionNames.size();
134 
135  dict["OP_NAME"] = strings.opName;
136  dict["OP_CPP_NAME"] = strings.opCppName;
137  dict["OP_OPERAND_COUNT"] = std::to_string(strings.opOperandNames.size());
138  dict["OP_RESULT_COUNT"] = std::to_string(strings.opResultNames.size());
139  dict["OP_OPERAND_INITIALIZER_LIST"] =
140  operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
141  dict["OP_RESULT_INITIALIZER_LIST"] =
142  resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
143  dict["OP_REGION_COUNT"] = std::to_string(regionCount);
144 }
145 
146 /// Fills a dictionary with values from DialectStrings
147 static void fillDict(irdl::detail::dictionary &dict,
148  const DialectStrings &strings) {
149  dict["DIALECT_NAME"] = strings.dialectName;
150  dict["DIALECT_BASE_TYPE_NAME"] = strings.dialectBaseTypeName;
151  dict["DIALECT_CPP_NAME"] = strings.dialectCppName;
152  dict["DIALECT_CPP_SHORT_NAME"] = strings.dialectCppShortName;
153  dict["NAMESPACE_OPEN"] = strings.namespaceOpen;
154  dict["NAMESPACE_CLOSE"] = strings.namespaceClose;
155  dict["NAMESPACE_PATH"] = strings.namespacePath;
156 }
157 
158 static LogicalResult generateTypedefList(irdl::DialectOp &dialect,
159  SmallVector<std::string> &typeNames) {
160  auto typeOps = dialect.getOps<irdl::TypeOp>();
161  auto range = llvm::map_range(typeOps, typeToCppName);
162  typeNames = SmallVector<std::string>(range);
163  return success();
164 }
165 
166 static LogicalResult generateOpList(irdl::DialectOp &dialect,
167  SmallVector<std::string> &opNames) {
168  auto operationOps = dialect.getOps<irdl::OperationOp>();
169  auto range = llvm::map_range(operationOps, opToCppName);
170  opNames = SmallVector<std::string>(range);
171  return success();
172 }
173 
174 } // namespace
175 
176 static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output,
177  irdl::detail::dictionary &dict) {
178  static const auto typeDeclTemplate = irdl::detail::Template(
179 #include "Templates/TypeDecl.txt"
180  );
181 
182  fillDict(dict, getStrings(type));
183  typeDeclTemplate.render(output, dict);
184 
185  return success();
186 }
187 
189  const OpStrings &opStrings) {
190  auto opGetters = std::string{};
191  auto resGetters = std::string{};
192  auto regionGetters = std::string{};
193  auto regionAdaptorGetters = std::string{};
194 
195  for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
196  const auto op =
197  llvm::convertToCamelFromSnakeCase(opStrings.opOperandNames[i], true);
198  opGetters += llvm::formatv("::mlir::Value get{0}() { return "
199  "getStructuredOperands({1}).front(); }\n ",
200  op, i);
201  }
202  for (size_t i = 0, end = opStrings.opResultNames.size(); i < end; ++i) {
203  const auto op =
204  llvm::convertToCamelFromSnakeCase(opStrings.opResultNames[i], true);
205  resGetters += llvm::formatv(
206  R"(::mlir::Value get{0}() { return ::llvm::cast<::mlir::Value>(getStructuredResults({1}).front()); }
207  )",
208  op, i);
209  }
210 
211  for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
212  const auto op =
213  llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
214  regionAdaptorGetters += llvm::formatv(
215  R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
216  )",
217  op, i);
218  regionGetters += llvm::formatv(
219  R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
220  )",
221  op, i);
222  }
223 
224  dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
225  dict["OP_RESULT_GETTER_DECLS"] = resGetters;
226  dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
227  dict["OP_REGION_GETTER_DECLS"] = regionGetters;
228 }
229 
231  const OpStrings &opStrings) {
232  std::string buildDecls;
233  llvm::raw_string_ostream stream{buildDecls};
234 
235  auto resultParams =
236  llvm::join(llvm::map_range(opStrings.opResultNames,
237  [](StringRef name) -> std::string {
238  return llvm::formatv(
239  "::mlir::Type {0}, ",
240  llvm::convertToCamelFromSnakeCase(name));
241  }),
242  "");
243 
244  auto operandParams =
245  llvm::join(llvm::map_range(opStrings.opOperandNames,
246  [](StringRef name) -> std::string {
247  return llvm::formatv(
248  "::mlir::Value {0}, ",
249  llvm::convertToCamelFromSnakeCase(name));
250  }),
251  "");
252 
253  stream << llvm::formatv(
254  R"(static void build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {0} {1} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
255  resultParams, operandParams);
256  stream << "\n";
257  stream << llvm::formatv(
258  R"(static {0} create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
259  opStrings.opCppName, resultParams, operandParams);
260  stream << "\n";
261  stream << llvm::formatv(
262  R"(static {0} create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {{});)",
263  opStrings.opCppName, resultParams, operandParams);
264  stream << "\n";
265  dict["OP_BUILD_DECLS"] = buildDecls;
266 }
267 
268 // add traits to the dictionary, return true if any were added
269 static SmallVector<std::string> generateTraits(irdl::OperationOp op,
270  const OpStrings &strings) {
271  SmallVector<std::string> cppTraitNames;
272  if (!strings.opRegionNames.empty()) {
273  cppTraitNames.push_back(
274  llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
275  strings.opRegionNames.size())
276  .str());
277 
278  // Requires verifyInvariantsImpl is implemented on the op
279  cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
280  }
281  return cppTraitNames;
282 }
283 
284 static LogicalResult generateOperationInclude(irdl::OperationOp op,
285  raw_ostream &output,
286  irdl::detail::dictionary &dict) {
287  static const auto perOpDeclTemplate = irdl::detail::Template(
288 #include "Templates/PerOperationDecl.txt"
289  );
290  const auto opStrings = getStrings(op);
291  fillDict(dict, opStrings);
292 
293  SmallVector<std::string> traitNames = generateTraits(op, opStrings);
294  if (traitNames.empty())
295  dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
296  else
297  dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
298  llvm::join(traitNames, ", "));
299 
300  generateOpGetterDeclarations(dict, opStrings);
301  generateOpBuilderDeclarations(dict, opStrings);
302 
303  perOpDeclTemplate.render(output, dict);
304  return success();
305 }
306 
307 static LogicalResult generateInclude(irdl::DialectOp dialect,
308  raw_ostream &output,
309  DialectStrings &dialectStrings) {
310  static const auto dialectDeclTemplate = irdl::detail::Template(
311 #include "Templates/DialectDecl.txt"
312  );
313  static const auto typeHeaderDeclTemplate = irdl::detail::Template(
314 #include "Templates/TypeHeaderDecl.txt"
315  );
316 
318  fillDict(dict, dialectStrings);
319 
320  dialectDeclTemplate.render(output, dict);
321  typeHeaderDeclTemplate.render(output, dict);
322 
323  auto typeOps = dialect.getOps<irdl::TypeOp>();
324  auto operationOps = dialect.getOps<irdl::OperationOp>();
325 
326  for (auto &&typeOp : typeOps) {
327  if (failed(generateTypeInclude(typeOp, output, dict)))
328  return failure();
329  }
330 
331  SmallVector<std::string> opNames;
332  if (failed(generateOpList(dialect, opNames)))
333  return failure();
334 
335  auto classDeclarations =
336  llvm::join(llvm::map_range(opNames,
337  [](llvm::StringRef name) -> std::string {
338  return llvm::formatv("class {0};", name);
339  }),
340  "\n");
341  const auto forwardDeclarations = llvm::formatv(
342  "{1}\n{0}\n{2}", std::move(classDeclarations),
343  dialectStrings.namespaceOpen, dialectStrings.namespaceClose);
344 
345  output << forwardDeclarations;
346  for (auto &&operationOp : operationOps) {
347  if (failed(generateOperationInclude(operationOp, output, dict)))
348  return failure();
349  }
350 
351  return success();
352 }
353 
355  irdl::detail::dictionary &dict, irdl::OperationOp op,
356  const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
357  SmallVectorImpl<std::string> &verifierCalls) {
358  auto regionsOp = op.getOp<irdl::RegionsOp>();
359  if (strings.opRegionNames.empty() || !regionsOp)
360  return;
361 
362  for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
363  std::string regionName = strings.opRegionNames[i];
364  std::string helperFnName =
365  llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
366  strings.opCppName, regionName)
367  .str();
368 
369  // Extract the actual region constraint from the IRDL RegionOp
370  std::string condition = "true";
371  std::string textualConditionName = "any region";
372 
373  if (auto regionDefOp =
374  dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
375  // Generate constraint condition based on RegionOp attributes
376  SmallVector<std::string> conditionParts;
377  SmallVector<std::string> descriptionParts;
378 
379  // Check number of blocks constraint
380  if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
381  conditionParts.push_back(
382  llvm::formatv("region.getBlocks().size() == {0}",
383  blockCount.value())
384  .str());
385  descriptionParts.push_back(
386  llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
387  }
388 
389  // Check entry block arguments constraint
390  if (regionDefOp.getConstrainedArguments()) {
391  size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
392  conditionParts.push_back(
393  llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
394  .str());
395  descriptionParts.push_back(
396  llvm::formatv("{0} entry block argument(s)", expectedArgCount)
397  .str());
398  }
399 
400  // Combine conditions
401  if (!conditionParts.empty()) {
402  condition = llvm::join(conditionParts, " && ");
403  }
404 
405  // Generate descriptive error message
406  if (!descriptionParts.empty()) {
407  textualConditionName =
408  llvm::formatv("region with {0}",
409  llvm::join(descriptionParts, " and "))
410  .str();
411  }
412  }
413 
414  verifierHelpers.push_back(llvm::formatv(
415  R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
416  if (!({1})) {{
417  return op->emitOpError("region #") << regionIndex
418  << (regionName.empty() ? " " : " ('" + regionName + "') ")
419  << "failed to verify constraint: {2}";
420  }
421  return ::mlir::success();
422 })",
423  helperFnName, condition, textualConditionName));
424 
425  verifierCalls.push_back(llvm::formatv(R"(
426  if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
427  return ::mlir::failure();)",
428  helperFnName, i, regionName)
429  .str());
430  }
431 }
432 
434  irdl::OperationOp op, const OpStrings &strings) {
435  SmallVector<std::string> verifierHelpers;
436  SmallVector<std::string> verifierCalls;
437 
438  generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
439  verifierCalls);
440 
441  // Add an overall verifier that sequences the helper calls
442  std::string verifierDef =
443  llvm::formatv(R"(
444 ::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
445  if(::mlir::failed(verify()))
446  return ::mlir::failure();
447 
448  {1}
449 
450  return ::mlir::success();
451 })",
452  strings.opCppName, llvm::join(verifierCalls, "\n"));
453 
454  dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
455  dict["OP_VERIFIER"] = verifierDef;
456 }
457 
458 static std::string generateOpDefinition(irdl::detail::dictionary &dict,
459  irdl::OperationOp op) {
460  static const auto perOpDefTemplate = mlir::irdl::detail::Template{
461 #include "Templates/PerOperationDef.txt"
462  };
463 
464  auto opStrings = getStrings(op);
465  fillDict(dict, opStrings);
466 
467  auto resultTypes = llvm::join(
468  llvm::map_range(opStrings.opResultNames,
469  [](StringRef attr) -> std::string {
470  return llvm::formatv("::mlir::Type {0}, ", attr);
471  }),
472  "");
473  auto operandTypes = llvm::join(
474  llvm::map_range(opStrings.opOperandNames,
475  [](StringRef attr) -> std::string {
476  return llvm::formatv("::mlir::Value {0}, ", attr);
477  }),
478  "");
479  auto operandAdder =
480  llvm::join(llvm::map_range(opStrings.opOperandNames,
481  [](StringRef attr) -> std::string {
482  return llvm::formatv(
483  " opState.addOperands({0});", attr);
484  }),
485  "\n");
486  auto resultAdder = llvm::join(
487  llvm::map_range(opStrings.opResultNames,
488  [](StringRef attr) -> std::string {
489  return llvm::formatv(" opState.addTypes({0});", attr);
490  }),
491  "\n");
492 
493  const auto buildDefinition = llvm::formatv(
494  R"(
495 void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
496 {3}
497 {4}
498 }
499 
500 {0} {0}::create(::mlir::OpBuilder &opBuilder, ::mlir::Location location, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
501  ::mlir::OperationState __state__(location, getOperationName());
502  build(opBuilder, __state__, {5} {6} attributes);
503  auto __res__ = ::llvm::dyn_cast<{0}>(opBuilder.create(__state__));
504  assert(__res__ && "builder didn't return the right type");
505  return __res__;
506 }
507 
508 {0} {0}::create(::mlir::ImplicitLocOpBuilder &opBuilder, {1} {2} ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {{
509  return create(opBuilder, opBuilder.getLoc(), {5} {6} attributes);
510 }
511 )",
512  opStrings.opCppName, std::move(resultTypes), std::move(operandTypes),
513  std::move(operandAdder), std::move(resultAdder),
514  llvm::join(opStrings.opResultNames, ",") +
515  (!opStrings.opResultNames.empty() ? "," : ""),
516  llvm::join(opStrings.opOperandNames, ",") +
517  (!opStrings.opOperandNames.empty() ? "," : ""));
518 
519  dict["OP_BUILD_DEFS"] = buildDefinition;
520 
521  generateVerifiers(dict, op, opStrings);
522 
523  std::string str;
524  llvm::raw_string_ostream stream{str};
525  perOpDefTemplate.render(stream, dict);
526  return str;
527 }
528 
529 static std::string
530 generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings) {
531  return llvm::formatv(
532  R"(.Case({1}::{0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
533 value = {1}::{0}::get(parser.getContext());
534 return ::mlir::success(!!value);
535 }))",
536  name, dialectStrings.namespacePath);
537 }
538 
539 static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
540  DialectStrings &dialectStrings) {
541 
542  static const auto typeHeaderDefTemplate = mlir::irdl::detail::Template{
543 #include "Templates/TypeHeaderDef.txt"
544  };
545  static const auto typeDefTemplate = mlir::irdl::detail::Template{
546 #include "Templates/TypeDef.txt"
547  };
548  static const auto dialectDefTemplate = mlir::irdl::detail::Template{
549 #include "Templates/DialectDef.txt"
550  };
551 
553  fillDict(dict, dialectStrings);
554 
555  typeHeaderDefTemplate.render(output, dict);
556 
557  SmallVector<std::string> typeNames;
558  if (failed(generateTypedefList(dialect, typeNames)))
559  return failure();
560 
561  dict["TYPE_LIST"] = llvm::join(
562  llvm::map_range(typeNames,
563  [&dialectStrings](llvm::StringRef name) -> std::string {
564  return llvm::formatv(
565  "{0}::{1}", dialectStrings.namespacePath, name);
566  }),
567  ",\n");
568 
569  auto typeVerifierGenerator =
570  [&dialectStrings](llvm::StringRef name) -> std::string {
571  return generateTypeVerifierCase(name, dialectStrings);
572  };
573 
574  auto typeCase =
575  llvm::join(llvm::map_range(typeNames, typeVerifierGenerator), "\n");
576 
577  dict["TYPE_PARSER"] = llvm::formatv(
578  R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
579  return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
580  {0}
581  .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
582  *mnemonic = keyword;
583  return std::nullopt;
584  });
585 })",
586  std::move(typeCase));
587 
588  auto typePrintCase =
589  llvm::join(llvm::map_range(typeNames,
590  [&](llvm::StringRef name) -> std::string {
591  return llvm::formatv(
592  R"(.Case<{1}::{0}>([&](auto t) {
593  printer << {1}::{0}::getMnemonic();
594  return ::mlir::success();
595  }))",
596  name, dialectStrings.namespacePath);
597  }),
598  "\n");
599  dict["TYPE_PRINTER"] = llvm::formatv(
600  R"(static ::llvm::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
601  return ::llvm::TypeSwitch<::mlir::Type, ::llvm::LogicalResult>(def)
602  {0}
603  .Default([](auto) {{ return ::mlir::failure(); });
604 })",
605  std::move(typePrintCase));
606 
607  dict["TYPE_DEFINES"] =
608  join(map_range(typeNames,
609  [&](StringRef name) -> std::string {
610  return formatv("MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
611  name, dialectStrings.namespacePath);
612  }),
613  "\n");
614 
615  typeDefTemplate.render(output, dict);
616 
617  auto operations = dialect.getOps<irdl::OperationOp>();
618  SmallVector<std::string> opNames;
619  if (failed(generateOpList(dialect, opNames)))
620  return failure();
621 
622  const auto commaSeparatedOpList = llvm::join(
623  map_range(opNames,
624  [&dialectStrings](llvm::StringRef name) -> std::string {
625  return llvm::formatv("{0}::{1}", dialectStrings.namespacePath,
626  name);
627  }),
628  ",\n");
629 
630  const auto opDefinitionGenerator = [&dict](irdl::OperationOp op) {
631  return generateOpDefinition(dict, op);
632  };
633 
634  const auto perOpDefinitions =
635  llvm::join(llvm::map_range(operations, opDefinitionGenerator), "\n");
636 
637  dict["OP_LIST"] = commaSeparatedOpList;
638  dict["OP_CLASSES"] = perOpDefinitions;
639  output << perOpDefinitions;
640  dialectDefTemplate.render(output, dict);
641 
642  return success();
643 }
644 
645 static LogicalResult verifySupported(irdl::DialectOp dialect) {
646  LogicalResult res = success();
647  dialect.walk([&](mlir::Operation *op) {
648  res =
650  .Case<irdl::DialectOp>(([](irdl::DialectOp) { return success(); }))
651  .Case<irdl::OperationOp>(
652  ([](irdl::OperationOp) { return success(); }))
653  .Case<irdl::TypeOp>(([](irdl::TypeOp) { return success(); }))
654  .Case<irdl::OperandsOp>(([](irdl::OperandsOp op) -> LogicalResult {
655  if (llvm::all_of(
656  op.getVariadicity(), [](irdl::VariadicityAttr attr) {
657  return attr.getValue() == irdl::Variadicity::single;
658  }))
659  return success();
660  return op.emitError("IRDL C++ translation does not yet support "
661  "variadic operations");
662  }))
663  .Case<irdl::ResultsOp>(([](irdl::ResultsOp op) -> LogicalResult {
664  if (llvm::all_of(
665  op.getVariadicity(), [](irdl::VariadicityAttr attr) {
666  return attr.getValue() == irdl::Variadicity::single;
667  }))
668  return success();
669  return op.emitError(
670  "IRDL C++ translation does not yet support variadic results");
671  }))
672  .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
673  .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
674  .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
675  .Default([](mlir::Operation *op) -> LogicalResult {
676  return op->emitError("IRDL C++ translation does not yet support "
677  "translation of ")
678  << op->getName() << " operation";
679  });
680 
681  if (failed(res))
682  return WalkResult::interrupt();
683 
684  return WalkResult::advance();
685  });
686 
687  return res;
688 }
689 
690 LogicalResult
692  raw_ostream &output) {
693  static const auto typeDefTempl = detail::Template(
694 #include "Templates/TypeDef.txt"
695  );
696 
697  llvm::SmallMapVector<DialectOp, DialectStrings, 2> dialectStringTable;
698 
699  for (auto dialect : dialects) {
700  if (failed(verifySupported(dialect)))
701  return failure();
702 
703  StringRef dialectName = dialect.getSymName();
704 
705  SmallVector<SmallString<8>> namespaceAbsolutePath{{"mlir"}, dialectName};
706  std::string namespaceOpen;
707  std::string namespaceClose;
708  std::string namespacePath;
709  llvm::raw_string_ostream namespaceOpenStream(namespaceOpen);
710  llvm::raw_string_ostream namespaceCloseStream(namespaceClose);
711  llvm::raw_string_ostream namespacePathStream(namespacePath);
712  for (auto &pathElement : namespaceAbsolutePath) {
713  namespaceOpenStream << "namespace " << pathElement << " {\n";
714  namespaceCloseStream << "} // namespace " << pathElement << "\n";
715  namespacePathStream << "::" << pathElement;
716  }
717 
718  std::string cppShortName =
719  llvm::convertToCamelFromSnakeCase(dialectName, true);
720  std::string dialectBaseTypeName = llvm::formatv("{0}Type", cppShortName);
721  std::string cppName = llvm::formatv("{0}Dialect", cppShortName);
722 
723  DialectStrings dialectStrings;
724  dialectStrings.dialectName = dialectName;
725  dialectStrings.dialectBaseTypeName = dialectBaseTypeName;
726  dialectStrings.dialectCppName = cppName;
727  dialectStrings.dialectCppShortName = cppShortName;
728  dialectStrings.namespaceOpen = namespaceOpen;
729  dialectStrings.namespaceClose = namespaceClose;
730  dialectStrings.namespacePath = namespacePath;
731 
732  dialectStringTable[dialect] = std::move(dialectStrings);
733  }
734 
735  // generate the actual header
736  output << headerTemplateText;
737 
738  output << llvm::formatv("#ifdef {0}\n#undef {0}\n", declarationMacroFlag);
739  for (auto dialect : dialects) {
740 
741  auto &dialectStrings = dialectStringTable[dialect];
742  auto &dialectName = dialectStrings.dialectName;
743 
744  if (failed(generateInclude(dialect, output, dialectStrings)))
745  return dialect->emitError("Error in Dialect " + dialectName +
746  " while generating headers");
747  }
748  output << llvm::formatv("#endif // #ifdef {}\n", declarationMacroFlag);
749 
750  output << llvm::formatv("#ifdef {0}\n#undef {0}\n ", definitionMacroFlag);
751  for (auto &dialect : dialects) {
752  auto &dialectStrings = dialectStringTable[dialect];
753  auto &dialectName = dialectStrings.dialectName;
754 
755  if (failed(generateLib(dialect, output, dialectStrings)))
756  return dialect->emitError("Error in Dialect " + dialectName +
757  " while generating library");
758  }
759  output << llvm::formatv("#endif // #ifdef {}\n", definitionMacroFlag);
760 
761  return success();
762 }
static LogicalResult verifySupported(irdl::DialectOp dialect)
Definition: IRDLToCpp.cpp:591
static LogicalResult generateInclude(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
Definition: IRDLToCpp.cpp:304
static SmallVector< std::string > generateTraits(irdl::OperationOp op, const OpStrings &strings)
Definition: IRDLToCpp.cpp:266
static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict)
Definition: IRDLToCpp.cpp:281
static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
Definition: IRDLToCpp.cpp:188
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings)
Definition: IRDLToCpp.cpp:227
constexpr char declarationMacroFlag[]
Definition: IRDLToCpp.cpp:28
constexpr char headerTemplateText[]
Definition: IRDLToCpp.cpp:24
static std::string generateTypeVerifierCase(StringRef name, const DialectStrings &dialectStrings)
Definition: IRDLToCpp.cpp:493
static LogicalResult generateTypeInclude(irdl::TypeOp type, raw_ostream &output, irdl::detail::dictionary &dict)
Definition: IRDLToCpp.cpp:176
static void generateRegionConstraintVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings, SmallVectorImpl< std::string > &verifierHelpers, SmallVectorImpl< std::string > &verifierCalls)
Definition: IRDLToCpp.cpp:351
static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, DialectStrings &dialectStrings)
Definition: IRDLToCpp.cpp:499
constexpr char definitionMacroFlag[]
Definition: IRDLToCpp.cpp:29
static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op)
Definition: IRDLToCpp.cpp:448
static void generateVerifiers(irdl::detail::dictionary &dict, irdl::OperationOp op, const OpStrings &strings)
Definition: IRDLToCpp.cpp:423
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Template Code as used by IRDL-to-Cpp.
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
LogicalResult translateIRDLDialectToCpp(llvm::ArrayRef< irdl::DialectOp > dialects, raw_ostream &output)
Translates an IRDL dialect definition to a C++ definition that can be used with MLIR.
Definition: IRDLToCpp.cpp:637
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.