MLIR 23.0.0git
NanobindUtils.h
Go to the documentation of this file.
1//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12
13#include "mlir-c/Support.h"
15
16#include <atomic>
17#include <fstream>
18#include <memory>
19#include <sstream>
20#include <string>
21#include <string_view>
22#include <type_traits>
23#include <typeinfo>
24#include <variant>
25
26template <>
27struct std::iterator_traits<nanobind::detail::fast_iterator> {
28 using value_type = nanobind::handle;
29 using reference = const value_type;
30 using pointer = void;
31 using difference_type = std::ptrdiff_t;
32 using iterator_category = std::forward_iterator_tag;
33};
34
35namespace mlir {
36namespace python {
37
38/// Safely calls Python initialization code on first use, avoiding deadlocks.
39template <typename T>
40class SafeInit {
41public:
42 typedef std::unique_ptr<T> (*F)();
43
44 explicit SafeInit(F init_fn) : initFn(init_fn) {}
45
46 T &get() {
47 if (T *result = output.load()) {
48 return *result;
49 }
50
51 // Note: init_fn() may be called multiple times if, for example, the GIL is
52 // released during its execution. The intended use case is for module
53 // imports which are safe to perform multiple times. We are careful not to
54 // hold a lock across init_fn() to avoid lock ordering problems.
55 std::unique_ptr<T> m = initFn();
56 {
57 nanobind::ft_lock_guard lock(mu);
58 if (T *result = output.load()) {
59 return *result;
60 }
61 T *p = m.release();
62 output.store(p);
63 return *p;
64 }
65 }
66
67private:
68 nanobind::ft_mutex mu;
69 std::atomic<T *> output{nullptr};
70 F initFn;
71};
72
74 size_t operator()(MlirTypeID typeID) const {
75 return mlirTypeIDHashValue(typeID);
76 }
77};
78
80 bool operator()(MlirTypeID lhs, MlirTypeID rhs) const {
81 return mlirTypeIDEqual(lhs, rhs);
82 }
83};
84
85/// CRTP template for special wrapper types that are allowed to be passed in as
86/// 'None' function arguments and can be resolved by some global mechanic if
87/// so. Such types will raise an error if this global resolution fails, and
88/// it is actually illegal for them to ever be unresolved. From a user
89/// perspective, they behave like a smart ptr to the underlying type (i.e.
90/// 'get' method and operator-> overloaded).
91///
92/// Derived types must provide a method, which is called when an environmental
93/// resolution is required. It must raise an exception if resolution fails:
94/// static ReferrentTy &resolve()
95///
96/// They must also provide a parameter description that will be used in
97/// error messages about mismatched types:
98/// static constexpr const char kTypeDescription[] = "<Description>";
99
100template <typename DerivedTy, typename T>
102public:
103 using ReferrentTy = T;
104 /// Type casters require the type to be default constructible, but using
105 /// such an instance is illegal.
106 Defaulting() = default;
107 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
108
109 ReferrentTy *get() const { return referrent; }
110 ReferrentTy *operator->() { return referrent; }
111
112private:
113 ReferrentTy *referrent = nullptr;
114};
115
116} // namespace python
117} // namespace mlir
118
119namespace nanobind {
120namespace detail {
121
122/// Helper function to concatenate arguments into a `std::string`.
123template <typename... Ts>
124inline std::string join(const Ts &...args) {
125 std::ostringstream oss;
126 (oss << ... << args);
127 return oss.str();
128}
129
130template <typename DefaultingTy>
132 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
133
134 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
135 if (src.is_none()) {
136 // Note that we do want an exception to propagate from here as it will be
137 // the most informative.
138 value = DefaultingTy{DefaultingTy::resolve()};
139 return true;
140 }
141
142 // Unlike many casters that chain, these casters are expected to always
143 // succeed, so instead of doing an isinstance check followed by a cast,
144 // just cast in one step and handle the exception. Returning false (vs
145 // letting the exception propagate) causes higher level signature parsing
146 // code to produce nice error messages (other than "Cannot cast...").
147 try {
148 value = DefaultingTy{
149 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
150 return true;
151 } catch (std::exception &) {
152 return false;
153 }
154 }
155
156 static handle from_cpp(DefaultingTy src, rv_policy policy,
157 cleanup_list *cleanup) noexcept {
158 return nanobind::cast(src, policy);
159 }
160};
161} // namespace detail
162} // namespace nanobind
163
164//------------------------------------------------------------------------------
165// Conversion utilities.
166//------------------------------------------------------------------------------
167
168namespace mlir {
169
170/// Accumulates into a python string from a method that accepts an
171/// MlirStringCallback.
173 nanobind::list parts;
174
175 void *getUserData() { return this; }
176
178 return [](MlirStringRef part, void *userData) {
179 PyPrintAccumulator *printAccum =
180 static_cast<PyPrintAccumulator *>(userData);
181 nanobind::str pyPart(part.data,
182 part.length); // Decodes as UTF-8 by default.
183 printAccum->parts.append(std::move(pyPart));
184 };
185 }
186
187 nanobind::str join() {
188 nanobind::str delim("", 0);
189 return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
190 }
191};
192
193/// RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope
194/// exit.
203
204/// Accumulates into a file, either writing text (default)
205/// or binary. The file may be a Python file-like object or a path to a file.
207public:
208 PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
209 : binary(binary) {
210 std::string filePath;
211 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
212 std::string errorMessage;
213 auto errorCallback = +[](MlirStringRef message, void *userData) {
214 auto *storage = static_cast<std::string *>(userData);
215 storage->assign(message.data, message.length);
216 };
218 filePath.c_str(), binary, errorCallback, &errorMessage);
219 if (mlirLlvmRawFdOStreamIsNull(stream)) {
220 throw nanobind::value_error(
221 (std::string("Unable to open file for writing: ") + errorMessage)
222 .c_str());
223 }
224 writeTarget.emplace<RAIIMlirLlvmRawFdOStream>(stream);
225 } else {
226 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
227 }
228 }
229
231 return writeTarget.index() == 0 ? getPyWriteCallback()
232 : getOStreamCallback();
233 }
234
235 void *getUserData() { return this; }
236
237private:
238 MlirStringCallback getPyWriteCallback() {
239 return [](MlirStringRef part, void *userData) {
240 nanobind::gil_scoped_acquire acquire;
241 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
242 if (accum->binary) {
243 // Note: Still has to copy and not avoidable with this API.
244 nanobind::bytes pyBytes(part.data, part.length);
245 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
246 } else {
247 nanobind::str pyStr(part.data,
248 part.length); // Decodes as UTF-8 by default.
249 std::get<nanobind::object>(accum->writeTarget)(pyStr);
250 }
251 };
252 }
253
254 MlirStringCallback getOStreamCallback() {
255 return [](MlirStringRef part, void *userData) {
256 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
258 std::get<RAIIMlirLlvmRawFdOStream>(accum->writeTarget), part);
259 };
260 }
261
262 std::variant<nanobind::object, RAIIMlirLlvmRawFdOStream> writeTarget;
263 bool binary;
264};
265
266/// Accumulates into a python string from a method that is expected to make
267/// one (no more, no less) call to the callback (asserts internally on
268/// violation).
270 void *getUserData() { return this; }
271
273 return [](MlirStringRef part, void *userData) {
275 static_cast<PySinglePartStringAccumulator *>(userData);
276 assert(!accum->invoked &&
277 "PySinglePartStringAccumulator called back multiple times");
278 accum->invoked = true;
279 accum->value = nanobind::str(part.data, part.length);
280 };
281 }
282
283 nanobind::str takeValue() {
284 assert(invoked && "PySinglePartStringAccumulator not called back");
285 return std::move(value);
286 }
287
288private:
289 nanobind::str value;
290 bool invoked = false;
291};
292
293/// A CRTP base class for pseudo-containers willing to support Python-type
294/// slicing access on top of indexed access. Calling ::bind on this class
295/// will define `__len__` as well as `__getitem__` with integer and slice
296/// arguments.
297///
298/// This is intended for pseudo-containers that can refer to arbitrary slices of
299/// underlying storage indexed by a single integer. Indexing those with an
300/// integer produces an instance of ElementTy. Indexing those with a slice
301/// produces a new instance of Derived, which can be sliced further.
302///
303/// A derived class must provide the following:
304/// - a `static const char *pyClassName ` field containing the name of the
305/// Python class to bind;
306/// - an instance method `intptr_t getRawNumElements()` that returns the
307/// number
308/// of elements in the backing container (NOT that of the slice);
309/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
310/// single element at the given linear index (NOT slice index);
311/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
312/// constructs a new instance of the derived pseudo-container with the
313/// given slice parameters (to be forwarded to the Sliceable constructor).
314///
315/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
316/// throw.
317///
318/// A derived class may additionally define:
319/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
320/// the python class.
321template <typename Derived, typename ElementTy>
323protected:
324 using ClassTy = nanobind::class_<Derived>;
325
326 /// Transforms `index` into a legal value to access the underlying sequence.
327 /// Returns <0 on failure.
329 if (index < 0)
330 index = length + index;
331 if (index < 0 || index >= length)
332 return -1;
333 return index;
334 }
335
336 /// Computes the linear index given the current slice properties.
338 intptr_t linearIndex = index * step + startIndex;
339 assert(linearIndex >= 0 &&
340 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
341 "linear index out of bounds, the slice is ill-formed");
342 return linearIndex;
343 }
344
345 /// Trait to check if T provides a `maybeDownCast` method.
346 /// Note, you need the & to detect inherited members.
347 template <typename T, typename = void>
348 struct has_maybe_downcast : std::false_type {};
349
350 template <typename T>
351 struct has_maybe_downcast<T, std::void_t<decltype(&T::maybeDownCast)>>
352 : std::true_type {};
353
354 /// Returns the element at the given slice index. Supports negative indices
355 /// by taking elements in inverse order. Returns a nullptr object if out
356 /// of bounds.
357 nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
358 // Negative indices mean we count from the end.
360 if (index < 0) {
361 PyErr_SetString(PyExc_IndexError, "index out of range");
362 return {};
363 }
364
365 if constexpr (has_maybe_downcast<ElementTy>::value)
366 return static_cast<Derived *>(this)
367 ->getRawElement(linearizeIndex(index))
368 .maybeDownCast();
369 else
370 return nanobind::cast(
371 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
372 }
373
374 /// Returns a new instance of the pseudo-container restricted to the given
375 /// slice. Returns a nullptr object on failure.
376 nanobind::object getItemSlice(PyObject *slice) {
377 Py_ssize_t start, stop, extraStep, sliceLength;
378 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
379 &sliceLength) != 0) {
380 PyErr_SetString(PyExc_IndexError, "index out of range");
381 return {};
382 }
383 return nanobind::cast(static_cast<Derived *>(this)->slice(
384 startIndex + start * step, sliceLength, step * extraStep));
385 }
386
387public:
390 assert(length >= 0 && "expected non-negative slice length");
391 }
392
393 /// Returns the `index`-th element in the slice, supports negative indices.
394 /// Throws if the index is out of bounds.
396 // Negative indices mean we count from the end.
398 if (index < 0) {
399 throw nanobind::index_error("index out of range");
400 }
401
402 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
403 }
404
405 /// Returns the size of slice.
406 intptr_t size() { return length; }
407
408 /// Returns a new vector (mapped to Python list) containing elements from two
409 /// slices. The new vector is necessary because slices may not be contiguous
410 /// or even come from the same original sequence.
411 std::vector<ElementTy> dunderAdd(Derived &other) {
412 std::vector<ElementTy> elements;
413 elements.reserve(length + other.length);
414 for (intptr_t i = 0; i < length; ++i) {
415 elements.push_back(static_cast<Derived *>(this)->getElement(i));
416 }
417 for (intptr_t i = 0; i < other.length; ++i) {
418 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
419 }
420 return elements;
421 }
422
423 // Manually implement the sequence protocol via the C API. We do this
424 // because it is approx 4x faster than via nanobind, largely because that
425 // formulation requires a C++ exception to be thrown to detect end of
426 // sequence.
427 // Since we are in a C-context, any C++ exception that happens here
428 // will terminate the program. There is nothing in this implementation
429 // that should throw in a non-terminal way, so we forgo further
430 // exception marshalling.
431 // See: https://github.com/pybind/pybind11/issues/2842
432 //
433 /// Binds the indexing and length methods in the Python class.
434 static void bind(nanobind::module_ &m) {
435 // These slots are passed via nanobind::type_slots() at class creation
436 // time, which is compatible with both the full and limited (stable ABI)
437 // Python APIs.
438 static PyType_Slot sequenceSlots[] = {
439 {Py_sq_length, (void *)(+[](PyObject *rawSelf) -> Py_ssize_t {
440 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
441 return self->length;
442 })},
443 // sq_item is called as part of the sequence protocol for iteration,
444 // list construction, etc.
445 {Py_sq_item,
446 (void *)(+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
447 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
448 return self->getItem(index).release().ptr();
449 })},
450 // mp_subscript is used for both slices and integer lookups.
451 {Py_mp_subscript,
452 (void *)(+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
453 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
454 Py_ssize_t index =
455 PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
456 if (!PyErr_Occurred()) {
457 // Integer indexing.
458 return self->getItem(index).release().ptr();
459 }
460 PyErr_Clear();
461
462 // Assume slice-based indexing.
463 if (PySlice_Check(rawSubscript)) {
464 return self->getItemSlice(rawSubscript).release().ptr();
465 }
466
467 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
468 return nullptr;
469 })},
470 {0, nullptr}};
471 const std::type_info &elemTy = typeid(ElementTy);
472 PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
473 assert(elemTyInfo &&
474 "expected nb_type_lookup to succeed for Sliceable elemTy");
475 nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
476 std::string sig = std::string("class ") + Derived::pyClassName +
477 "(collections.abc.Sequence[" +
478 nanobind::cast<std::string>(elemTyName) + "])";
479 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
480 nanobind::type_slots(sequenceSlots),
481 nanobind::sig(sig.c_str()))
482 .def("__add__", &Sliceable::dunderAdd);
483 Derived::bindDerived(clazz);
484 }
485
486 /// Hook for derived classes willing to bind more methods.
487 static void bindDerived(ClassTy &) {}
488
492};
493
494} // namespace mlir
495
496#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
lhs
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
nanobind::typed< nanobind::object, ElementTy > getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< PyOpResultList > ClassTy
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
intptr_t size()
Returns the size of slice.
ReferrentTy * operator->()
Defaulting(ReferrentTy &referrent)
ReferrentTy * get() const
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
std::unique_ptr< T >(* F)()
MLIR_CAPI_EXPORTED bool mlirLlvmRawFdOStreamIsNull(MlirLlvmRawFdOStream stream)
Checks if a raw_fd_ostream is null.
Definition Support.cpp:69
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamWrite(MlirLlvmRawFdOStream stream, MlirStringRef string)
Write a string to a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:64
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamDestroy(MlirLlvmRawFdOStream stream)
Destroy a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:73
struct MlirStringRef MlirStringRef
Definition Support.h:82
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition Support.h:110
MLIR_CAPI_EXPORTED MlirLlvmRawFdOStream mlirLlvmRawFdOStreamCreate(const char *path, bool binary, MlirStringCallback errorCallback, void *userData)
Create a raw_fd_ostream for the given path.
Definition Support.cpp:47
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope exit.
RAIIMlirLlvmRawFdOStream & operator=(const RAIIMlirLlvmRawFdOStream &)=delete
RAIIMlirLlvmRawFdOStream(MlirLlvmRawFdOStream stream)
RAIIMlirLlvmRawFdOStream(const RAIIMlirLlvmRawFdOStream &)=delete
Trait to check if T provides a maybeDownCast method.
bool operator()(MlirTypeID lhs, MlirTypeID rhs) const
size_t operator()(MlirTypeID typeID) const
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept