MLIR 23.0.0git
NanobindUtils.h
Go to the documentation of this file.
1//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12
13#include "mlir-c/Support.h"
15
16#include <array>
17#include <atomic>
18#include <fstream>
19#include <memory>
20#include <sstream>
21#include <string>
22#include <string_view>
23#include <type_traits>
24#include <typeinfo>
25#include <variant>
26
27template <>
28struct std::iterator_traits<nanobind::detail::fast_iterator> {
29 using value_type = nanobind::handle;
30 using reference = const value_type;
31 using pointer = void;
32 using difference_type = std::ptrdiff_t;
33 using iterator_category = std::forward_iterator_tag;
34};
35
36namespace mlir {
37namespace python {
38
39/// Safely calls Python initialization code on first use, avoiding deadlocks.
40template <typename T>
41class SafeInit {
42public:
43 typedef std::unique_ptr<T> (*F)();
44
45 explicit SafeInit(F init_fn) : initFn(init_fn) {}
46
47 T &get() {
48 if (T *result = output.load()) {
49 return *result;
50 }
51
52 // Note: init_fn() may be called multiple times if, for example, the GIL is
53 // released during its execution. The intended use case is for module
54 // imports which are safe to perform multiple times. We are careful not to
55 // hold a lock across init_fn() to avoid lock ordering problems.
56 std::unique_ptr<T> m = initFn();
57 {
58 nanobind::ft_lock_guard lock(mu);
59 if (T *result = output.load()) {
60 return *result;
61 }
62 T *p = m.release();
63 output.store(p);
64 return *p;
65 }
66 }
67
68private:
69 nanobind::ft_mutex mu;
70 std::atomic<T *> output{nullptr};
71 F initFn;
72};
73
75 size_t operator()(MlirTypeID typeID) const {
76 return mlirTypeIDHashValue(typeID);
77 }
78};
79
81 bool operator()(MlirTypeID lhs, MlirTypeID rhs) const {
82 return mlirTypeIDEqual(lhs, rhs);
83 }
84};
85
86/// CRTP template for special wrapper types that are allowed to be passed in as
87/// 'None' function arguments and can be resolved by some global mechanic if
88/// so. Such types will raise an error if this global resolution fails, and
89/// it is actually illegal for them to ever be unresolved. From a user
90/// perspective, they behave like a smart ptr to the underlying type (i.e.
91/// 'get' method and operator-> overloaded).
92///
93/// Derived types must provide a method, which is called when an environmental
94/// resolution is required. It must raise an exception if resolution fails:
95/// static ReferrentTy &resolve()
96///
97/// They must also provide a parameter description that will be used in
98/// error messages about mismatched types:
99/// static constexpr const char kTypeDescription[] = "<Description>";
100
101template <typename DerivedTy, typename T>
103public:
104 using ReferrentTy = T;
105 /// Type casters require the type to be default constructible, but using
106 /// such an instance is illegal.
107 Defaulting() = default;
108 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
109
110 ReferrentTy *get() const { return referrent; }
111 ReferrentTy *operator->() { return referrent; }
112
113private:
114 ReferrentTy *referrent = nullptr;
115};
116
117} // namespace python
118} // namespace mlir
119
120namespace nanobind {
121namespace detail {
122
123/// Helper function to concatenate arguments into a `std::string`.
124template <typename... Ts>
125inline std::string join(const Ts &...args) {
126 std::ostringstream oss;
127 (oss << ... << args);
128 return oss.str();
129}
130
131template <typename DefaultingTy>
133 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
134
135 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
136 if (src.is_none()) {
137 // Note that we do want an exception to propagate from here as it will be
138 // the most informative.
139 value = DefaultingTy{DefaultingTy::resolve()};
140 return true;
141 }
142
143 // Unlike many casters that chain, these casters are expected to always
144 // succeed, so instead of doing an isinstance check followed by a cast,
145 // just cast in one step and handle the exception. Returning false (vs
146 // letting the exception propagate) causes higher level signature parsing
147 // code to produce nice error messages (other than "Cannot cast...").
148 try {
149 value = DefaultingTy{
150 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
151 return true;
152 } catch (std::exception &) {
153 return false;
154 }
155 }
156
157 static handle from_cpp(DefaultingTy src, rv_policy policy,
158 cleanup_list *cleanup) noexcept {
159 return nanobind::cast(src, policy);
160 }
161};
162} // namespace detail
163} // namespace nanobind
164
165//------------------------------------------------------------------------------
166// Conversion utilities.
167//------------------------------------------------------------------------------
168
169namespace mlir {
170
171/// Accumulates into a python string from a method that accepts an
172/// MlirStringCallback.
174 nanobind::list parts;
175
176 void *getUserData() { return this; }
177
179 return [](MlirStringRef part, void *userData) {
180 PyPrintAccumulator *printAccum =
181 static_cast<PyPrintAccumulator *>(userData);
182 nanobind::str pyPart(part.data,
183 part.length); // Decodes as UTF-8 by default.
184 printAccum->parts.append(std::move(pyPart));
185 };
186 }
187
188 nanobind::str join() {
189 nanobind::str delim("", 0);
190 return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
191 }
192};
193
194/// RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope
195/// exit.
204
205/// Accumulates into a file, either writing text (default)
206/// or binary. The file may be a Python file-like object or a path to a file.
208public:
209 PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
210 : binary(binary) {
211 std::string filePath;
212 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
213 std::string errorMessage;
214 auto errorCallback = +[](MlirStringRef message, void *userData) {
215 auto *storage = static_cast<std::string *>(userData);
216 storage->assign(message.data, message.length);
217 };
219 filePath.c_str(), binary, errorCallback, &errorMessage);
220 if (mlirLlvmRawFdOStreamIsNull(stream)) {
221 throw nanobind::value_error(
222 (std::string("Unable to open file for writing: ") + errorMessage)
223 .c_str());
224 }
225 writeTarget.emplace<RAIIMlirLlvmRawFdOStream>(stream);
226 } else {
227 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
228 }
229 }
230
232 return writeTarget.index() == 0 ? getPyWriteCallback()
233 : getOStreamCallback();
234 }
235
236 void *getUserData() { return this; }
237
238private:
239 MlirStringCallback getPyWriteCallback() {
240 return [](MlirStringRef part, void *userData) {
241 nanobind::gil_scoped_acquire acquire;
242 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
243 if (accum->binary) {
244 // Note: Still has to copy and not avoidable with this API.
245 nanobind::bytes pyBytes(part.data, part.length);
246 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
247 } else {
248 nanobind::str pyStr(part.data,
249 part.length); // Decodes as UTF-8 by default.
250 std::get<nanobind::object>(accum->writeTarget)(pyStr);
251 }
252 };
253 }
254
255 MlirStringCallback getOStreamCallback() {
256 return [](MlirStringRef part, void *userData) {
257 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
259 std::get<RAIIMlirLlvmRawFdOStream>(accum->writeTarget), part);
260 };
261 }
262
263 std::variant<nanobind::object, RAIIMlirLlvmRawFdOStream> writeTarget;
264 bool binary;
265};
266
267/// Accumulates into a python string from a method that is expected to make
268/// one (no more, no less) call to the callback (asserts internally on
269/// violation).
271 void *getUserData() { return this; }
272
274 return [](MlirStringRef part, void *userData) {
276 static_cast<PySinglePartStringAccumulator *>(userData);
277 assert(!accum->invoked &&
278 "PySinglePartStringAccumulator called back multiple times");
279 accum->invoked = true;
280 accum->value = nanobind::str(part.data, part.length);
281 };
282 }
283
284 nanobind::str takeValue() {
285 assert(invoked && "PySinglePartStringAccumulator not called back");
286 return std::move(value);
287 }
288
289private:
290 nanobind::str value;
291 bool invoked = false;
292};
293
294/// A CRTP base class for pseudo-containers willing to support Python-type
295/// slicing access on top of indexed access. Calling ::bind on this class
296/// will define `__len__` as well as `__getitem__` with integer and slice
297/// arguments.
298///
299/// This is intended for pseudo-containers that can refer to arbitrary slices of
300/// underlying storage indexed by a single integer. Indexing those with an
301/// integer produces an instance of ElementTy. Indexing those with a slice
302/// produces a new instance of Derived, which can be sliced further.
303///
304/// A derived class must provide the following:
305/// - a `static const char *pyClassName ` field containing the name of the
306/// Python class to bind;
307/// - an instance method `intptr_t getRawNumElements()` that returns the
308/// number
309/// of elements in the backing container (NOT that of the slice);
310/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
311/// single element at the given linear index (NOT slice index);
312/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
313/// constructs a new instance of the derived pseudo-container with the
314/// given slice parameters (to be forwarded to the Sliceable constructor).
315///
316/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
317/// throw.
318///
319/// A derived class may additionally define:
320/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
321/// the python class.
322/// - a `static constexpr std::array<const char *, N> typeParams` to make the
323/// Python class generic, parameterizable with the given type parameters.
324template <typename Derived, typename ElementTy>
326protected:
327 using ClassTy = nanobind::class_<Derived>;
328
329 /// Type parameter names for generic classes. When non-empty, the Python
330 /// class will be made generic with `typing.Generic[...]`.
331 static constexpr std::array<const char *, 0> typeParams = {};
332
333 /// Transforms `index` into a legal value to access the underlying sequence.
334 /// Returns <0 on failure.
336 if (index < 0)
337 index = length + index;
338 if (index < 0 || index >= length)
339 return -1;
340 return index;
341 }
342
343 /// Computes the linear index given the current slice properties.
345 intptr_t linearIndex = index * step + startIndex;
346 assert(linearIndex >= 0 &&
347 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
348 "linear index out of bounds, the slice is ill-formed");
349 return linearIndex;
350 }
351
352 /// Trait to check if T provides a `maybeDownCast` method.
353 /// Note, you need the & to detect inherited members.
354 template <typename T, typename = void>
355 struct has_maybe_downcast : std::false_type {};
356
357 template <typename T>
358 struct has_maybe_downcast<T, std::void_t<decltype(&T::maybeDownCast)>>
359 : std::true_type {};
360
361 /// Returns the element at the given slice index. Supports negative indices
362 /// by taking elements in inverse order. Returns a nullptr object if out
363 /// of bounds.
364 nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
365 // Negative indices mean we count from the end.
367 if (index < 0) {
368 PyErr_SetString(PyExc_IndexError, "index out of range");
369 return {};
370 }
371
372 if constexpr (has_maybe_downcast<ElementTy>::value)
373 return static_cast<Derived *>(this)
374 ->getRawElement(linearizeIndex(index))
375 .maybeDownCast();
376 else
377 return nanobind::cast(
378 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
379 }
380
381 /// Returns a new instance of the pseudo-container restricted to the given
382 /// slice. Returns a nullptr object on failure.
383 nanobind::object getItemSlice(PyObject *slice) {
384 Py_ssize_t start, stop, extraStep, sliceLength;
385 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
386 &sliceLength) != 0) {
387 PyErr_SetString(PyExc_IndexError, "index out of range");
388 return {};
389 }
390 return nanobind::cast(static_cast<Derived *>(this)->slice(
391 startIndex + start * step, sliceLength, step * extraStep));
392 }
393
394public:
397 assert(length >= 0 && "expected non-negative slice length");
398 }
399
400 /// Returns the `index`-th element in the slice, supports negative indices.
401 /// Throws if the index is out of bounds.
403 // Negative indices mean we count from the end.
405 if (index < 0) {
406 throw nanobind::index_error("index out of range");
407 }
408
409 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
410 }
411
412 /// Returns the size of slice.
413 intptr_t size() { return length; }
414
415 /// Returns a new vector (mapped to Python list) containing elements from two
416 /// slices. The new vector is necessary because slices may not be contiguous
417 /// or even come from the same original sequence.
418 std::vector<ElementTy> dunderAdd(Derived &other) {
419 std::vector<ElementTy> elements;
420 elements.reserve(length + other.length);
421 for (intptr_t i = 0; i < length; ++i) {
422 elements.push_back(static_cast<Derived *>(this)->getElement(i));
423 }
424 for (intptr_t i = 0; i < other.length; ++i) {
425 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
426 }
427 return elements;
428 }
429
430 // Manually implement the sequence protocol via the C API. We do this
431 // because it is approx 4x faster than via nanobind, largely because that
432 // formulation requires a C++ exception to be thrown to detect end of
433 // sequence.
434 // Since we are in a C-context, any C++ exception that happens here
435 // will terminate the program. There is nothing in this implementation
436 // that should throw in a non-terminal way, so we forgo further
437 // exception marshalling.
438 // See: https://github.com/pybind/pybind11/issues/2842
439 //
440 /// Binds the indexing and length methods in the Python class.
441 static void bind(nanobind::module_ &m) {
442 // These slots are passed via nanobind::type_slots() at class creation
443 // time, which is compatible with both the full and limited (stable ABI)
444 // Python APIs.
445 static PyType_Slot sequenceSlots[] = {
446 {Py_sq_length, (void *)(+[](PyObject *rawSelf) -> Py_ssize_t {
447 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
448 return self->length;
449 })},
450 // sq_item is called as part of the sequence protocol for iteration,
451 // list construction, etc.
452 {Py_sq_item,
453 (void *)(+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
454 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
455 return self->getItem(index).release().ptr();
456 })},
457 // mp_subscript is used for both slices and integer lookups.
458 {Py_mp_subscript,
459 (void *)(+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
460 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
461 Py_ssize_t index =
462 PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
463 if (!PyErr_Occurred()) {
464 // Integer indexing.
465 return self->getItem(index).release().ptr();
466 }
467 PyErr_Clear();
468
469 // Assume slice-based indexing.
470 if (PySlice_Check(rawSubscript)) {
471 return self->getItemSlice(rawSubscript).release().ptr();
472 }
473
474 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
475 return nullptr;
476 })},
477 {0, nullptr}};
478 const std::type_info &elemTy = typeid(ElementTy);
479 PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
480 assert(elemTyInfo &&
481 "expected nb_type_lookup to succeed for Sliceable elemTy");
482 nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
483 std::string sig = std::string("class ") + Derived::pyClassName +
484 "(collections.abc.Sequence[" +
485 nanobind::cast<std::string>(elemTyName) + "]";
486 if constexpr (!Derived::typeParams.empty()) {
487 sig += ", typing.Generic[";
488 for (size_t i = 0; i < Derived::typeParams.size(); ++i) {
489 if (i > 0)
490 sig += ", ";
491 const char *tp = Derived::typeParams[i];
492 sig += tp;
493 if (!nanobind::hasattr(m, tp))
494 m.attr(tp) = nanobind::type_var(tp);
495 }
496 sig += "]";
497 }
498 sig += ")";
499 ClassTy clazz;
500 if constexpr (!Derived::typeParams.empty()) {
501 clazz =
502 ClassTy(m, Derived::pyClassName, nanobind::type_slots(sequenceSlots),
503 nanobind::is_generic(), nanobind::sig(sig.c_str()));
504 } else {
505 clazz =
506 ClassTy(m, Derived::pyClassName, nanobind::type_slots(sequenceSlots),
507 nanobind::sig(sig.c_str()));
508 }
509 clazz.def("__add__", &Sliceable::dunderAdd);
510 Derived::bindDerived(clazz);
511 }
512
513 /// Hook for derived classes willing to bind more methods.
514 static void bindDerived(ClassTy &) {}
515
519};
520
521} // namespace mlir
522
523#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
lhs
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
nanobind::typed< nanobind::object, ElementTy > getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< PyOpResultList > ClassTy
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
static constexpr std::array< const char *, 0 > typeParams
intptr_t size()
Returns the size of slice.
ReferrentTy * operator->()
Defaulting(ReferrentTy &referrent)
ReferrentTy * get() const
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
std::unique_ptr< T >(* F)()
MLIR_CAPI_EXPORTED bool mlirLlvmRawFdOStreamIsNull(MlirLlvmRawFdOStream stream)
Checks if a raw_fd_ostream is null.
Definition Support.cpp:69
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamWrite(MlirLlvmRawFdOStream stream, MlirStringRef string)
Write a string to a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:64
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamDestroy(MlirLlvmRawFdOStream stream)
Destroy a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:73
struct MlirStringRef MlirStringRef
Definition Support.h:82
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition Support.h:110
MLIR_CAPI_EXPORTED MlirLlvmRawFdOStream mlirLlvmRawFdOStreamCreate(const char *path, bool binary, MlirStringCallback errorCallback, void *userData)
Create a raw_fd_ostream for the given path.
Definition Support.cpp:47
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope exit.
RAIIMlirLlvmRawFdOStream & operator=(const RAIIMlirLlvmRawFdOStream &)=delete
RAIIMlirLlvmRawFdOStream(MlirLlvmRawFdOStream stream)
RAIIMlirLlvmRawFdOStream(const RAIIMlirLlvmRawFdOStream &)=delete
Trait to check if T provides a maybeDownCast method.
bool operator()(MlirTypeID lhs, MlirTypeID rhs) const
size_t operator()(MlirTypeID typeID) const
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept