MLIR 23.0.0git
NanobindUtils.h
Go to the documentation of this file.
1//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12
13#include "mlir-c/Support.h"
15
16#include <fstream>
17#include <sstream>
18#include <string>
19#include <string_view>
20#include <type_traits>
21#include <typeinfo>
22#include <variant>
23
24template <>
25struct std::iterator_traits<nanobind::detail::fast_iterator> {
26 using value_type = nanobind::handle;
27 using reference = const value_type;
28 using pointer = void;
29 using difference_type = std::ptrdiff_t;
30 using iterator_category = std::forward_iterator_tag;
31};
32
33namespace mlir {
34namespace python {
35
37 size_t operator()(MlirTypeID typeID) const {
38 return mlirTypeIDHashValue(typeID);
39 }
40};
41
43 bool operator()(MlirTypeID lhs, MlirTypeID rhs) const {
44 return mlirTypeIDEqual(lhs, rhs);
45 }
46};
47
48/// CRTP template for special wrapper types that are allowed to be passed in as
49/// 'None' function arguments and can be resolved by some global mechanic if
50/// so. Such types will raise an error if this global resolution fails, and
51/// it is actually illegal for them to ever be unresolved. From a user
52/// perspective, they behave like a smart ptr to the underlying type (i.e.
53/// 'get' method and operator-> overloaded).
54///
55/// Derived types must provide a method, which is called when an environmental
56/// resolution is required. It must raise an exception if resolution fails:
57/// static ReferrentTy &resolve()
58///
59/// They must also provide a parameter description that will be used in
60/// error messages about mismatched types:
61/// static constexpr const char kTypeDescription[] = "<Description>";
62
63template <typename DerivedTy, typename T>
65public:
66 using ReferrentTy = T;
67 /// Type casters require the type to be default constructible, but using
68 /// such an instance is illegal.
69 Defaulting() = default;
70 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
71
72 ReferrentTy *get() const { return referrent; }
73 ReferrentTy *operator->() { return referrent; }
74
75private:
76 ReferrentTy *referrent = nullptr;
77};
78
79} // namespace python
80} // namespace mlir
81
82namespace nanobind {
83namespace detail {
84
85/// Helper function to concatenate arguments into a `std::string`.
86template <typename... Ts>
87inline std::string join(const Ts &...args) {
88 std::ostringstream oss;
89 (oss << ... << args);
90 return oss.str();
91}
92
93template <typename DefaultingTy>
95 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
96
97 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
98 if (src.is_none()) {
99 // Note that we do want an exception to propagate from here as it will be
100 // the most informative.
101 value = DefaultingTy{DefaultingTy::resolve()};
102 return true;
103 }
104
105 // Unlike many casters that chain, these casters are expected to always
106 // succeed, so instead of doing an isinstance check followed by a cast,
107 // just cast in one step and handle the exception. Returning false (vs
108 // letting the exception propagate) causes higher level signature parsing
109 // code to produce nice error messages (other than "Cannot cast...").
110 try {
111 value = DefaultingTy{
112 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
113 return true;
114 } catch (std::exception &) {
115 return false;
116 }
117 }
118
119 static handle from_cpp(DefaultingTy src, rv_policy policy,
120 cleanup_list *cleanup) noexcept {
121 return nanobind::cast(src, policy);
122 }
123};
124} // namespace detail
125} // namespace nanobind
126
127//------------------------------------------------------------------------------
128// Conversion utilities.
129//------------------------------------------------------------------------------
130
131namespace mlir {
132
133/// Accumulates into a python string from a method that accepts an
134/// MlirStringCallback.
136 nanobind::list parts;
137
138 void *getUserData() { return this; }
139
141 return [](MlirStringRef part, void *userData) {
142 PyPrintAccumulator *printAccum =
143 static_cast<PyPrintAccumulator *>(userData);
144 nanobind::str pyPart(part.data,
145 part.length); // Decodes as UTF-8 by default.
146 printAccum->parts.append(std::move(pyPart));
147 };
148 }
149
150 nanobind::str join() {
151 nanobind::str delim("", 0);
152 return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
153 }
154};
155
156/// RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope
157/// exit.
166
167/// Accumulates into a file, either writing text (default)
168/// or binary. The file may be a Python file-like object or a path to a file.
170public:
171 PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
172 : binary(binary) {
173 std::string filePath;
174 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
175 std::string errorMessage;
176 auto errorCallback = +[](MlirStringRef message, void *userData) {
177 auto *storage = static_cast<std::string *>(userData);
178 storage->assign(message.data, message.length);
179 };
181 filePath.c_str(), binary, errorCallback, &errorMessage);
182 if (mlirLlvmRawFdOStreamIsNull(stream)) {
183 throw nanobind::value_error(
184 (std::string("Unable to open file for writing: ") + errorMessage)
185 .c_str());
186 }
187 writeTarget.emplace<RAIIMlirLlvmRawFdOStream>(stream);
188 } else {
189 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
190 }
191 }
192
194 return writeTarget.index() == 0 ? getPyWriteCallback()
195 : getOStreamCallback();
196 }
197
198 void *getUserData() { return this; }
199
200private:
201 MlirStringCallback getPyWriteCallback() {
202 return [](MlirStringRef part, void *userData) {
203 nanobind::gil_scoped_acquire acquire;
204 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
205 if (accum->binary) {
206 // Note: Still has to copy and not avoidable with this API.
207 nanobind::bytes pyBytes(part.data, part.length);
208 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
209 } else {
210 nanobind::str pyStr(part.data,
211 part.length); // Decodes as UTF-8 by default.
212 std::get<nanobind::object>(accum->writeTarget)(pyStr);
213 }
214 };
215 }
216
217 MlirStringCallback getOStreamCallback() {
218 return [](MlirStringRef part, void *userData) {
219 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
221 std::get<RAIIMlirLlvmRawFdOStream>(accum->writeTarget), part);
222 };
223 }
224
225 std::variant<nanobind::object, RAIIMlirLlvmRawFdOStream> writeTarget;
226 bool binary;
227};
228
229/// Accumulates into a python string from a method that is expected to make
230/// one (no more, no less) call to the callback (asserts internally on
231/// violation).
233 void *getUserData() { return this; }
234
236 return [](MlirStringRef part, void *userData) {
238 static_cast<PySinglePartStringAccumulator *>(userData);
239 assert(!accum->invoked &&
240 "PySinglePartStringAccumulator called back multiple times");
241 accum->invoked = true;
242 accum->value = nanobind::str(part.data, part.length);
243 };
244 }
245
246 nanobind::str takeValue() {
247 assert(invoked && "PySinglePartStringAccumulator not called back");
248 return std::move(value);
249 }
250
251private:
252 nanobind::str value;
253 bool invoked = false;
254};
255
256/// A CRTP base class for pseudo-containers willing to support Python-type
257/// slicing access on top of indexed access. Calling ::bind on this class
258/// will define `__len__` as well as `__getitem__` with integer and slice
259/// arguments.
260///
261/// This is intended for pseudo-containers that can refer to arbitrary slices of
262/// underlying storage indexed by a single integer. Indexing those with an
263/// integer produces an instance of ElementTy. Indexing those with a slice
264/// produces a new instance of Derived, which can be sliced further.
265///
266/// A derived class must provide the following:
267/// - a `static const char *pyClassName ` field containing the name of the
268/// Python class to bind;
269/// - an instance method `intptr_t getRawNumElements()` that returns the
270/// number
271/// of elements in the backing container (NOT that of the slice);
272/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
273/// single element at the given linear index (NOT slice index);
274/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
275/// constructs a new instance of the derived pseudo-container with the
276/// given slice parameters (to be forwarded to the Sliceable constructor).
277///
278/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
279/// throw.
280///
281/// A derived class may additionally define:
282/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
283/// the python class.
284template <typename Derived, typename ElementTy>
286protected:
287 using ClassTy = nanobind::class_<Derived>;
288
289 /// Transforms `index` into a legal value to access the underlying sequence.
290 /// Returns <0 on failure.
292 if (index < 0)
293 index = length + index;
294 if (index < 0 || index >= length)
295 return -1;
296 return index;
297 }
298
299 /// Computes the linear index given the current slice properties.
301 intptr_t linearIndex = index * step + startIndex;
302 assert(linearIndex >= 0 &&
303 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
304 "linear index out of bounds, the slice is ill-formed");
305 return linearIndex;
306 }
307
308 /// Trait to check if T provides a `maybeDownCast` method.
309 /// Note, you need the & to detect inherited members.
310 template <typename T, typename = void>
311 struct has_maybe_downcast : std::false_type {};
312
313 template <typename T>
314 struct has_maybe_downcast<T, std::void_t<decltype(&T::maybeDownCast)>>
315 : std::true_type {};
316
317 /// Returns the element at the given slice index. Supports negative indices
318 /// by taking elements in inverse order. Returns a nullptr object if out
319 /// of bounds.
320 nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
321 // Negative indices mean we count from the end.
323 if (index < 0) {
324 PyErr_SetString(PyExc_IndexError, "index out of range");
325 return {};
326 }
327
328 if constexpr (has_maybe_downcast<ElementTy>::value)
329 return static_cast<Derived *>(this)
330 ->getRawElement(linearizeIndex(index))
331 .maybeDownCast();
332 else
333 return nanobind::cast(
334 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
335 }
336
337 /// Returns a new instance of the pseudo-container restricted to the given
338 /// slice. Returns a nullptr object on failure.
339 nanobind::object getItemSlice(PyObject *slice) {
340 Py_ssize_t start, stop, extraStep, sliceLength;
341 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
342 &sliceLength) != 0) {
343 PyErr_SetString(PyExc_IndexError, "index out of range");
344 return {};
345 }
346 return nanobind::cast(static_cast<Derived *>(this)->slice(
347 startIndex + start * step, sliceLength, step * extraStep));
348 }
349
350public:
353 assert(length >= 0 && "expected non-negative slice length");
354 }
355
356 /// Returns the `index`-th element in the slice, supports negative indices.
357 /// Throws if the index is out of bounds.
359 // Negative indices mean we count from the end.
361 if (index < 0) {
362 throw nanobind::index_error("index out of range");
363 }
364
365 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
366 }
367
368 /// Returns the size of slice.
369 intptr_t size() { return length; }
370
371 /// Returns a new vector (mapped to Python list) containing elements from two
372 /// slices. The new vector is necessary because slices may not be contiguous
373 /// or even come from the same original sequence.
374 std::vector<ElementTy> dunderAdd(Derived &other) {
375 std::vector<ElementTy> elements;
376 elements.reserve(length + other.length);
377 for (intptr_t i = 0; i < length; ++i) {
378 elements.push_back(static_cast<Derived *>(this)->getElement(i));
379 }
380 for (intptr_t i = 0; i < other.length; ++i) {
381 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
382 }
383 return elements;
384 }
385
386 /// Binds the indexing and length methods in the Python class.
387 static void bind(nanobind::module_ &m) {
388 const std::type_info &elemTy = typeid(ElementTy);
389 PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
390 assert(elemTyInfo &&
391 "expected nb_type_lookup to succeed for Sliceable elemTy");
392 nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
393 std::string sig = std::string("class ") + Derived::pyClassName +
394 "(collections.abc.Sequence[" +
395 nanobind::cast<std::string>(elemTyName) + "])";
396 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
397 nanobind::sig(sig.c_str()))
398 .def("__add__", &Sliceable::dunderAdd);
399 Derived::bindDerived(clazz);
400
401 // Manually implement the sequence protocol via the C API. We do this
402 // because it is approx 4x faster than via nanobind, largely because that
403 // formulation requires a C++ exception to be thrown to detect end of
404 // sequence.
405 // Since we are in a C-context, any C++ exception that happens here
406 // will terminate the program. There is nothing in this implementation
407 // that should throw in a non-terminal way, so we forgo further
408 // exception marshalling.
409 // See: https://github.com/pybind/nanobind/issues/2842
410 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
411 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
412 "must be heap type");
413 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
414 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
415 return self->length;
416 };
417 // sq_item is called as part of the sequence protocol for iteration,
418 // list construction, etc.
419 heap_type->as_sequence.sq_item =
420 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
421 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
422 return self->getItem(index).release().ptr();
423 };
424 // mp_subscript is used for both slices and integer lookups.
425 heap_type->as_mapping.mp_subscript =
426 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
427 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
428 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
429 if (!PyErr_Occurred()) {
430 // Integer indexing.
431 return self->getItem(index).release().ptr();
432 }
433 PyErr_Clear();
434
435 // Assume slice-based indexing.
436 if (PySlice_Check(rawSubscript)) {
437 return self->getItemSlice(rawSubscript).release().ptr();
438 }
439
440 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
441 return nullptr;
442 };
443 }
444
445 /// Hook for derived classes willing to bind more methods.
446 static void bindDerived(ClassTy &) {}
447
451};
452
453} // namespace mlir
454
455#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
lhs
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
nanobind::typed< nanobind::object, ElementTy > getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< PyOpResultList > ClassTy
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
intptr_t size()
Returns the size of slice.
ReferrentTy * operator->()
Defaulting(ReferrentTy &referrent)
ReferrentTy * get() const
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
MLIR_CAPI_EXPORTED bool mlirLlvmRawFdOStreamIsNull(MlirLlvmRawFdOStream stream)
Checks if a raw_fd_ostream is null.
Definition Support.cpp:69
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamWrite(MlirLlvmRawFdOStream stream, MlirStringRef string)
Write a string to a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:64
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:93
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamDestroy(MlirLlvmRawFdOStream stream)
Destroy a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
Definition Support.cpp:73
struct MlirStringRef MlirStringRef
Definition Support.h:82
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:89
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition Support.h:110
MLIR_CAPI_EXPORTED MlirLlvmRawFdOStream mlirLlvmRawFdOStreamCreate(const char *path, bool binary, MlirStringCallback errorCallback, void *userData)
Create a raw_fd_ostream for the given path.
Definition Support.cpp:47
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope exit.
RAIIMlirLlvmRawFdOStream & operator=(const RAIIMlirLlvmRawFdOStream &)=delete
RAIIMlirLlvmRawFdOStream(MlirLlvmRawFdOStream stream)
RAIIMlirLlvmRawFdOStream(const RAIIMlirLlvmRawFdOStream &)=delete
Trait to check if T provides a maybeDownCast method.
bool operator()(MlirTypeID lhs, MlirTypeID rhs) const
size_t operator()(MlirTypeID typeID) const
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept