MLIR 23.0.0git
Globals.cpp
Go to the documentation of this file.
1//===- IRModule.cpp - IR pybind module ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include <cstring>
12#include <optional>
13#include <sstream>
14#include <string_view>
15#include <vector>
16
18// clang-format off
21// clang-format on
22#include "mlir-c/Support.h"
24
25namespace nb = nanobind;
26using namespace mlir;
27
28/// Local helper adapted from llvm::Regex::escape.
29static std::string escapeRegex(std::string_view String) {
30 static constexpr char RegexMetachars[] = "()^$|*+?.[]\\{}";
31 std::string RegexStr;
32 for (char C : String) {
33 if (std::strchr(RegexMetachars, C))
34 RegexStr += '\\';
35 RegexStr += C;
36 }
37 return RegexStr;
38}
39
40// -----------------------------------------------------------------------------
41// PyGlobals
42// -----------------------------------------------------------------------------
43
44namespace mlir {
45namespace python {
47PyGlobals *PyGlobals::instance = nullptr;
48
50 assert(!instance && "PyGlobals already constructed");
51 instance = this;
52 // The default search path include {mlir.}dialects, where {mlir.} is the
53 // package prefix configured at compile time.
54 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
55}
56
57PyGlobals::~PyGlobals() { instance = nullptr; }
58
60 assert(instance && "PyGlobals is null");
61 return *instance;
62}
63
64bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
65 {
66 nb::ft_lock_guard lock(mutex);
67 std::string dialectNamespaceStr(dialectNamespace);
68 if (loadedDialectModules.find(dialectNamespaceStr) !=
69 loadedDialectModules.end())
70 return true;
71 }
72 // Since re-entrancy is possible, make a copy of the search prefixes.
73 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
74 nb::object loaded = nb::none();
75 for (std::string moduleName : localSearchPrefixes) {
76 moduleName.push_back('.');
77 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
78
79 try {
80 loaded = nb::module_::import_(moduleName.c_str());
81 } catch (nb::python_error &e) {
82 if (e.matches(PyExc_ModuleNotFoundError)) {
83 continue;
84 }
85 throw;
86 }
87 break;
88 }
89
90 if (loaded.is_none())
91 return false;
92 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
93 // may have occurred, which may do anything.
94 nb::ft_lock_guard lock(mutex);
95 loadedDialectModules.insert(std::string(dialectNamespace));
96 return true;
97}
98
99void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
100 nb::callable pyFunc, bool replace,
101 bool allowExisting) {
102 nb::ft_lock_guard lock(mutex);
103 nb::object &found = attributeBuilderMap[attributeKind];
104 if (found) {
105 std::string msg =
106 nanobind::detail::join("Attribute builder for '", attributeKind,
107 "' is already registered with func: ",
108 nb::cast<std::string>(nb::str(found)));
109 if (allowExisting) {
110#ifndef NDEBUG
111 if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) {
112 // If the user has set warnings to errors (e.g., via -Werror),
113 // PyErr_WarnEx returns -1 and sets a Python exception.
114 throw nb::python_error();
115 }
116#endif
117 return;
118 }
119 if (!replace)
120 throw std::runtime_error(msg);
121 }
122 found = std::move(pyFunc);
123}
124
125void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
126 nb::callable typeCaster, bool replace) {
127 nb::ft_lock_guard lock(mutex);
128 nb::object &found = typeCasterMap[mlirTypeID];
129 if (found && !replace)
130 throw std::runtime_error("Type caster is already registered with caster: " +
131 nb::cast<std::string>(nb::str(found)));
132 found = std::move(typeCaster);
133}
134
135void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
136 nb::callable valueCaster, bool replace) {
137 nb::ft_lock_guard lock(mutex);
138 nb::object &found = valueCasterMap[mlirTypeID];
139 if (found && !replace)
140 throw std::runtime_error("Value caster is already registered: " +
141 nb::cast<std::string>(nb::repr(found)));
142 found = std::move(valueCaster);
143}
144
145void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
146 nb::object pyClass, bool replace) {
147 nb::ft_lock_guard lock(mutex);
148 nb::object &found = dialectClassMap[dialectNamespace];
149 if (found && !replace) {
150 throw std::runtime_error(nanobind::detail::join(
151 "Dialect namespace '", dialectNamespace, "' is already registered."));
152 }
153 found = std::move(pyClass);
154}
155
156void PyGlobals::registerOperationImpl(const std::string &operationName,
157 nb::object pyClass, bool replace) {
158 nb::ft_lock_guard lock(mutex);
159 nb::object &found = operationClassMap[operationName];
160 if (found && !replace) {
161 throw std::runtime_error(nanobind::detail::join(
162 "Operation '", operationName, "' is already registered."));
163 }
164 found = std::move(pyClass);
165}
166
167void PyGlobals::registerOpAdaptorImpl(const std::string &operationName,
168 nb::object pyClass, bool replace) {
169 nb::ft_lock_guard lock(mutex);
170 nb::object &found = opAdaptorClassMap[operationName];
171 if (found && !replace) {
172 throw std::runtime_error(nanobind::detail::join(
173 "Operation adaptor of '", operationName, "' is already registered."));
174 }
175 found = std::move(pyClass);
176}
177
178std::optional<nb::callable>
179PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
180 nb::ft_lock_guard lock(mutex);
181 const auto foundIt = attributeBuilderMap.find(attributeKind);
182 if (foundIt != attributeBuilderMap.end()) {
183 assert(foundIt->second && "attribute builder is defined");
184 return foundIt->second;
185 }
186 return std::nullopt;
187}
188
189std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
190 MlirDialect dialect) {
191 // Try to load dialect module.
193 (void)loadDialectModule(std::string_view(ns.data, ns.length));
194 nb::ft_lock_guard lock(mutex);
195 const auto foundIt = typeCasterMap.find(mlirTypeID);
196 if (foundIt != typeCasterMap.end()) {
197 assert(foundIt->second && "type caster is defined");
198 return foundIt->second;
199 }
200 return std::nullopt;
201}
202
203std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
204 MlirDialect dialect) {
205 // Try to load dialect module.
207 (void)loadDialectModule(std::string_view(ns.data, ns.length));
208 nb::ft_lock_guard lock(mutex);
209 const auto foundIt = valueCasterMap.find(mlirTypeID);
210 if (foundIt != valueCasterMap.end()) {
211 assert(foundIt->second && "value caster is defined");
212 return foundIt->second;
213 }
214 return std::nullopt;
215}
216
217std::optional<nb::object>
218PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
219 // Make sure dialect module is loaded.
220 (void)loadDialectModule(dialectNamespace);
221
222 nb::ft_lock_guard lock(mutex);
223 const auto foundIt = dialectClassMap.find(dialectNamespace);
224 if (foundIt != dialectClassMap.end()) {
225 assert(foundIt->second && "dialect class is defined");
226 return foundIt->second;
227 }
228 // Not found and loading did not yield a registration.
229 return std::nullopt;
230}
231
232std::optional<nb::object>
233PyGlobals::lookupOperationClass(std::string_view operationName) {
234 // Make sure dialect module is loaded.
235 std::string_view dialectNamespace =
236 operationName.substr(0, operationName.find('.'));
237 (void)loadDialectModule(dialectNamespace);
238
239 nb::ft_lock_guard lock(mutex);
240 std::string operationNameStr(operationName);
241 auto foundIt = operationClassMap.find(operationNameStr);
242 if (foundIt != operationClassMap.end()) {
243 assert(foundIt->second && "OpView is defined");
244 return foundIt->second;
245 }
246 // Not found and loading did not yield a registration.
247 return std::nullopt;
248}
249
250std::optional<nb::object>
251PyGlobals::lookupOpAdaptorClass(std::string_view operationName) {
252 // Make sure dialect module is loaded.
253 std::string_view dialectNamespace =
254 operationName.substr(0, operationName.find('.'));
255 (void)loadDialectModule(dialectNamespace);
256
257 nb::ft_lock_guard lock(mutex);
258 std::string operationNameStr(operationName);
259 auto foundIt = opAdaptorClassMap.find(operationNameStr);
260 if (foundIt != opAdaptorClassMap.end()) {
261 assert(foundIt->second && "OpAdaptor is defined");
262 return foundIt->second;
263 }
264 // Not found and loading did not yield a registration.
265 return std::nullopt;
266}
267
269 nanobind::ft_lock_guard lock(mutex);
270 return locTracebackEnabled_;
271}
272
274 nanobind::ft_lock_guard lock(mutex);
275 locTracebackEnabled_ = value;
276}
277
279 nanobind::ft_lock_guard lock(mutex);
280 return locTracebackFramesLimit_;
281}
282
284 nanobind::ft_lock_guard lock(mutex);
285 locTracebackFramesLimit_ = std::min(value, kMaxFrames);
286}
287
289 const std::string &file) {
290 nanobind::ft_lock_guard lock(mutex);
291 auto reg = "^" + escapeRegex(file);
292 if (userTracebackIncludeFiles.insert(reg).second)
293 rebuildUserTracebackIncludeRegex = true;
294 if (userTracebackExcludeFiles.count(reg)) {
295 if (userTracebackExcludeFiles.erase(reg))
296 rebuildUserTracebackExcludeRegex = true;
297 }
298}
299
301 const std::string &file) {
302 nanobind::ft_lock_guard lock(mutex);
303 auto reg = "^" + escapeRegex(file);
304 if (userTracebackExcludeFiles.insert(reg).second)
305 rebuildUserTracebackExcludeRegex = true;
306 if (userTracebackIncludeFiles.count(reg)) {
307 if (userTracebackIncludeFiles.erase(reg))
308 rebuildUserTracebackIncludeRegex = true;
309 }
310}
311
313 const std::string_view file) {
314 nanobind::ft_lock_guard lock(mutex);
315 auto joinWithPipe = [](const std::unordered_set<std::string> &set) {
316 std::ostringstream os;
317 for (auto it = set.begin(); it != set.end(); ++it) {
318 if (it != set.begin())
319 os << "|";
320 os << *it;
321 }
322 return os.str();
323 };
324 if (rebuildUserTracebackIncludeRegex) {
325 userTracebackIncludeRegex.assign(joinWithPipe(userTracebackIncludeFiles));
326 rebuildUserTracebackIncludeRegex = false;
327 isUserTracebackFilenameCache.clear();
328 }
329 if (rebuildUserTracebackExcludeRegex) {
330 userTracebackExcludeRegex.assign(joinWithPipe(userTracebackExcludeFiles));
331 rebuildUserTracebackExcludeRegex = false;
332 isUserTracebackFilenameCache.clear();
333 }
334 std::string fileStr(file);
335 const auto foundIt = isUserTracebackFilenameCache.find(fileStr);
336 if (foundIt == isUserTracebackFilenameCache.end()) {
337 bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
338 bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
339 isUserTracebackFilenameCache[fileStr] = include || !exclude;
340 }
341 return isUserTracebackFilenameCache[fileStr];
342}
343} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
344} // namespace python
345} // namespace mlir
static std::string escapeRegex(std::string_view String)
Local helper adapted from llvm::Regex::escape.
Definition Globals.cpp:29
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:167
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:203
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
Definition Globals.cpp:251
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation dialect class.
Definition Globals.cpp:145
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:233
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:125
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:156
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:189
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:135
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:179
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:218
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false, bool allow_existing=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
Include the generated interface declarations.
@ String
A string value.
Definition AsmState.h:286
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80