MLIR  20.0.0git
IRModule.cpp
Go to the documentation of this file.
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include <optional>
12 #include <vector>
13 
14 #include "Globals.h"
15 #include "NanobindUtils.h"
16 #include "mlir-c/Support.h"
18 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
19 
20 namespace nb = nanobind;
21 using namespace mlir;
22 using namespace mlir::python;
23 
24 // -----------------------------------------------------------------------------
25 // PyGlobals
26 // -----------------------------------------------------------------------------
27 
28 PyGlobals *PyGlobals::instance = nullptr;
29 
31  assert(!instance && "PyGlobals already constructed");
32  instance = this;
33  // The default search path include {mlir.}dialects, where {mlir.} is the
34  // package prefix configured at compile time.
35  dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
36 }
37 
38 PyGlobals::~PyGlobals() { instance = nullptr; }
39 
40 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41  if (loadedDialectModules.contains(dialectNamespace))
42  return true;
43  // Since re-entrancy is possible, make a copy of the search prefixes.
44  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
45  nb::object loaded = nb::none();
46  for (std::string moduleName : localSearchPrefixes) {
47  moduleName.push_back('.');
48  moduleName.append(dialectNamespace.data(), dialectNamespace.size());
49 
50  try {
51  loaded = nb::module_::import_(moduleName.c_str());
52  } catch (nb::python_error &e) {
53  if (e.matches(PyExc_ModuleNotFoundError)) {
54  continue;
55  }
56  throw;
57  }
58  break;
59  }
60 
61  if (loaded.is_none())
62  return false;
63  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
64  // may have occurred, which may do anything.
65  loadedDialectModules.insert(dialectNamespace);
66  return true;
67 }
68 
69 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
70  nb::callable pyFunc, bool replace) {
71  nb::object &found = attributeBuilderMap[attributeKind];
72  if (found && !replace) {
73  throw std::runtime_error((llvm::Twine("Attribute builder for '") +
74  attributeKind +
75  "' is already registered with func: " +
76  nb::cast<std::string>(nb::str(found)))
77  .str());
78  }
79  found = std::move(pyFunc);
80 }
81 
82 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
83  nb::callable typeCaster, bool replace) {
84  nb::object &found = typeCasterMap[mlirTypeID];
85  if (found && !replace)
86  throw std::runtime_error("Type caster is already registered with caster: " +
87  nb::cast<std::string>(nb::str(found)));
88  found = std::move(typeCaster);
89 }
90 
91 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
92  nb::callable valueCaster, bool replace) {
93  nb::object &found = valueCasterMap[mlirTypeID];
94  if (found && !replace)
95  throw std::runtime_error("Value caster is already registered: " +
96  nb::cast<std::string>(nb::repr(found)));
97  found = std::move(valueCaster);
98 }
99 
100 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
101  nb::object pyClass) {
102  nb::object &found = dialectClassMap[dialectNamespace];
103  if (found) {
104  throw std::runtime_error((llvm::Twine("Dialect namespace '") +
105  dialectNamespace + "' is already registered.")
106  .str());
107  }
108  found = std::move(pyClass);
109 }
110 
111 void PyGlobals::registerOperationImpl(const std::string &operationName,
112  nb::object pyClass, bool replace) {
113  nb::object &found = operationClassMap[operationName];
114  if (found && !replace) {
115  throw std::runtime_error((llvm::Twine("Operation '") + operationName +
116  "' is already registered.")
117  .str());
118  }
119  found = std::move(pyClass);
120 }
121 
122 std::optional<nb::callable>
123 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
124  const auto foundIt = attributeBuilderMap.find(attributeKind);
125  if (foundIt != attributeBuilderMap.end()) {
126  assert(foundIt->second && "attribute builder is defined");
127  return foundIt->second;
128  }
129  return std::nullopt;
130 }
131 
132 std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
133  MlirDialect dialect) {
134  // Try to load dialect module.
136  const auto foundIt = typeCasterMap.find(mlirTypeID);
137  if (foundIt != typeCasterMap.end()) {
138  assert(foundIt->second && "type caster is defined");
139  return foundIt->second;
140  }
141  return std::nullopt;
142 }
143 
144 std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
145  MlirDialect dialect) {
146  // Try to load dialect module.
148  const auto foundIt = valueCasterMap.find(mlirTypeID);
149  if (foundIt != valueCasterMap.end()) {
150  assert(foundIt->second && "value caster is defined");
151  return foundIt->second;
152  }
153  return std::nullopt;
154 }
155 
156 std::optional<nb::object>
157 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
158  // Make sure dialect module is loaded.
159  if (!loadDialectModule(dialectNamespace))
160  return std::nullopt;
161  const auto foundIt = dialectClassMap.find(dialectNamespace);
162  if (foundIt != dialectClassMap.end()) {
163  assert(foundIt->second && "dialect class is defined");
164  return foundIt->second;
165  }
166  // Not found and loading did not yield a registration.
167  return std::nullopt;
168 }
169 
170 std::optional<nb::object>
171 PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
172  // Make sure dialect module is loaded.
173  auto split = operationName.split('.');
174  llvm::StringRef dialectNamespace = split.first;
175  if (!loadDialectModule(dialectNamespace))
176  return std::nullopt;
177 
178  auto foundIt = operationClassMap.find(operationName);
179  if (foundIt != operationClassMap.end()) {
180  assert(foundIt->second && "OpView is defined");
181  return foundIt->second;
182  }
183  // Not found and loading did not yield a registration.
184  return std::nullopt;
185 }
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
Globals that are always accessible once the extension has been initialized.
Definition: Globals.h:27
bool loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition: IRModule.cpp:40
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition: IRModule.cpp:82
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition: IRModule.cpp:111
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition: IRModule.cpp:69
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition: IRModule.cpp:91
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:144
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition: IRModule.cpp:157
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:171
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition: IRModule.cpp:123
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition: IRModule.cpp:100
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:132
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:128
Include the generated interface declarations.