MLIR 22.0.0git
Globals.cpp
Go to the documentation of this file.
1//===- IRModule.cpp - IR pybind module ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include <optional>
12#include <vector>
13
15// clang-format off
18// clang-format on
19#include "mlir-c/Support.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24
25// -----------------------------------------------------------------------------
26// PyGlobals
27// -----------------------------------------------------------------------------
28
29namespace mlir {
30namespace python {
32PyGlobals *PyGlobals::instance = nullptr;
33
35 assert(!instance && "PyGlobals already constructed");
36 instance = this;
37 // The default search path include {mlir.}dialects, where {mlir.} is the
38 // package prefix configured at compile time.
39 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
40}
41
42PyGlobals::~PyGlobals() { instance = nullptr; }
43
45 assert(instance && "PyGlobals is null");
46 return *instance;
47}
48
49bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
50 {
51 nb::ft_lock_guard lock(mutex);
52 if (loadedDialectModules.contains(dialectNamespace))
53 return true;
54 }
55 // Since re-entrancy is possible, make a copy of the search prefixes.
56 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
57 nb::object loaded = nb::none();
58 for (std::string moduleName : localSearchPrefixes) {
59 moduleName.push_back('.');
60 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
61
62 try {
63 loaded = nb::module_::import_(moduleName.c_str());
64 } catch (nb::python_error &e) {
65 if (e.matches(PyExc_ModuleNotFoundError)) {
66 continue;
67 }
68 throw;
69 }
70 break;
71 }
72
73 if (loaded.is_none())
74 return false;
75 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
76 // may have occurred, which may do anything.
77 nb::ft_lock_guard lock(mutex);
78 loadedDialectModules.insert(dialectNamespace);
79 return true;
80}
81
82void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
83 nb::callable pyFunc, bool replace) {
84 nb::ft_lock_guard lock(mutex);
85 nb::object &found = attributeBuilderMap[attributeKind];
86 if (found && !replace) {
87 throw std::runtime_error((llvm::Twine("Attribute builder for '") +
88 attributeKind +
89 "' is already registered with func: " +
90 nb::cast<std::string>(nb::str(found)))
91 .str());
92 }
93 found = std::move(pyFunc);
94}
95
96void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
97 nb::callable typeCaster, bool replace) {
98 nb::ft_lock_guard lock(mutex);
99 nb::object &found = typeCasterMap[mlirTypeID];
100 if (found && !replace)
101 throw std::runtime_error("Type caster is already registered with caster: " +
102 nb::cast<std::string>(nb::str(found)));
103 found = std::move(typeCaster);
104}
105
106void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
107 nb::callable valueCaster, bool replace) {
108 nb::ft_lock_guard lock(mutex);
109 nb::object &found = valueCasterMap[mlirTypeID];
110 if (found && !replace)
111 throw std::runtime_error("Value caster is already registered: " +
112 nb::cast<std::string>(nb::repr(found)));
113 found = std::move(valueCaster);
114}
115
116void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
117 nb::object pyClass) {
118 nb::ft_lock_guard lock(mutex);
119 nb::object &found = dialectClassMap[dialectNamespace];
120 if (found) {
121 throw std::runtime_error((llvm::Twine("Dialect namespace '") +
122 dialectNamespace + "' is already registered.")
123 .str());
124 }
125 found = std::move(pyClass);
126}
127
128void PyGlobals::registerOperationImpl(const std::string &operationName,
129 nb::object pyClass, bool replace) {
130 nb::ft_lock_guard lock(mutex);
131 nb::object &found = operationClassMap[operationName];
132 if (found && !replace) {
133 throw std::runtime_error((llvm::Twine("Operation '") + operationName +
134 "' is already registered.")
135 .str());
136 }
137 found = std::move(pyClass);
138}
139
140std::optional<nb::callable>
141PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
142 nb::ft_lock_guard lock(mutex);
143 const auto foundIt = attributeBuilderMap.find(attributeKind);
144 if (foundIt != attributeBuilderMap.end()) {
145 assert(foundIt->second && "attribute builder is defined");
146 return foundIt->second;
147 }
148 return std::nullopt;
149}
150
151std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
152 MlirDialect dialect) {
153 // Try to load dialect module.
155 nb::ft_lock_guard lock(mutex);
156 const auto foundIt = typeCasterMap.find(mlirTypeID);
157 if (foundIt != typeCasterMap.end()) {
158 assert(foundIt->second && "type caster is defined");
159 return foundIt->second;
160 }
161 return std::nullopt;
162}
163
164std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
165 MlirDialect dialect) {
166 // Try to load dialect module.
168 nb::ft_lock_guard lock(mutex);
169 const auto foundIt = valueCasterMap.find(mlirTypeID);
170 if (foundIt != valueCasterMap.end()) {
171 assert(foundIt->second && "value caster is defined");
172 return foundIt->second;
173 }
174 return std::nullopt;
175}
176
177std::optional<nb::object>
178PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
179 // Make sure dialect module is loaded.
180 if (!loadDialectModule(dialectNamespace))
181 return std::nullopt;
182 nb::ft_lock_guard lock(mutex);
183 const auto foundIt = dialectClassMap.find(dialectNamespace);
184 if (foundIt != dialectClassMap.end()) {
185 assert(foundIt->second && "dialect class is defined");
186 return foundIt->second;
187 }
188 // Not found and loading did not yield a registration.
189 return std::nullopt;
190}
191
192std::optional<nb::object>
193PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
194 // Make sure dialect module is loaded.
195 auto split = operationName.split('.');
196 llvm::StringRef dialectNamespace = split.first;
197 if (!loadDialectModule(dialectNamespace))
198 return std::nullopt;
199
200 nb::ft_lock_guard lock(mutex);
201 auto foundIt = operationClassMap.find(operationName);
202 if (foundIt != operationClassMap.end()) {
203 assert(foundIt->second && "OpView is defined");
204 return foundIt->second;
205 }
206 // Not found and loading did not yield a registration.
207 return std::nullopt;
208}
209
211 nanobind::ft_lock_guard lock(mutex);
212 return locTracebackEnabled_;
213}
214
216 nanobind::ft_lock_guard lock(mutex);
217 locTracebackEnabled_ = value;
218}
219
221 nanobind::ft_lock_guard lock(mutex);
222 return locTracebackFramesLimit_;
223}
224
226 nanobind::ft_lock_guard lock(mutex);
227 locTracebackFramesLimit_ = std::min(value, kMaxFrames);
228}
229
231 const std::string &file) {
232 nanobind::ft_lock_guard lock(mutex);
233 auto reg = "^" + llvm::Regex::escape(file);
234 if (userTracebackIncludeFiles.insert(reg).second)
235 rebuildUserTracebackIncludeRegex = true;
236 if (userTracebackExcludeFiles.count(reg)) {
237 if (userTracebackExcludeFiles.erase(reg))
238 rebuildUserTracebackExcludeRegex = true;
239 }
240}
241
243 const std::string &file) {
244 nanobind::ft_lock_guard lock(mutex);
245 auto reg = "^" + llvm::Regex::escape(file);
246 if (userTracebackExcludeFiles.insert(reg).second)
247 rebuildUserTracebackExcludeRegex = true;
248 if (userTracebackIncludeFiles.count(reg)) {
249 if (userTracebackIncludeFiles.erase(reg))
250 rebuildUserTracebackIncludeRegex = true;
251 }
252}
253
255 const llvm::StringRef file) {
256 nanobind::ft_lock_guard lock(mutex);
257 if (rebuildUserTracebackIncludeRegex) {
258 userTracebackIncludeRegex.assign(
259 llvm::join(userTracebackIncludeFiles, "|"));
260 rebuildUserTracebackIncludeRegex = false;
261 isUserTracebackFilenameCache.clear();
262 }
263 if (rebuildUserTracebackExcludeRegex) {
264 userTracebackExcludeRegex.assign(
265 llvm::join(userTracebackExcludeFiles, "|"));
266 rebuildUserTracebackExcludeRegex = false;
267 isUserTracebackFilenameCache.clear();
268 }
269 if (!isUserTracebackFilenameCache.contains(file)) {
270 std::string fileStr = file.str();
271 bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
272 bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
273 isUserTracebackFilenameCache[file] = include || !exclude;
274 }
275 return isUserTracebackFilenameCache[file];
276}
277} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
278} // namespace python
279} // namespace mlir
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:34
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:164
bool loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:49
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:44
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:96
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:82
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:128
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:151
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:106
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:193
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:141
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition Globals.cpp:116
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:178
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:137
Include the generated interface declarations.