MLIR 22.0.0git
IRModule.cpp
Go to the documentation of this file.
1//===- IRModule.cpp - IR pybind module ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include <optional>
12#include <vector>
13
14#include "Globals.h"
15#include "NanobindUtils.h"
17#include "mlir-c/Support.h"
19
20namespace nb = nanobind;
21using namespace mlir;
22using namespace mlir::python;
23
24// -----------------------------------------------------------------------------
25// PyGlobals
26// -----------------------------------------------------------------------------
27
28PyGlobals *PyGlobals::instance = nullptr;
29
31 assert(!instance && "PyGlobals already constructed");
32 instance = this;
33 // The default search path include {mlir.}dialects, where {mlir.} is the
34 // package prefix configured at compile time.
35 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
36}
37
38PyGlobals::~PyGlobals() { instance = nullptr; }
39
40bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41 {
42 nb::ft_lock_guard lock(mutex);
43 if (loadedDialectModules.contains(dialectNamespace))
44 return true;
45 }
46 // Since re-entrancy is possible, make a copy of the search prefixes.
47 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
48 nb::object loaded = nb::none();
49 for (std::string moduleName : localSearchPrefixes) {
50 moduleName.push_back('.');
51 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
52
53 try {
54 loaded = nb::module_::import_(moduleName.c_str());
55 } catch (nb::python_error &e) {
56 if (e.matches(PyExc_ModuleNotFoundError)) {
57 continue;
58 }
59 throw;
60 }
61 break;
62 }
63
64 if (loaded.is_none())
65 return false;
66 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
67 // may have occurred, which may do anything.
68 nb::ft_lock_guard lock(mutex);
69 loadedDialectModules.insert(dialectNamespace);
70 return true;
71}
72
73void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
74 nb::callable pyFunc, bool replace) {
75 nb::ft_lock_guard lock(mutex);
76 nb::object &found = attributeBuilderMap[attributeKind];
77 if (found && !replace) {
78 throw std::runtime_error((llvm::Twine("Attribute builder for '") +
79 attributeKind +
80 "' is already registered with func: " +
81 nb::cast<std::string>(nb::str(found)))
82 .str());
83 }
84 found = std::move(pyFunc);
85}
86
87void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
88 nb::callable typeCaster, bool replace) {
89 nb::ft_lock_guard lock(mutex);
90 nb::object &found = typeCasterMap[mlirTypeID];
91 if (found && !replace)
92 throw std::runtime_error("Type caster is already registered with caster: " +
93 nb::cast<std::string>(nb::str(found)));
94 found = std::move(typeCaster);
95}
96
97void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
98 nb::callable valueCaster, bool replace) {
99 nb::ft_lock_guard lock(mutex);
100 nb::object &found = valueCasterMap[mlirTypeID];
101 if (found && !replace)
102 throw std::runtime_error("Value caster is already registered: " +
103 nb::cast<std::string>(nb::repr(found)));
104 found = std::move(valueCaster);
105}
106
107void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
108 nb::object pyClass) {
109 nb::ft_lock_guard lock(mutex);
110 nb::object &found = dialectClassMap[dialectNamespace];
111 if (found) {
112 throw std::runtime_error((llvm::Twine("Dialect namespace '") +
113 dialectNamespace + "' is already registered.")
114 .str());
115 }
116 found = std::move(pyClass);
117}
118
119void PyGlobals::registerOperationImpl(const std::string &operationName,
120 nb::object pyClass, bool replace) {
121 nb::ft_lock_guard lock(mutex);
122 nb::object &found = operationClassMap[operationName];
123 if (found && !replace) {
124 throw std::runtime_error((llvm::Twine("Operation '") + operationName +
125 "' is already registered.")
126 .str());
127 }
128 found = std::move(pyClass);
129}
130
131std::optional<nb::callable>
132PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133 nb::ft_lock_guard lock(mutex);
134 const auto foundIt = attributeBuilderMap.find(attributeKind);
135 if (foundIt != attributeBuilderMap.end()) {
136 assert(foundIt->second && "attribute builder is defined");
137 return foundIt->second;
138 }
139 return std::nullopt;
140}
141
142std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
143 MlirDialect dialect) {
144 // Try to load dialect module.
146 nb::ft_lock_guard lock(mutex);
147 const auto foundIt = typeCasterMap.find(mlirTypeID);
148 if (foundIt != typeCasterMap.end()) {
149 assert(foundIt->second && "type caster is defined");
150 return foundIt->second;
151 }
152 return std::nullopt;
153}
154
155std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
156 MlirDialect dialect) {
157 // Try to load dialect module.
159 nb::ft_lock_guard lock(mutex);
160 const auto foundIt = valueCasterMap.find(mlirTypeID);
161 if (foundIt != valueCasterMap.end()) {
162 assert(foundIt->second && "value caster is defined");
163 return foundIt->second;
164 }
165 return std::nullopt;
166}
167
168std::optional<nb::object>
169PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
170 // Make sure dialect module is loaded.
171 if (!loadDialectModule(dialectNamespace))
172 return std::nullopt;
173 nb::ft_lock_guard lock(mutex);
174 const auto foundIt = dialectClassMap.find(dialectNamespace);
175 if (foundIt != dialectClassMap.end()) {
176 assert(foundIt->second && "dialect class is defined");
177 return foundIt->second;
178 }
179 // Not found and loading did not yield a registration.
180 return std::nullopt;
181}
182
183std::optional<nb::object>
184PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
185 // Make sure dialect module is loaded.
186 auto split = operationName.split('.');
187 llvm::StringRef dialectNamespace = split.first;
188 if (!loadDialectModule(dialectNamespace))
189 return std::nullopt;
190
191 nb::ft_lock_guard lock(mutex);
192 auto foundIt = operationClassMap.find(operationName);
193 if (foundIt != operationClassMap.end()) {
194 assert(foundIt->second && "OpView is defined");
195 return foundIt->second;
196 }
197 // Not found and loading did not yield a registration.
198 return std::nullopt;
199}
200
202 nanobind::ft_lock_guard lock(mutex);
203 return locTracebackEnabled_;
204}
205
207 nanobind::ft_lock_guard lock(mutex);
208 locTracebackEnabled_ = value;
209}
210
212 nanobind::ft_lock_guard lock(mutex);
213 return locTracebackFramesLimit_;
214}
215
217 nanobind::ft_lock_guard lock(mutex);
218 locTracebackFramesLimit_ = std::min(value, kMaxFrames);
219}
220
222 const std::string &file) {
223 nanobind::ft_lock_guard lock(mutex);
224 auto reg = "^" + llvm::Regex::escape(file);
225 if (userTracebackIncludeFiles.insert(reg).second)
226 rebuildUserTracebackIncludeRegex = true;
227 if (userTracebackExcludeFiles.count(reg)) {
228 if (userTracebackExcludeFiles.erase(reg))
229 rebuildUserTracebackExcludeRegex = true;
230 }
231}
232
234 const std::string &file) {
235 nanobind::ft_lock_guard lock(mutex);
236 auto reg = "^" + llvm::Regex::escape(file);
237 if (userTracebackExcludeFiles.insert(reg).second)
238 rebuildUserTracebackExcludeRegex = true;
239 if (userTracebackIncludeFiles.count(reg)) {
240 if (userTracebackIncludeFiles.erase(reg))
241 rebuildUserTracebackIncludeRegex = true;
242 }
243}
244
246 const llvm::StringRef file) {
247 nanobind::ft_lock_guard lock(mutex);
248 if (rebuildUserTracebackIncludeRegex) {
249 userTracebackIncludeRegex.assign(
250 llvm::join(userTracebackIncludeFiles, "|"));
251 rebuildUserTracebackIncludeRegex = false;
252 isUserTracebackFilenameCache.clear();
253 }
254 if (rebuildUserTracebackExcludeRegex) {
255 userTracebackExcludeRegex.assign(
256 llvm::join(userTracebackExcludeFiles, "|"));
257 rebuildUserTracebackExcludeRegex = false;
258 isUserTracebackFilenameCache.clear();
259 }
260 if (!isUserTracebackFilenameCache.contains(file)) {
261 std::string fileStr = file.str();
262 bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
263 bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
264 isUserTracebackFilenameCache[file] = include || !exclude;
265 }
266 return isUserTracebackFilenameCache[file];
267}
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
void registerTracebackFileExclusion(const std::string &file)
Definition IRModule.cpp:233
static constexpr size_t kMaxFrames
Definition Globals.h:138
bool isUserTracebackFilename(llvm::StringRef file)
Definition IRModule.cpp:245
void setLocTracebackFramesLimit(size_t value)
Definition IRModule.cpp:216
void registerTracebackFileInclusion(const std::string &file)
Definition IRModule.cpp:221
Globals that are always accessible once the extension has been initialized.
Definition Globals.h:33
bool loadDialectModule(llvm::StringRef dialectNamespace)
Loads a python module corresponding to the given dialect namespace.
Definition IRModule.cpp:40
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition IRModule.cpp:87
void registerOperationImpl(const std::string &operationName, nanobind::object pyClass, bool replace=false)
Adds a concrete implementation operation class.
Definition IRModule.cpp:119
void registerAttributeBuilder(const std::string &attributeKind, nanobind::callable pyFunc, bool replace=false)
Adds a user-friendly Attribute builder.
Definition IRModule.cpp:73
void registerValueCaster(MlirTypeID mlirTypeID, nanobind::callable valueCaster, bool replace=false)
Adds a user-friendly value caster.
Definition IRModule.cpp:97
std::optional< nanobind::callable > lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom value caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:155
std::optional< nanobind::object > lookupDialectClass(const std::string &dialectNamespace)
Looks up a registered dialect class by namespace.
Definition IRModule.cpp:169
std::optional< nanobind::object > lookupOperationClass(llvm::StringRef operationName)
Looks up a registered operation class (deriving from OpView) by operation name.
Definition IRModule.cpp:184
std::optional< nanobind::callable > lookupAttributeBuilder(const std::string &attributeKind)
Returns the custom Attribute builder for Attribute kind.
Definition IRModule.cpp:132
void registerDialectImpl(const std::string &dialectNamespace, nanobind::object pyClass)
Adds a concrete implementation dialect class.
Definition IRModule.cpp:107
std::optional< nanobind::callable > lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect)
Returns the custom type caster for MlirTypeID mlirTypeID.
Definition IRModule.cpp:142
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect)
Returns the namespace of the given dialect.
Definition IR.cpp:137
Include the generated interface declarations.