MLIR 22.0.0git
NanobindUtils.h
Go to the documentation of this file.
1//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12
13#include "mlir-c/Support.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/ADT/Twine.h"
18#include "llvm/Support/DataTypes.h"
19#include "llvm/Support/raw_ostream.h"
20
21#include <string>
22#include <typeinfo>
23#include <variant>
24
25template <>
26struct std::iterator_traits<nanobind::detail::fast_iterator> {
27 using value_type = nanobind::handle;
28 using reference = const value_type;
29 using pointer = void;
30 using difference_type = std::ptrdiff_t;
31 using iterator_category = std::forward_iterator_tag;
32};
33
34namespace mlir {
35namespace python {
36
37/// CRTP template for special wrapper types that are allowed to be passed in as
38/// 'None' function arguments and can be resolved by some global mechanic if
39/// so. Such types will raise an error if this global resolution fails, and
40/// it is actually illegal for them to ever be unresolved. From a user
41/// perspective, they behave like a smart ptr to the underlying type (i.e.
42/// 'get' method and operator-> overloaded).
43///
44/// Derived types must provide a method, which is called when an environmental
45/// resolution is required. It must raise an exception if resolution fails:
46/// static ReferrentTy &resolve()
47///
48/// They must also provide a parameter description that will be used in
49/// error messages about mismatched types:
50/// static constexpr const char kTypeDescription[] = "<Description>";
51
52template <typename DerivedTy, typename T>
54public:
55 using ReferrentTy = T;
56 /// Type casters require the type to be default constructible, but using
57 /// such an instance is illegal.
58 Defaulting() = default;
59 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
60
61 ReferrentTy *get() const { return referrent; }
62 ReferrentTy *operator->() { return referrent; }
63
64private:
65 ReferrentTy *referrent = nullptr;
66};
67
68} // namespace python
69} // namespace mlir
70
71namespace nanobind {
72namespace detail {
73
74template <typename DefaultingTy>
76 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
77
78 bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
79 if (src.is_none()) {
80 // Note that we do want an exception to propagate from here as it will be
81 // the most informative.
82 value = DefaultingTy{DefaultingTy::resolve()};
83 return true;
84 }
85
86 // Unlike many casters that chain, these casters are expected to always
87 // succeed, so instead of doing an isinstance check followed by a cast,
88 // just cast in one step and handle the exception. Returning false (vs
89 // letting the exception propagate) causes higher level signature parsing
90 // code to produce nice error messages (other than "Cannot cast...").
91 try {
92 value = DefaultingTy{
93 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
94 return true;
95 } catch (std::exception &) {
96 return false;
97 }
98 }
99
100 static handle from_cpp(DefaultingTy src, rv_policy policy,
101 cleanup_list *cleanup) noexcept {
102 return nanobind::cast(src, policy);
103 }
104};
105} // namespace detail
106} // namespace nanobind
107
108//------------------------------------------------------------------------------
109// Conversion utilities.
110//------------------------------------------------------------------------------
111
112namespace mlir {
113
114/// Accumulates into a python string from a method that accepts an
115/// MlirStringCallback.
117 nanobind::list parts;
118
119 void *getUserData() { return this; }
120
122 return [](MlirStringRef part, void *userData) {
123 PyPrintAccumulator *printAccum =
124 static_cast<PyPrintAccumulator *>(userData);
125 nanobind::str pyPart(part.data,
126 part.length); // Decodes as UTF-8 by default.
127 printAccum->parts.append(std::move(pyPart));
128 };
129 }
130
131 nanobind::str join() {
132 nanobind::str delim("", 0);
133 return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
134 }
135};
136
137/// Accumulates into a file, either writing text (default)
138/// or binary. The file may be a Python file-like object or a path to a file.
140public:
141 PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
142 : binary(binary) {
143 std::string filePath;
144 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
145 std::error_code ec;
146 writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
147 if (ec) {
148 throw nanobind::value_error(
149 (std::string("Unable to open file for writing: ") + ec.message())
150 .c_str());
151 }
152 } else {
153 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
154 }
155 }
156
158 return writeTarget.index() == 0 ? getPyWriteCallback()
159 : getOstreamCallback();
160 }
161
162 void *getUserData() { return this; }
163
164private:
165 MlirStringCallback getPyWriteCallback() {
166 return [](MlirStringRef part, void *userData) {
167 nanobind::gil_scoped_acquire acquire;
168 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
169 if (accum->binary) {
170 // Note: Still has to copy and not avoidable with this API.
171 nanobind::bytes pyBytes(part.data, part.length);
172 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
173 } else {
174 nanobind::str pyStr(part.data,
175 part.length); // Decodes as UTF-8 by default.
176 std::get<nanobind::object>(accum->writeTarget)(pyStr);
177 }
178 };
179 }
180
181 MlirStringCallback getOstreamCallback() {
182 return [](MlirStringRef part, void *userData) {
183 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
184 std::get<llvm::raw_fd_ostream>(accum->writeTarget)
185 .write(part.data, part.length);
186 };
187 }
188
189 std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
190 bool binary;
191};
192
193/// Accumulates into a python string from a method that is expected to make
194/// one (no more, no less) call to the callback (asserts internally on
195/// violation).
197 void *getUserData() { return this; }
198
200 return [](MlirStringRef part, void *userData) {
202 static_cast<PySinglePartStringAccumulator *>(userData);
203 assert(!accum->invoked &&
204 "PySinglePartStringAccumulator called back multiple times");
205 accum->invoked = true;
206 accum->value = nanobind::str(part.data, part.length);
207 };
208 }
209
210 nanobind::str takeValue() {
211 assert(invoked && "PySinglePartStringAccumulator not called back");
212 return std::move(value);
213 }
214
215private:
216 nanobind::str value;
217 bool invoked = false;
218};
219
220/// A CRTP base class for pseudo-containers willing to support Python-type
221/// slicing access on top of indexed access. Calling ::bind on this class
222/// will define `__len__` as well as `__getitem__` with integer and slice
223/// arguments.
224///
225/// This is intended for pseudo-containers that can refer to arbitrary slices of
226/// underlying storage indexed by a single integer. Indexing those with an
227/// integer produces an instance of ElementTy. Indexing those with a slice
228/// produces a new instance of Derived, which can be sliced further.
229///
230/// A derived class must provide the following:
231/// - a `static const char *pyClassName ` field containing the name of the
232/// Python class to bind;
233/// - an instance method `intptr_t getRawNumElements()` that returns the
234/// number
235/// of elements in the backing container (NOT that of the slice);
236/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
237/// single element at the given linear index (NOT slice index);
238/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
239/// constructs a new instance of the derived pseudo-container with the
240/// given slice parameters (to be forwarded to the Sliceable constructor).
241///
242/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
243/// throw.
244///
245/// A derived class may additionally define:
246/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
247/// the python class.
248template <typename Derived, typename ElementTy>
250protected:
251 using ClassTy = nanobind::class_<Derived>;
252
253 /// Transforms `index` into a legal value to access the underlying sequence.
254 /// Returns <0 on failure.
255 intptr_t wrapIndex(intptr_t index) {
256 if (index < 0)
257 index = length + index;
258 if (index < 0 || index >= length)
259 return -1;
260 return index;
261 }
262
263 /// Computes the linear index given the current slice properties.
264 intptr_t linearizeIndex(intptr_t index) {
265 intptr_t linearIndex = index * step + startIndex;
266 assert(linearIndex >= 0 &&
267 linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
268 "linear index out of bounds, the slice is ill-formed");
269 return linearIndex;
270 }
271
272 /// Trait to check if T provides a `maybeDownCast` method.
273 /// Note, you need the & to detect inherited members.
274 template <typename T, typename... Args>
275 using has_maybe_downcast = decltype(&T::maybeDownCast);
276
277 /// Returns the element at the given slice index. Supports negative indices
278 /// by taking elements in inverse order. Returns a nullptr object if out
279 /// of bounds.
280 nanobind::object getItem(intptr_t index) {
281 // Negative indices mean we count from the end.
283 if (index < 0) {
284 PyErr_SetString(PyExc_IndexError, "index out of range");
285 return {};
286 }
287
288 if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
289 return static_cast<Derived *>(this)
290 ->getRawElement(linearizeIndex(index))
291 .maybeDownCast();
292 else
293 return nanobind::cast(
294 static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
295 }
296
297 /// Returns a new instance of the pseudo-container restricted to the given
298 /// slice. Returns a nullptr object on failure.
299 nanobind::object getItemSlice(PyObject *slice) {
300 ssize_t start, stop, extraStep, sliceLength;
301 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
302 &sliceLength) != 0) {
303 PyErr_SetString(PyExc_IndexError, "index out of range");
304 return {};
305 }
306 return nanobind::cast(static_cast<Derived *>(this)->slice(
307 startIndex + start * step, sliceLength, step * extraStep));
308 }
309
310public:
311 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
313 assert(length >= 0 && "expected non-negative slice length");
314 }
315
316 /// Returns the `index`-th element in the slice, supports negative indices.
317 /// Throws if the index is out of bounds.
318 ElementTy getElement(intptr_t index) {
319 // Negative indices mean we count from the end.
321 if (index < 0) {
322 throw nanobind::index_error("index out of range");
323 }
324
325 return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
326 }
327
328 /// Returns the size of slice.
329 intptr_t size() { return length; }
330
331 /// Returns a new vector (mapped to Python list) containing elements from two
332 /// slices. The new vector is necessary because slices may not be contiguous
333 /// or even come from the same original sequence.
334 std::vector<ElementTy> dunderAdd(Derived &other) {
335 std::vector<ElementTy> elements;
336 elements.reserve(length + other.length);
337 for (intptr_t i = 0; i < length; ++i) {
338 elements.push_back(static_cast<Derived *>(this)->getElement(i));
339 }
340 for (intptr_t i = 0; i < other.length; ++i) {
341 elements.push_back(static_cast<Derived *>(&other)->getElement(i));
342 }
343 return elements;
344 }
345
346 /// Binds the indexing and length methods in the Python class.
347 static void bind(nanobind::module_ &m) {
348 const std::type_info &elemTy = typeid(ElementTy);
349 PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
350 assert(elemTyInfo &&
351 "expected nb_type_lookup to succeed for Sliceable elemTy");
352 nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
353 std::string sig = std::string("class ") + Derived::pyClassName +
354 "(collections.abc.Sequence[" +
355 nanobind::cast<std::string>(elemTyName) + "])";
356 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
357 nanobind::sig(sig.c_str()))
358 .def("__add__", &Sliceable::dunderAdd);
359 Derived::bindDerived(clazz);
360
361 // Manually implement the sequence protocol via the C API. We do this
362 // because it is approx 4x faster than via nanobind, largely because that
363 // formulation requires a C++ exception to be thrown to detect end of
364 // sequence.
365 // Since we are in a C-context, any C++ exception that happens here
366 // will terminate the program. There is nothing in this implementation
367 // that should throw in a non-terminal way, so we forgo further
368 // exception marshalling.
369 // See: https://github.com/pybind/nanobind/issues/2842
370 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
371 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
372 "must be heap type");
373 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
374 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
375 return self->length;
376 };
377 // sq_item is called as part of the sequence protocol for iteration,
378 // list construction, etc.
379 heap_type->as_sequence.sq_item =
380 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
381 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
382 return self->getItem(index).release().ptr();
383 };
384 // mp_subscript is used for both slices and integer lookups.
385 heap_type->as_mapping.mp_subscript =
386 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
387 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
388 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
389 if (!PyErr_Occurred()) {
390 // Integer indexing.
391 return self->getItem(index).release().ptr();
392 }
393 PyErr_Clear();
394
395 // Assume slice-based indexing.
396 if (PySlice_Check(rawSubscript)) {
397 return self->getItemSlice(rawSubscript).release().ptr();
398 }
399
400 PyErr_SetString(PyExc_ValueError, "expected integer or slice");
401 return nullptr;
402 };
403 }
404
405 /// Hook for derived classes willing to bind more methods.
406 static void bindDerived(ClassTy &) {}
407
408 intptr_t startIndex;
409 intptr_t length;
410 intptr_t step;
411};
412
413} // namespace mlir
414
415namespace llvm {
416
417template <>
418struct DenseMapInfo<MlirTypeID> {
419 static inline MlirTypeID getEmptyKey() {
421 return mlirTypeIDCreate(pointer);
422 }
423 static inline MlirTypeID getTombstoneKey() {
425 return mlirTypeIDCreate(pointer);
426 }
427 static inline unsigned getHashValue(const MlirTypeID &val) {
428 return mlirTypeIDHashValue(val);
429 }
430 static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
431 return mlirTypeIDEqual(lhs, rhs);
432 }
433};
434} // namespace llvm
435
436#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
lhs
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< PyOpResultList > ClassTy
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
decltype(&T::maybeDownCast) has_maybe_downcast
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
nanobind::object getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t size()
Returns the size of slice.
ReferrentTy * operator->()
Defaulting(ReferrentTy &referrent)
ReferrentTy * get() const
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr)
ptr must be 8 byte aligned and unique to a type valid for the duration of the returned type id's usag...
Definition Support.cpp:38
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition Support.cpp:51
struct MlirStringRef MlirStringRef
Definition Support.h:77
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition Support.cpp:47
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition Support.h:105
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
Include the generated interface declarations.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:73
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75
static MlirTypeID getTombstoneKey()
static bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs)
static unsigned getHashValue(const MlirTypeID &val)
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept