MLIR  22.0.0git
MainModule.cpp
Go to the documentation of this file.
1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Globals.h"
10 #include "IRModule.h"
11 #include "NanobindUtils.h"
12 #include "Pass.h"
13 #include "Rewrite.h"
15 
16 namespace nb = nanobind;
17 using namespace mlir;
18 using namespace nb::literals;
19 using namespace mlir::python;
20 
21 // -----------------------------------------------------------------------------
22 // Module initialization.
23 // -----------------------------------------------------------------------------
24 
25 NB_MODULE(_mlir, m) {
26  m.doc() = "MLIR Python Native Extension";
27 
28  nb::class_<PyGlobals>(m, "_Globals")
29  .def_prop_rw("dialect_search_modules",
30  &PyGlobals::getDialectSearchPrefixes,
31  &PyGlobals::setDialectSearchPrefixes)
32  .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
33  "module_name"_a)
34  .def(
35  "_check_dialect_module_loaded",
36  [](PyGlobals &self, const std::string &dialectNamespace) {
37  return self.loadDialectModule(dialectNamespace);
38  },
39  "dialect_namespace"_a)
40  .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
41  "dialect_namespace"_a, "dialect_class"_a,
42  "Testing hook for directly registering a dialect")
43  .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
44  "operation_name"_a, "operation_class"_a, nb::kw_only(),
45  "replace"_a = false,
46  "Testing hook for directly registering an operation")
47  .def("loc_tracebacks_enabled",
48  [](PyGlobals &self) {
49  return self.getTracebackLoc().locTracebacksEnabled();
50  })
51  .def("set_loc_tracebacks_enabled",
52  [](PyGlobals &self, bool enabled) {
53  self.getTracebackLoc().setLocTracebacksEnabled(enabled);
54  })
55  .def("loc_tracebacks_frame_limit",
56  [](PyGlobals &self) {
57  return self.getTracebackLoc().locTracebackFramesLimit();
58  })
59  .def("set_loc_tracebacks_frame_limit",
60  [](PyGlobals &self, std::optional<int> n) {
61  self.getTracebackLoc().setLocTracebackFramesLimit(
62  n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
63  })
64  .def("register_traceback_file_inclusion",
65  [](PyGlobals &self, const std::string &filename) {
66  self.getTracebackLoc().registerTracebackFileInclusion(filename);
67  })
68  .def("register_traceback_file_exclusion",
69  [](PyGlobals &self, const std::string &filename) {
70  self.getTracebackLoc().registerTracebackFileExclusion(filename);
71  });
72 
73  // Aside from making the globals accessible to python, having python manage
74  // it is necessary to make sure it is destroyed (and releases its python
75  // resources) properly.
76  m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
77 
78  // Registration decorators.
79  m.def(
80  "register_dialect",
81  [](nb::type_object pyClass) {
82  std::string dialectNamespace =
83  nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
84  PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
85  return pyClass;
86  },
87  "dialect_class"_a,
88  "Class decorator for registering a custom Dialect wrapper");
89  m.def(
90  "register_operation",
91  [](const nb::type_object &dialectClass, bool replace) -> nb::object {
92  return nb::cpp_function(
93  [dialectClass,
94  replace](nb::type_object opClass) -> nb::type_object {
95  std::string operationName =
96  nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
97  PyGlobals::get().registerOperationImpl(operationName, opClass,
98  replace);
99  // Dict-stuff the new opClass by name onto the dialect class.
100  nb::object opClassName = opClass.attr("__name__");
101  dialectClass.attr(opClassName) = opClass;
102  return opClass;
103  });
104  },
105  "dialect_class"_a, nb::kw_only(), "replace"_a = false,
106  "Produce a class decorator for registering an Operation class as part of "
107  "a dialect");
108  m.def(
110  [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
111  return nb::cpp_function([mlirTypeID, replace](
112  nb::callable typeCaster) -> nb::object {
113  PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
114  return typeCaster;
115  });
116  },
117  "typeid"_a, nb::kw_only(), "replace"_a = false,
118  "Register a type caster for casting MLIR types to custom user types.");
119  m.def(
121  [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
122  return nb::cpp_function(
123  [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
124  PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
125  replace);
126  return valueCaster;
127  });
128  },
129  "typeid"_a, nb::kw_only(), "replace"_a = false,
130  "Register a value caster for casting MLIR values to custom user values.");
131 
132  // Define and populate IR submodule.
133  auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
134  populateIRCore(irModule);
135  populateIRAffine(irModule);
136  populateIRAttributes(irModule);
137  populateIRInterfaces(irModule);
138  populateIRTypes(irModule);
139 
140  auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
141  populateRewriteSubmodule(rewriteModule);
142 
143  // Define and populate PassManager submodule.
144  auto passManagerModule =
145  m.def_submodule("passmanager", "MLIR Pass Management Bindings");
146  populatePassManagerSubmodule(passManagerModule);
147 }
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the value caster registration bindin...
Definition: Interop.h:142
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Definition: Interop.h:130
NB_MODULE(_mlir, m)
Definition: MainModule.cpp:25
Globals that are always accessible once the extension has been initialized.
Definition: Globals.h:33
void populateIRAttributes(nanobind::module_ &m)
void populateIRInterfaces(nb::module_ &m)
void populatePassManagerSubmodule(nanobind::module_ &m)
void populateIRAffine(nanobind::module_ &m)
void populateRewriteSubmodule(nanobind::module_ &m)
void populateIRTypes(nanobind::module_ &m)
void populateIRCore(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...