MLIR  20.0.0git
IRModule.cpp
Go to the documentation of this file.
1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include <optional>
12 #include <vector>
13 
14 #include "Globals.h"
15 #include "NanobindUtils.h"
16 #include "mlir-c/Support.h"
18 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
19 
20 namespace nb = nanobind;
21 using namespace mlir;
22 using namespace mlir::python;
23 
24 // -----------------------------------------------------------------------------
25 // PyGlobals
26 // -----------------------------------------------------------------------------
27 
28 PyGlobals *PyGlobals::instance = nullptr;
29 
31  assert(!instance && "PyGlobals already constructed");
32  instance = this;
33  // The default search path include {mlir.}dialects, where {mlir.} is the
34  // package prefix configured at compile time.
35  dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
36 }
37 
38 PyGlobals::~PyGlobals() { instance = nullptr; }
39 
40 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41  {
42  nb::ft_lock_guard lock(mutex);
43  if (loadedDialectModules.contains(dialectNamespace))
44  return true;
45  }
46  // Since re-entrancy is possible, make a copy of the search prefixes.
47  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
48  nb::object loaded = nb::none();
49  for (std::string moduleName : localSearchPrefixes) {
50  moduleName.push_back('.');
51  moduleName.append(dialectNamespace.data(), dialectNamespace.size());
52 
53  try {
54  loaded = nb::module_::import_(moduleName.c_str());
55  } catch (nb::python_error &e) {
56  if (e.matches(PyExc_ModuleNotFoundError)) {
57  continue;
58  }
59  throw;
60  }
61  break;
62  }
63 
64  if (loaded.is_none())
65  return false;
66  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
67  // may have occurred, which may do anything.
68  nb::ft_lock_guard lock(mutex);
69  loadedDialectModules.insert(dialectNamespace);
70  return true;
71 }
72 
73 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
74  nb::callable pyFunc, bool replace) {
75  nb::ft_lock_guard lock(mutex);
76  nb::object &found = attributeBuilderMap[attributeKind];
77  if (found && !replace) {
78  throw std::runtime_error((llvm::Twine("Attribute builder for '") +
79  attributeKind +
80  "' is already registered with func: " +
81  nb::cast<std::string>(nb::str(found)))
82  .str());
83  }
84  found = std::move(pyFunc);
85 }
86 
87 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
88  nb::callable typeCaster, bool replace) {
89  nb::ft_lock_guard lock(mutex);
90  nb::object &found = typeCasterMap[mlirTypeID];
91  if (found && !replace)
92  throw std::runtime_error("Type caster is already registered with caster: " +
93  nb::cast<std::string>(nb::str(found)));
94  found = std::move(typeCaster);
95 }
96 
97 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
98  nb::callable valueCaster, bool replace) {
99  nb::ft_lock_guard lock(mutex);
100  nb::object &found = valueCasterMap[mlirTypeID];
101  if (found && !replace)
102  throw std::runtime_error("Value caster is already registered: " +
103  nb::cast<std::string>(nb::repr(found)));
104  found = std::move(valueCaster);
105 }
106 
107 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
108  nb::object pyClass) {
109  nb::ft_lock_guard lock(mutex);
110  nb::object &found = dialectClassMap[dialectNamespace];
111  if (found) {
112  throw std::runtime_error((llvm::Twine("Dialect namespace '") +
113  dialectNamespace + "' is already registered.")
114  .str());
115  }
116  found = std::move(pyClass);
117 }
118 
119 void PyGlobals::registerOperationImpl(const std::string &operationName,
120  nb::object pyClass, bool replace) {
121  nb::ft_lock_guard lock(mutex);
122  nb::object &found = operationClassMap[operationName];
123  if (found && !replace) {
124  throw std::runtime_error((llvm::Twine("Operation '") + operationName +
125  "' is already registered.")
126  .str());
127  }
128  found = std::move(pyClass);
129 }
130 
131 std::optional<nb::callable>
132 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133  nb::ft_lock_guard lock(mutex);
134  const auto foundIt = attributeBuilderMap.find(attributeKind);
135  if (foundIt != attributeBuilderMap.end()) {
136  assert(foundIt->second && "attribute builder is defined");
137  return foundIt->second;
138  }
139  return std::nullopt;
140 }
141 
142 std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
143  MlirDialect dialect) {
144  // Try to load dialect module.
146  nb::ft_lock_guard lock(mutex);
147  const auto foundIt = typeCasterMap.find(mlirTypeID);
148  if (foundIt != typeCasterMap.end()) {
149  assert(foundIt->second && "type caster is defined");
150  return foundIt->second;
151  }
152  return std::nullopt;
153 }
154 
155 std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
156  MlirDialect dialect) {
157  // Try to load dialect module.
159  nb::ft_lock_guard lock(mutex);
160  const auto foundIt = valueCasterMap.find(mlirTypeID);
161  if (foundIt != valueCasterMap.end()) {
162  assert(foundIt->second && "value caster is defined");
163  return foundIt->second;
164  }
165  return std::nullopt;
166 }
167 
168 std::optional<nb::object>
169 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
170  // Make sure dialect module is loaded.
171  if (!loadDialectModule(dialectNamespace))
172  return std::nullopt;
173  nb::ft_lock_guard lock(mutex);
174  const auto foundIt = dialectClassMap.find(dialectNamespace);
175  if (foundIt != dialectClassMap.end()) {
176  assert(foundIt->second && "dialect class is defined");
177  return foundIt->second;
178  }
179  // Not found and loading did not yield a registration.
180  return std::nullopt;
181 }
182 
183 std::optional<nb::object>
184 PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
185  // Make sure dialect module is loaded.
186  auto split = operationName.split('.');
187  llvm::StringRef dialectNamespace = split.first;
188  if (!loadDialectModule(dialectNamespace))
189  return std::nullopt;
190 
191  nb::ft_lock_guard lock(mutex);
192  auto foundIt = operationClassMap.find(operationName);
193  if (foundIt != operationClassMap.end()) {
194  assert(foundIt->second && "OpView is defined");
195  return foundIt->second;
196  }
197  // Not found and loading did not yield a registration.
198  return std::nullopt;
199 }
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
Globals that are always accessible once the extension has been initialized.
Definition: Globals.h:28
bool loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition: IRModule.cpp:40
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition: IRModule.cpp:87
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition: IRModule.cpp:119
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition: IRModule.cpp:73
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition: IRModule.cpp:97
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:155
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition: IRModule.cpp:169
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition: IRModule.cpp:184
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition: IRModule.cpp:132
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition: IRModule.cpp:107
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition: IRModule.cpp:142
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition: IR.cpp:128
Include the generated interface declarations.