MLIR  21.0.0git
NanobindUtils.h
Go to the documentation of this file.
1 //===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
12 
13 #include "mlir-c/Support.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/DataTypes.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 #include <string>
22 #include <variant>
23 
24 template <>
25 struct std::iterator_traits<nanobind::detail::fast_iterator> {
26  using value_type = nanobind::handle;
27  using reference = const value_type;
28  using pointer = void;
29  using difference_type = std::ptrdiff_t;
30  using iterator_category = std::forward_iterator_tag;
31 };
32 
33 namespace mlir {
34 namespace python {
35 
36 /// CRTP template for special wrapper types that are allowed to be passed in as
37 /// 'None' function arguments and can be resolved by some global mechanic if
38 /// so. Such types will raise an error if this global resolution fails, and
39 /// it is actually illegal for them to ever be unresolved. From a user
40 /// perspective, they behave like a smart ptr to the underlying type (i.e.
41 /// 'get' method and operator-> overloaded).
42 ///
43 /// Derived types must provide a method, which is called when an environmental
44 /// resolution is required. It must raise an exception if resolution fails:
45 /// static ReferrentTy &resolve()
46 ///
47 /// They must also provide a parameter description that will be used in
48 /// error messages about mismatched types:
49 /// static constexpr const char kTypeDescription[] = "<Description>";
50 
51 template <typename DerivedTy, typename T>
52 class Defaulting {
53 public:
54  using ReferrentTy = T;
55  /// Type casters require the type to be default constructible, but using
56  /// such an instance is illegal.
57  Defaulting() = default;
58  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
59 
60  ReferrentTy *get() const { return referrent; }
61  ReferrentTy *operator->() { return referrent; }
62 
63 private:
64  ReferrentTy *referrent = nullptr;
65 };
66 
67 } // namespace python
68 } // namespace mlir
69 
70 namespace nanobind {
71 namespace detail {
72 
73 template <typename DefaultingTy>
75  NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
76 
77  bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
78  if (src.is_none()) {
79  // Note that we do want an exception to propagate from here as it will be
80  // the most informative.
81  value = DefaultingTy{DefaultingTy::resolve()};
82  return true;
83  }
84 
85  // Unlike many casters that chain, these casters are expected to always
86  // succeed, so instead of doing an isinstance check followed by a cast,
87  // just cast in one step and handle the exception. Returning false (vs
88  // letting the exception propagate) causes higher level signature parsing
89  // code to produce nice error messages (other than "Cannot cast...").
90  try {
91  value = DefaultingTy{
92  nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
93  return true;
94  } catch (std::exception &) {
95  return false;
96  }
97  }
98 
99  static handle from_cpp(DefaultingTy src, rv_policy policy,
100  cleanup_list *cleanup) noexcept {
101  return nanobind::cast(src, policy);
102  }
103 };
104 } // namespace detail
105 } // namespace nanobind
106 
107 //------------------------------------------------------------------------------
108 // Conversion utilities.
109 //------------------------------------------------------------------------------
110 
111 namespace mlir {
112 
113 /// Accumulates into a python string from a method that accepts an
114 /// MlirStringCallback.
116  nanobind::list parts;
117 
118  void *getUserData() { return this; }
119 
121  return [](MlirStringRef part, void *userData) {
122  PyPrintAccumulator *printAccum =
123  static_cast<PyPrintAccumulator *>(userData);
124  nanobind::str pyPart(part.data,
125  part.length); // Decodes as UTF-8 by default.
126  printAccum->parts.append(std::move(pyPart));
127  };
128  }
129 
130  nanobind::str join() {
131  nanobind::str delim("", 0);
132  return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
133  }
134 };
135 
136 /// Accumulates into a file, either writing text (default)
137 /// or binary. The file may be a Python file-like object or a path to a file.
139 public:
140  PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
141  : binary(binary) {
142  std::string filePath;
143  if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
144  std::error_code ec;
145  writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
146  if (ec) {
147  throw nanobind::value_error(
148  (std::string("Unable to open file for writing: ") + ec.message())
149  .c_str());
150  }
151  } else {
152  writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
153  }
154  }
155 
157  return writeTarget.index() == 0 ? getPyWriteCallback()
158  : getOstreamCallback();
159  }
160 
161  void *getUserData() { return this; }
162 
163 private:
164  MlirStringCallback getPyWriteCallback() {
165  return [](MlirStringRef part, void *userData) {
166  nanobind::gil_scoped_acquire acquire;
167  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
168  if (accum->binary) {
169  // Note: Still has to copy and not avoidable with this API.
170  nanobind::bytes pyBytes(part.data, part.length);
171  std::get<nanobind::object>(accum->writeTarget)(pyBytes);
172  } else {
173  nanobind::str pyStr(part.data,
174  part.length); // Decodes as UTF-8 by default.
175  std::get<nanobind::object>(accum->writeTarget)(pyStr);
176  }
177  };
178  }
179 
180  MlirStringCallback getOstreamCallback() {
181  return [](MlirStringRef part, void *userData) {
182  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
183  std::get<llvm::raw_fd_ostream>(accum->writeTarget)
184  .write(part.data, part.length);
185  };
186  }
187 
188  std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
189  bool binary;
190 };
191 
192 /// Accumulates into a python string from a method that is expected to make
193 /// one (no more, no less) call to the callback (asserts internally on
194 /// violation).
196  void *getUserData() { return this; }
197 
199  return [](MlirStringRef part, void *userData) {
201  static_cast<PySinglePartStringAccumulator *>(userData);
202  assert(!accum->invoked &&
203  "PySinglePartStringAccumulator called back multiple times");
204  accum->invoked = true;
205  accum->value = nanobind::str(part.data, part.length);
206  };
207  }
208 
209  nanobind::str takeValue() {
210  assert(invoked && "PySinglePartStringAccumulator not called back");
211  return std::move(value);
212  }
213 
214 private:
215  nanobind::str value;
216  bool invoked = false;
217 };
218 
219 /// A CRTP base class for pseudo-containers willing to support Python-type
220 /// slicing access on top of indexed access. Calling ::bind on this class
221 /// will define `__len__` as well as `__getitem__` with integer and slice
222 /// arguments.
223 ///
224 /// This is intended for pseudo-containers that can refer to arbitrary slices of
225 /// underlying storage indexed by a single integer. Indexing those with an
226 /// integer produces an instance of ElementTy. Indexing those with a slice
227 /// produces a new instance of Derived, which can be sliced further.
228 ///
229 /// A derived class must provide the following:
230 /// - a `static const char *pyClassName ` field containing the name of the
231 /// Python class to bind;
232 /// - an instance method `intptr_t getRawNumElements()` that returns the
233 /// number
234 /// of elements in the backing container (NOT that of the slice);
235 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
236 /// single element at the given linear index (NOT slice index);
237 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
238 /// constructs a new instance of the derived pseudo-container with the
239 /// given slice parameters (to be forwarded to the Sliceable constructor).
240 ///
241 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
242 /// throw.
243 ///
244 /// A derived class may additionally define:
245 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
246 /// the python class.
247 template <typename Derived, typename ElementTy>
248 class Sliceable {
249 protected:
250  using ClassTy = nanobind::class_<Derived>;
251 
252  /// Transforms `index` into a legal value to access the underlying sequence.
253  /// Returns <0 on failure.
254  intptr_t wrapIndex(intptr_t index) {
255  if (index < 0)
256  index = length + index;
257  if (index < 0 || index >= length)
258  return -1;
259  return index;
260  }
261 
262  /// Computes the linear index given the current slice properties.
263  intptr_t linearizeIndex(intptr_t index) {
264  intptr_t linearIndex = index * step + startIndex;
265  assert(linearIndex >= 0 &&
266  linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
267  "linear index out of bounds, the slice is ill-formed");
268  return linearIndex;
269  }
270 
271  /// Trait to check if T provides a `maybeDownCast` method.
272  /// Note, you need the & to detect inherited members.
273  template <typename T, typename... Args>
274  using has_maybe_downcast = decltype(&T::maybeDownCast);
275 
276  /// Returns the element at the given slice index. Supports negative indices
277  /// by taking elements in inverse order. Returns a nullptr object if out
278  /// of bounds.
279  nanobind::object getItem(intptr_t index) {
280  // Negative indices mean we count from the end.
281  index = wrapIndex(index);
282  if (index < 0) {
283  PyErr_SetString(PyExc_IndexError, "index out of range");
284  return {};
285  }
286 
287  if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
288  return static_cast<Derived *>(this)
289  ->getRawElement(linearizeIndex(index))
290  .maybeDownCast();
291  else
292  return nanobind::cast(
293  static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
294  }
295 
296  /// Returns a new instance of the pseudo-container restricted to the given
297  /// slice. Returns a nullptr object on failure.
298  nanobind::object getItemSlice(PyObject *slice) {
299  ssize_t start, stop, extraStep, sliceLength;
300  if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
301  &sliceLength) != 0) {
302  PyErr_SetString(PyExc_IndexError, "index out of range");
303  return {};
304  }
305  return nanobind::cast(static_cast<Derived *>(this)->slice(
306  startIndex + start * step, sliceLength, step * extraStep));
307  }
308 
309 public:
310  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
311  : startIndex(startIndex), length(length), step(step) {
312  assert(length >= 0 && "expected non-negative slice length");
313  }
314 
315  /// Returns the `index`-th element in the slice, supports negative indices.
316  /// Throws if the index is out of bounds.
317  ElementTy getElement(intptr_t index) {
318  // Negative indices mean we count from the end.
319  index = wrapIndex(index);
320  if (index < 0) {
321  throw nanobind::index_error("index out of range");
322  }
323 
324  return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
325  }
326 
327  /// Returns the size of slice.
328  intptr_t size() { return length; }
329 
330  /// Returns a new vector (mapped to Python list) containing elements from two
331  /// slices. The new vector is necessary because slices may not be contiguous
332  /// or even come from the same original sequence.
333  std::vector<ElementTy> dunderAdd(Derived &other) {
334  std::vector<ElementTy> elements;
335  elements.reserve(length + other.length);
336  for (intptr_t i = 0; i < length; ++i) {
337  elements.push_back(static_cast<Derived *>(this)->getElement(i));
338  }
339  for (intptr_t i = 0; i < other.length; ++i) {
340  elements.push_back(static_cast<Derived *>(&other)->getElement(i));
341  }
342  return elements;
343  }
344 
345  /// Binds the indexing and length methods in the Python class.
346  static void bind(nanobind::module_ &m) {
347  auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
348  .def("__add__", &Sliceable::dunderAdd);
349  Derived::bindDerived(clazz);
350 
351  // Manually implement the sequence protocol via the C API. We do this
352  // because it is approx 4x faster than via nanobind, largely because that
353  // formulation requires a C++ exception to be thrown to detect end of
354  // sequence.
355  // Since we are in a C-context, any C++ exception that happens here
356  // will terminate the program. There is nothing in this implementation
357  // that should throw in a non-terminal way, so we forgo further
358  // exception marshalling.
359  // See: https://github.com/pybind/nanobind/issues/2842
360  auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
361  assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
362  "must be heap type");
363  heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
364  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
365  return self->length;
366  };
367  // sq_item is called as part of the sequence protocol for iteration,
368  // list construction, etc.
369  heap_type->as_sequence.sq_item =
370  +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
371  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
372  return self->getItem(index).release().ptr();
373  };
374  // mp_subscript is used for both slices and integer lookups.
375  heap_type->as_mapping.mp_subscript =
376  +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
377  auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
378  Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
379  if (!PyErr_Occurred()) {
380  // Integer indexing.
381  return self->getItem(index).release().ptr();
382  }
383  PyErr_Clear();
384 
385  // Assume slice-based indexing.
386  if (PySlice_Check(rawSubscript)) {
387  return self->getItemSlice(rawSubscript).release().ptr();
388  }
389 
390  PyErr_SetString(PyExc_ValueError, "expected integer or slice");
391  return nullptr;
392  };
393  }
394 
395  /// Hook for derived classes willing to bind more methods.
396  static void bindDerived(ClassTy &) {}
397 
398 private:
399  intptr_t startIndex;
400  intptr_t length;
401  intptr_t step;
402 };
403 
404 } // namespace mlir
405 
406 namespace llvm {
407 
408 template <>
409 struct DenseMapInfo<MlirTypeID> {
410  static inline MlirTypeID getEmptyKey() {
411  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
412  return mlirTypeIDCreate(pointer);
413  }
414  static inline MlirTypeID getTombstoneKey() {
416  return mlirTypeIDCreate(pointer);
417  }
418  static inline unsigned getHashValue(const MlirTypeID &val) {
419  return mlirTypeIDHashValue(val);
420  }
421  static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
422  return mlirTypeIDEqual(lhs, rhs);
423  }
424 };
425 } // namespace llvm
426 
427 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< Derived > ClassTy
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
decltype(&T::maybeDownCast) has_maybe_downcast
Trait to check if T provides a maybeDownCast method.
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
nanobind::object getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t size()
Returns the size of slice.
CRTP template for special wrapper types that are allowed to be passed in as 'None' function arguments...
Definition: NanobindUtils.h:52
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Defaulting(ReferrentTy &referrent)
Definition: NanobindUtils.h:58
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
ReferrentTy * operator->()
Definition: NanobindUtils.h:61
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr)
ptr must be 8 byte aligned and unique to a type valid for the duration of the returned type id's usag...
Definition: Support.cpp:38
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Include the generated interface declarations.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static MlirTypeID getTombstoneKey()
static MlirTypeID getEmptyKey()
static bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs)
static unsigned getHashValue(const MlirTypeID &val)
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
Definition: NanobindUtils.h:77
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept
Definition: NanobindUtils.h:99