MLIR 23.0.0git
Globals.cpp
Go to the documentation of this file.
1//===- IRModule.cpp - IR pybind module ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include <cstring>
12#include <optional>
13#include <sstream>
14#include <string_view>
15#include <vector>
16
18// clang-format off
21// clang-format on
22#include "mlir-c/Support.h"
24
25namespace nb = nanobind;
26using namespace mlir;
27
28/// Local helper adapted from llvm::Regex::escape.
29static std::string escapeRegex(std::string_view String) {
30 static constexpr char RegexMetachars[] = "()^$|*+?.[]\\{}";
31 std::string RegexStr;
32 for (char C : String) {
33 if (std::strchr(RegexMetachars, C))
34 RegexStr += '\\';
35 RegexStr += C;
36 }
37 return RegexStr;
38}
39
40// -----------------------------------------------------------------------------
41// PyGlobals
42// -----------------------------------------------------------------------------
43
44namespace mlir {
45namespace python {
47PyGlobals *PyGlobals::instance = nullptr;
48
50 assert(!instance && "PyGlobals already constructed");
51 instance = this;
52 // The default search path include {mlir.}dialects, where {mlir.} is the
53 // package prefix configured at compile time.
54 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
55}
56
57PyGlobals::~PyGlobals() { instance = nullptr; }
58
60 assert(instance && "PyGlobals is null");
61 return *instance;
62}
63
64bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
65 {
66 nb::ft_lock_guard lock(mutex);
67 std::string dialectNamespaceStr(dialectNamespace);
68 if (loadedDialectModules.find(dialectNamespaceStr) !=
69 loadedDialectModules.end())
70 return true;
71 }
72 // Since re-entrancy is possible, make a copy of the search prefixes.
73 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
74 nb::object loaded = nb::none();
75 for (std::string moduleName : localSearchPrefixes) {
76 moduleName.push_back('.');
77 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
78
79 try {
80 loaded = nb::module_::import_(moduleName.c_str());
81 } catch (nb::python_error &e) {
82 if (e.matches(PyExc_ModuleNotFoundError)) {
83 continue;
84 }
85 throw;
86 }
87 break;
88 }
89
90 if (loaded.is_none())
91 return false;
92 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
93 // may have occurred, which may do anything.
94 nb::ft_lock_guard lock(mutex);
95 loadedDialectModules.insert(std::string(dialectNamespace));
96 return true;
97}
98
99void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
100 nb::callable pyFunc, bool replace) {
101 nb::ft_lock_guard lock(mutex);
102 nb::object &found = attributeBuilderMap[attributeKind];
103 if (found && !replace) {
104 throw std::runtime_error(
105 nanobind::detail::join("Attribute builder for '", attributeKind,
106 "' is already registered with func: ",
107 nb::cast<std::string>(nb::str(found))));
108 }
109 found = std::move(pyFunc);
110}
111
112void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
113 nb::callable typeCaster, bool replace) {
114 nb::ft_lock_guard lock(mutex);
115 nb::object &found = typeCasterMap[mlirTypeID];
116 if (found && !replace)
117 throw std::runtime_error("Type caster is already registered with caster: " +
118 nb::cast<std::string>(nb::str(found)));
119 found = std::move(typeCaster);
120}
121
122void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
123 nb::callable valueCaster, bool replace) {
124 nb::ft_lock_guard lock(mutex);
125 nb::object &found = valueCasterMap[mlirTypeID];
126 if (found && !replace)
127 throw std::runtime_error("Value caster is already registered: " +
128 nb::cast<std::string>(nb::repr(found)));
129 found = std::move(valueCaster);
130}
131
132void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
133 nb::object pyClass) {
134 nb::ft_lock_guard lock(mutex);
135 nb::object &found = dialectClassMap[dialectNamespace];
136 if (found) {
137 throw std::runtime_error(nanobind::detail::join(
138 "Dialect namespace '", dialectNamespace, "' is already registered."));
139 }
140 found = std::move(pyClass);
141}
142
143void PyGlobals::registerOperationImpl(const std::string &operationName,
144 nb::object pyClass, bool replace) {
145 nb::ft_lock_guard lock(mutex);
146 nb::object &found = operationClassMap[operationName];
147 if (found && !replace) {
148 throw std::runtime_error(nanobind::detail::join(
149 "Operation '", operationName, "' is already registered."));
150 }
151 found = std::move(pyClass);
152}
153
154void PyGlobals::registerOpAdaptorImpl(const std::string &operationName,
155 nb::object pyClass, bool replace) {
156 nb::ft_lock_guard lock(mutex);
157 nb::object &found = opAdaptorClassMap[operationName];
158 if (found && !replace) {
159 throw std::runtime_error(nanobind::detail::join(
160 "Operation adaptor of '", operationName, "' is already registered."));
161 }
162 found = std::move(pyClass);
163}
164
165std::optional<nb::callable>
166PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
167 nb::ft_lock_guard lock(mutex);
168 const auto foundIt = attributeBuilderMap.find(attributeKind);
169 if (foundIt != attributeBuilderMap.end()) {
170 assert(foundIt->second && "attribute builder is defined");
171 return foundIt->second;
172 }
173 return std::nullopt;
174}
175
176std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
177 MlirDialect dialect) {
178 // Try to load dialect module.
180 (void)loadDialectModule(std::string_view(ns.data, ns.length));
181 nb::ft_lock_guard lock(mutex);
182 const auto foundIt = typeCasterMap.find(mlirTypeID);
183 if (foundIt != typeCasterMap.end()) {
184 assert(foundIt->second && "type caster is defined");
185 return foundIt->second;
186 }
187 return std::nullopt;
188}
189
190std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
191 MlirDialect dialect) {
192 // Try to load dialect module.
194 (void)loadDialectModule(std::string_view(ns.data, ns.length));
195 nb::ft_lock_guard lock(mutex);
196 const auto foundIt = valueCasterMap.find(mlirTypeID);
197 if (foundIt != valueCasterMap.end()) {
198 assert(foundIt->second && "value caster is defined");
199 return foundIt->second;
200 }
201 return std::nullopt;
202}
203
204std::optional<nb::object>
205PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
206 // Make sure dialect module is loaded.
207 (void)loadDialectModule(dialectNamespace);
208
209 nb::ft_lock_guard lock(mutex);
210 const auto foundIt = dialectClassMap.find(dialectNamespace);
211 if (foundIt != dialectClassMap.end()) {
212 assert(foundIt->second && "dialect class is defined");
213 return foundIt->second;
214 }
215 // Not found and loading did not yield a registration.
216 return std::nullopt;
217}
218
219std::optional<nb::object>
220PyGlobals::lookupOperationClass(std::string_view operationName) {
221 // Make sure dialect module is loaded.
222 std::string_view dialectNamespace =
223 operationName.substr(0, operationName.find('.'));
224 (void)loadDialectModule(dialectNamespace);
225
226 nb::ft_lock_guard lock(mutex);
227 std::string operationNameStr(operationName);
228 auto foundIt = operationClassMap.find(operationNameStr);
229 if (foundIt != operationClassMap.end()) {
230 assert(foundIt->second && "OpView is defined");
231 return foundIt->second;
232 }
233 // Not found and loading did not yield a registration.
234 return std::nullopt;
235}
236
237std::optional<nb::object>
238PyGlobals::lookupOpAdaptorClass(std::string_view operationName) {
239 // Make sure dialect module is loaded.
240 std::string_view dialectNamespace =
241 operationName.substr(0, operationName.find('.'));
242 (void)loadDialectModule(dialectNamespace);
243
244 nb::ft_lock_guard lock(mutex);
245 std::string operationNameStr(operationName);
246 auto foundIt = opAdaptorClassMap.find(operationNameStr);
247 if (foundIt != opAdaptorClassMap.end()) {
248 assert(foundIt->second && "OpAdaptor is defined");
249 return foundIt->second;
250 }
251 // Not found and loading did not yield a registration.
252 return std::nullopt;
253}
254
256 nanobind::ft_lock_guard lock(mutex);
257 return locTracebackEnabled_;
258}
259
261 nanobind::ft_lock_guard lock(mutex);
262 locTracebackEnabled_ = value;
263}
264
266 nanobind::ft_lock_guard lock(mutex);
267 return locTracebackFramesLimit_;
268}
269
271 nanobind::ft_lock_guard lock(mutex);
272 locTracebackFramesLimit_ = std::min(value, kMaxFrames);
273}
274
276 const std::string &file) {
277 nanobind::ft_lock_guard lock(mutex);
278 auto reg = "^" + escapeRegex(file);
279 if (userTracebackIncludeFiles.insert(reg).second)
280 rebuildUserTracebackIncludeRegex = true;
281 if (userTracebackExcludeFiles.count(reg)) {
282 if (userTracebackExcludeFiles.erase(reg))
283 rebuildUserTracebackExcludeRegex = true;
284 }
285}
286
288 const std::string &file) {
289 nanobind::ft_lock_guard lock(mutex);
290 auto reg = "^" + escapeRegex(file);
291 if (userTracebackExcludeFiles.insert(reg).second)
292 rebuildUserTracebackExcludeRegex = true;
293 if (userTracebackIncludeFiles.count(reg)) {
294 if (userTracebackIncludeFiles.erase(reg))
295 rebuildUserTracebackIncludeRegex = true;
296 }
297}
298
300 const std::string_view file) {
301 nanobind::ft_lock_guard lock(mutex);
302 auto joinWithPipe = [](const std::unordered_set<std::string> &set) {
303 std::ostringstream os;
304 for (auto it = set.begin(); it != set.end(); ++it) {
305 if (it != set.begin())
306 os << "|";
307 os << *it;
308 }
309 return os.str();
310 };
311 if (rebuildUserTracebackIncludeRegex) {
312 userTracebackIncludeRegex.assign(joinWithPipe(userTracebackIncludeFiles));
313 rebuildUserTracebackIncludeRegex = false;
314 isUserTracebackFilenameCache.clear();
315 }
316 if (rebuildUserTracebackExcludeRegex) {
317 userTracebackExcludeRegex.assign(joinWithPipe(userTracebackExcludeFiles));
318 rebuildUserTracebackExcludeRegex = false;
319 isUserTracebackFilenameCache.clear();
320 }
321 std::string fileStr(file);
322 const auto foundIt = isUserTracebackFilenameCache.find(fileStr);
323 if (foundIt == isUserTracebackFilenameCache.end()) {
324 bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
325 bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
326 isUserTracebackFilenameCache[fileStr] = include || !exclude;
327 }
328 return isUserTracebackFilenameCache[fileStr];
329}
330} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
331} // namespace python
332} // namespace mlir
static std::string escapeRegex(std::string_view String)
Local helper adapted from llvm::Regex::escape.
Definition Globals.cpp:29
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:29
void registerOpAdaptorImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds an operation adaptor class.
Definition Globals.cpp:154
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:190
bool loadDialectModule(std::string_view dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition Globals.cpp:64
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
Definition Globals.cpp:238
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
std::optional< nanobind::object > lookupOperationClass(std::string_view operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition Globals.cpp:220
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition Globals.cpp:112
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition Globals.cpp:99
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition Globals.cpp:143
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition Globals.cpp:176
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition Globals.cpp:122
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition Globals.cpp:166
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition Globals.cpp:132
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition Globals.cpp:205
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:136
Include the generated interface declarations.
@ String
A string value.
Definition AsmState.h:286
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80