MLIR  20.0.0git
NanobindUtils.h
Go to the documentation of this file.
1 //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12 
13 #include "mlir-c/Support.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/DataTypes.h"
18 
19 template <>
20 struct std::iterator_traits<nanobind::detail::fast_iterator> {
21  using value_type = nanobind::handle;
22  using reference = const value_type;
23  using pointer = void;
24  using difference_type = std::ptrdiff_t;
25  using iterator_category = std::forward_iterator_tag;
26 };
27 
28 namespace mlir {
29 namespace python {
30 
31 /// CRTP template for special wrapper types that are allowed to be passed in as
32 /// 'None' function arguments and can be resolved by some global mechanic if
33 /// so. Such types will raise an error if this global resolution fails, and
34 /// it is actually illegal for them to ever be unresolved. From a user
35 /// perspective, they behave like a smart ptr to the underlying type (i.e.
36 /// 'get' method and operator-> overloaded).
37 ///
38 /// Derived types must provide a method, which is called when an environmental
39 /// resolution is required. It must raise an exception if resolution fails:
40 /// static ReferrentTy &resolve()
41 ///
42 /// They must also provide a parameter description that will be used in
43 /// error messages about mismatched types:
44 /// static constexpr const char kTypeDescription[] = "<Description>";
45 
46 template <typename DerivedTy, typename T>
47 class Defaulting {
48 public:
49  using ReferrentTy = T;
50  /// Type casters require the type to be default constructible, but using
51  /// such an instance is illegal.
52  Defaulting() = default;
53  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
54 
55  ReferrentTy *get() const { return referrent; }
56  ReferrentTy *operator->() { return referrent; }
57 
58 private:
59  ReferrentTy *referrent = nullptr;
60 };
61 
62 } // namespace python
63 } // namespace mlir
64 
65 namespace nanobind {
66 namespace detail {
67 
68 template <typename DefaultingTy>
70  NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
71 
72  bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
73  if (src.is_none()) {
74  // Note that we do want an exception to propagate from here as it will be
75  // the most informative.
76  value = DefaultingTy{DefaultingTy::resolve()};
77  return true;
78  }
79 
80  // Unlike many casters that chain, these casters are expected to always
81  // succeed, so instead of doing an isinstance check followed by a cast,
82  // just cast in one step and handle the exception. Returning false (vs
83  // letting the exception propagate) causes higher level signature parsing
84  // code to produce nice error messages (other than "Cannot cast...").
85  try {
86  value = DefaultingTy{
87  nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
88  return true;
89  } catch (std::exception &) {
90  return false;
91  }
92  }
93 
94  static handle from_cpp(DefaultingTy src, rv_policy policy,
95  cleanup_list *cleanup) noexcept {
96  return nanobind::cast(src, policy);
97  }
98 };
99 } // namespace detail
100 } // namespace nanobind
101 
102 //------------------------------------------------------------------------------
103 // Conversion utilities.
104 //------------------------------------------------------------------------------
105 
106 namespace mlir {
107 
108 /// Accumulates into a python string from a method that accepts an
109 /// MlirStringCallback.
111  nanobind::list parts;
112 
113  void *getUserData() { return this; }
114 
116  return [](MlirStringRef part, void *userData) {
117  PyPrintAccumulator *printAccum =
118  static_cast<PyPrintAccumulator *>(userData);
119  nanobind::str pyPart(part.data,
120  part.length); // Decodes as UTF-8 by default.
121  printAccum->parts.append(std::move(pyPart));
122  };
123  }
124 
125  nanobind::str join() {
126  nanobind::str delim("", 0);
127  return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
128  }
129 };
130 
131 /// Accumulates int a python file-like object, either writing text (default)
132 /// or binary.
134 public:
135  PyFileAccumulator(const nanobind::object &fileObject, bool binary)
136  : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
137 
138  void *getUserData() { return this; }
139 
141  return [](MlirStringRef part, void *userData) {
142  nanobind::gil_scoped_acquire acquire;
143  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
144  if (accum->binary) {
145  // Note: Still has to copy and not avoidable with this API.
146  nanobind::bytes pyBytes(part.data, part.length);
147  accum->pyWriteFunction(pyBytes);
148  } else {
149  nanobind::str pyStr(part.data,
150  part.length); // Decodes as UTF-8 by default.
151  accum->pyWriteFunction(pyStr);
152  }
153  };
154  }
155 
156 private:
157  nanobind::object pyWriteFunction;
158  bool binary;
159 };
160 
161 /// Accumulates into a python string from a method that is expected to make
162 /// one (no more, no less) call to the callback (asserts internally on
163 /// violation).
165  void *getUserData() { return this; }
166 
168  return [](MlirStringRef part, void *userData) {
170  static_cast<PySinglePartStringAccumulator *>(userData);
171  assert(!accum->invoked &&
172  "PySinglePartStringAccumulator called back multiple times");
173  accum->invoked = true;
174  accum->value = nanobind::str(part.data, part.length);
175  };
176  }
177 
178  nanobind::str takeValue() {
179  assert(invoked && "PySinglePartStringAccumulator not called back");
180  return std::move(value);
181  }
182 
183 private:
184  nanobind::str value;
185  bool invoked = false;
186 };
187 
188 /// A CRTP base class for pseudo-containers willing to support Python-type
189 /// slicing access on top of indexed access. Calling ::bind on this class
190 /// will define `__len__` as well as `__getitem__` with integer and slice
191 /// arguments.
192 ///
193 /// This is intended for pseudo-containers that can refer to arbitrary slices of
194 /// underlying storage indexed by a single integer. Indexing those with an
195 /// integer produces an instance of ElementTy. Indexing those with a slice
196 /// produces a new instance of Derived, which can be sliced further.
197 ///
198 /// A derived class must provide the following:
199 /// - a `static const char *pyClassName ` field containing the name of the
200 /// Python class to bind;
201 /// - an instance method `intptr_t getRawNumElements()` that returns the
202 /// number
203 /// of elements in the backing container (NOT that of the slice);
204 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
205 /// single element at the given linear index (NOT slice index);
206 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
207 /// constructs a new instance of the derived pseudo-container with the
208 /// given slice parameters (to be forwarded to the Sliceable constructor).
209 ///
210 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
211 /// throw.
212 ///
213 /// A derived class may additionally define:
214 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
215 /// the python class.
216 template <typename Derived, typename ElementTy>
217 class Sliceable {
218 protected:
219  using ClassTy = nanobind::class_<Derived>;
220 
221  /// Transforms `index` into a legal value to access the underlying sequence.
222  /// Returns <0 on failure.
223  intptr_t wrapIndex(intptr_t index) {
224  if (index < 0)
225  index = length + index;
226  if (index < 0 || index >= length)
227  return -1;
228  return index;
229  }
230 
231  /// Computes the linear index given the current slice properties.
232  intptr_t linearizeIndex(intptr_t index) {
233  intptr_t linearIndex = index * step + startIndex;
234  assert(linearIndex >= 0 &&
235  linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
236  "linear index out of bounds, the slice is ill-formed");
237  return linearIndex;
238  }
239 
240  /// Trait to check if T provides a `maybeDownCast` method.
241  /// Note, you need the & to detect inherited members.
242  template <typename T, typename... Args>
243  using has_maybe_downcast = decltype(&T::maybeDownCast);
244 
245  /// Returns the element at the given slice index. Supports negative indices
246  /// by taking elements in inverse order. Returns a nullptr object if out
247  /// of bounds.
248  nanobind::object getItem(intptr_t index) {
249  // Negative indices mean we count from the end.
250  index = wrapIndex(index);
251  if (index < 0) {
252  PyErr_SetString(PyExc_IndexError, "index out of range");
253  return {};
254  }
255 
256  if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
257  return static_cast<Derived *>(this)
258  ->getRawElement(linearizeIndex(index))
259  .maybeDownCast();
260  else
261  return nanobind::cast(
262  static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
263  }
264 
265  /// Returns a new instance of the pseudo-container restricted to the given
266  /// slice. Returns a nullptr object on failure.
267  nanobind::object getItemSlice(PyObject *slice) {
268  ssize_t start, stop, extraStep, sliceLength;
269  if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
270  &sliceLength) != 0) {
271  PyErr_SetString(PyExc_IndexError, "index out of range");
272  return {};
273  }
274  return nanobind::cast(static_cast<Derived *>(this)->slice(
275  startIndex + start * step, sliceLength, step * extraStep));
276  }
277 
278 public:
279  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
280  : startIndex(startIndex), length(length), step(step) {
281  assert(length >= 0 && "expected non-negative slice length");
282  }
283 
284  /// Returns the `index`-th element in the slice, supports negative indices.
285  /// Throws if the index is out of bounds.
286  ElementTy getElement(intptr_t index) {
287  // Negative indices mean we count from the end.
288  index = wrapIndex(index);
289  if (index < 0) {
290  throw nanobind::index_error("index out of range");
291  }
292 
293  return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
294  }
295 
296  /// Returns the size of slice.
297  intptr_t size() { return length; }
298 
299  /// Returns a new vector (mapped to Python list) containing elements from two
300  /// slices. The new vector is necessary because slices may not be contiguous
301  /// or even come from the same original sequence.
302  std::vector<ElementTy> dunderAdd(Derived &other) {
303  std::vector<ElementTy> elements;
304  elements.reserve(length + other.length);
305  for (intptr_t i = 0; i < length; ++i) {
306  elements.push_back(static_cast<Derived *>(this)->getElement(i));
307  }
308  for (intptr_t i = 0; i < other.length; ++i) {
309  elements.push_back(static_cast<Derived *>(&other)->getElement(i));
310  }
311  return elements;
312  }
313 
314  /// Binds the indexing and length methods in the Python class.
315  static void bind(nanobind::module_ &m) {
316  auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
317  .def("__add__", &Sliceable::dunderAdd);
318  Derived::bindDerived(clazz);
319 
320  // Manually implement the sequence protocol via the C API. We do this
321  // because it is approx 4x faster than via nanobind, largely because that
322  // formulation requires a C++ exception to be thrown to detect end of
323  // sequence.
324  // Since we are in a C-context, any C++ exception that happens here
325  // will terminate the program. There is nothing in this implementation
326  // that should throw in a non-terminal way, so we forgo further
327  // exception marshalling.
328  // See: https://github.com/pybind/nanobind/issues/2842
329  auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
330  assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
331  "must be heap type");
332  heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
333  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
334  return self->length;
335  };
336  // sq_item is called as part of the sequence protocol for iteration,
337  // list construction, etc.
338  heap_type->as_sequence.sq_item =
339  +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
340  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
341  return self->getItem(index).release().ptr();
342  };
343  // mp_subscript is used for both slices and integer lookups.
344  heap_type->as_mapping.mp_subscript =
345  +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
346  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
347  Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
348  if (!PyErr_Occurred()) {
349  // Integer indexing.
350  return self->getItem(index).release().ptr();
351  }
352  PyErr_Clear();
353 
354  // Assume slice-based indexing.
355  if (PySlice_Check(rawSubscript)) {
356  return self->getItemSlice(rawSubscript).release().ptr();
357  }
358 
359  PyErr_SetString(PyExc_ValueError, "expected integer or slice");
360  return nullptr;
361  };
362  }
363 
364  /// Hook for derived classes willing to bind more methods.
365  static void bindDerived(ClassTy &) {}
366 
367 private:
368  intptr_t startIndex;
369  intptr_t length;
370  intptr_t step;
371 };
372 
373 } // namespace mlir
374 
375 namespace llvm {
376 
377 template <>
378 struct DenseMapInfo<MlirTypeID> {
379  static inline MlirTypeID getEmptyKey() {
380  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
381  return mlirTypeIDCreate(pointer);
382  }
383  static inline MlirTypeID getTombstoneKey() {
385  return mlirTypeIDCreate(pointer);
386  }
387  static inline unsigned getHashValue(const MlirTypeID &val) {
388  return mlirTypeIDHashValue(val);
389  }
390  static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
391  return mlirTypeIDEqual(lhs, rhs);
392  }
393 };
394 } // namespace llvm
395 
396 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
Accumulates int a python file-like object, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileObject, bool binary)
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< Derived > ClassTy
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
decltype(&T::maybeDownCast) has_maybe_downcast
Trait to check if T provides a maybeDownCast method.
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
nanobind::object getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t size()
Returns the size of slice.
CRTP template for special wrapper types that are allowed to be passed in as 'None' function arguments...
Definition: NanobindUtils.h:47
ReferrentTy * get() const
Definition: NanobindUtils.h:55
Defaulting(ReferrentTy &referrent)
Definition: NanobindUtils.h:53
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
ReferrentTy * operator->()
Definition: NanobindUtils.h:56
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr)
ptr must be 8 byte aligned and unique to a type valid for the duration of the returned type id's usag...
Definition: Support.cpp:38
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static MlirTypeID getTombstoneKey()
static MlirTypeID getEmptyKey()
static bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs)
static unsigned getHashValue(const MlirTypeID &val)
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
Definition: NanobindUtils.h:72
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept
Definition: NanobindUtils.h:94