MLIR  19.0.0git
PybindUtils.h
Go to the documentation of this file.
1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include "mlir-c/Support.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/ADT/Twine.h"
15 #include "llvm/Support/DataTypes.h"
16 
17 #include <pybind11/pybind11.h>
18 #include <pybind11/stl.h>
19 
20 namespace mlir {
21 namespace python {
22 
23 /// CRTP template for special wrapper types that are allowed to be passed in as
24 /// 'None' function arguments and can be resolved by some global mechanic if
25 /// so. Such types will raise an error if this global resolution fails, and
26 /// it is actually illegal for them to ever be unresolved. From a user
27 /// perspective, they behave like a smart ptr to the underlying type (i.e.
28 /// 'get' method and operator-> overloaded).
29 ///
30 /// Derived types must provide a method, which is called when an environmental
31 /// resolution is required. It must raise an exception if resolution fails:
32 /// static ReferrentTy &resolve()
33 ///
34 /// They must also provide a parameter description that will be used in
35 /// error messages about mismatched types:
36 /// static constexpr const char kTypeDescription[] = "<Description>";
37 
38 template <typename DerivedTy, typename T>
39 class Defaulting {
40 public:
41  using ReferrentTy = T;
42  /// Type casters require the type to be default constructible, but using
43  /// such an instance is illegal.
44  Defaulting() = default;
45  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
46 
47  ReferrentTy *get() const { return referrent; }
48  ReferrentTy *operator->() { return referrent; }
49 
50 private:
51  ReferrentTy *referrent = nullptr;
52 };
53 
54 } // namespace python
55 } // namespace mlir
56 
57 namespace pybind11 {
58 namespace detail {
59 
60 template <typename DefaultingTy>
62  PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
63 
64  bool load(pybind11::handle src, bool) {
65  if (src.is_none()) {
66  // Note that we do want an exception to propagate from here as it will be
67  // the most informative.
68  value = DefaultingTy{DefaultingTy::resolve()};
69  return true;
70  }
71 
72  // Unlike many casters that chain, these casters are expected to always
73  // succeed, so instead of doing an isinstance check followed by a cast,
74  // just cast in one step and handle the exception. Returning false (vs
75  // letting the exception propagate) causes higher level signature parsing
76  // code to produce nice error messages (other than "Cannot cast...").
77  try {
78  value = DefaultingTy{
79  pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
80  return true;
81  } catch (std::exception &) {
82  return false;
83  }
84  }
85 
86  static handle cast(DefaultingTy src, return_value_policy policy,
87  handle parent) {
88  return pybind11::cast(src, policy);
89  }
90 };
91 } // namespace detail
92 } // namespace pybind11
93 
94 //------------------------------------------------------------------------------
95 // Conversion utilities.
96 //------------------------------------------------------------------------------
97 
98 namespace mlir {
99 
100 /// Accumulates into a python string from a method that accepts an
101 /// MlirStringCallback.
103  pybind11::list parts;
104 
105  void *getUserData() { return this; }
106 
108  return [](MlirStringRef part, void *userData) {
109  PyPrintAccumulator *printAccum =
110  static_cast<PyPrintAccumulator *>(userData);
111  pybind11::str pyPart(part.data,
112  part.length); // Decodes as UTF-8 by default.
113  printAccum->parts.append(std::move(pyPart));
114  };
115  }
116 
117  pybind11::str join() {
118  pybind11::str delim("", 0);
119  return delim.attr("join")(parts);
120  }
121 };
122 
123 /// Accumulates int a python file-like object, either writing text (default)
124 /// or binary.
126 public:
127  PyFileAccumulator(const pybind11::object &fileObject, bool binary)
128  : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
129 
130  void *getUserData() { return this; }
131 
133  return [](MlirStringRef part, void *userData) {
134  pybind11::gil_scoped_acquire acquire;
135  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
136  if (accum->binary) {
137  // Note: Still has to copy and not avoidable with this API.
138  pybind11::bytes pyBytes(part.data, part.length);
139  accum->pyWriteFunction(pyBytes);
140  } else {
141  pybind11::str pyStr(part.data,
142  part.length); // Decodes as UTF-8 by default.
143  accum->pyWriteFunction(pyStr);
144  }
145  };
146  }
147 
148 private:
149  pybind11::object pyWriteFunction;
150  bool binary;
151 };
152 
153 /// Accumulates into a python string from a method that is expected to make
154 /// one (no more, no less) call to the callback (asserts internally on
155 /// violation).
157  void *getUserData() { return this; }
158 
160  return [](MlirStringRef part, void *userData) {
162  static_cast<PySinglePartStringAccumulator *>(userData);
163  assert(!accum->invoked &&
164  "PySinglePartStringAccumulator called back multiple times");
165  accum->invoked = true;
166  accum->value = pybind11::str(part.data, part.length);
167  };
168  }
169 
170  pybind11::str takeValue() {
171  assert(invoked && "PySinglePartStringAccumulator not called back");
172  return std::move(value);
173  }
174 
175 private:
176  pybind11::str value;
177  bool invoked = false;
178 };
179 
180 /// A CRTP base class for pseudo-containers willing to support Python-type
181 /// slicing access on top of indexed access. Calling ::bind on this class
182 /// will define `__len__` as well as `__getitem__` with integer and slice
183 /// arguments.
184 ///
185 /// This is intended for pseudo-containers that can refer to arbitrary slices of
186 /// underlying storage indexed by a single integer. Indexing those with an
187 /// integer produces an instance of ElementTy. Indexing those with a slice
188 /// produces a new instance of Derived, which can be sliced further.
189 ///
190 /// A derived class must provide the following:
191 /// - a `static const char *pyClassName ` field containing the name of the
192 /// Python class to bind;
193 /// - an instance method `intptr_t getRawNumElements()` that returns the
194 /// number
195 /// of elements in the backing container (NOT that of the slice);
196 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
197 /// single element at the given linear index (NOT slice index);
198 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
199 /// constructs a new instance of the derived pseudo-container with the
200 /// given slice parameters (to be forwarded to the Sliceable constructor).
201 ///
202 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
203 /// throw.
204 ///
205 /// A derived class may additionally define:
206 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
207 /// the python class.
208 template <typename Derived, typename ElementTy>
209 class Sliceable {
210 protected:
211  using ClassTy = pybind11::class_<Derived>;
212 
213  /// Transforms `index` into a legal value to access the underlying sequence.
214  /// Returns <0 on failure.
215  intptr_t wrapIndex(intptr_t index) {
216  if (index < 0)
217  index = length + index;
218  if (index < 0 || index >= length)
219  return -1;
220  return index;
221  }
222 
223  /// Computes the linear index given the current slice properties.
224  intptr_t linearizeIndex(intptr_t index) {
225  intptr_t linearIndex = index * step + startIndex;
226  assert(linearIndex >= 0 &&
227  linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
228  "linear index out of bounds, the slice is ill-formed");
229  return linearIndex;
230  }
231 
232  /// Trait to check if T provides a `maybeDownCast` method.
233  /// Note, you need the & to detect inherited members.
234  template <typename T, typename... Args>
235  using has_maybe_downcast = decltype(&T::maybeDownCast);
236 
237  /// Returns the element at the given slice index. Supports negative indices
238  /// by taking elements in inverse order. Returns a nullptr object if out
239  /// of bounds.
240  pybind11::object getItem(intptr_t index) {
241  // Negative indices mean we count from the end.
242  index = wrapIndex(index);
243  if (index < 0) {
244  PyErr_SetString(PyExc_IndexError, "index out of range");
245  return {};
246  }
247 
248  if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
249  return static_cast<Derived *>(this)
250  ->getRawElement(linearizeIndex(index))
251  .maybeDownCast();
252  else
253  return pybind11::cast(
254  static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
255  }
256 
257  /// Returns a new instance of the pseudo-container restricted to the given
258  /// slice. Returns a nullptr object on failure.
259  pybind11::object getItemSlice(PyObject *slice) {
260  ssize_t start, stop, extraStep, sliceLength;
261  if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
262  &sliceLength) != 0) {
263  PyErr_SetString(PyExc_IndexError, "index out of range");
264  return {};
265  }
266  return pybind11::cast(static_cast<Derived *>(this)->slice(
267  startIndex + start * step, sliceLength, step * extraStep));
268  }
269 
270 public:
271  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
272  : startIndex(startIndex), length(length), step(step) {
273  assert(length >= 0 && "expected non-negative slice length");
274  }
275 
276  /// Returns the `index`-th element in the slice, supports negative indices.
277  /// Throws if the index is out of bounds.
278  ElementTy getElement(intptr_t index) {
279  // Negative indices mean we count from the end.
280  index = wrapIndex(index);
281  if (index < 0) {
282  throw pybind11::index_error("index out of range");
283  }
284 
285  return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
286  }
287 
288  /// Returns the size of slice.
289  intptr_t size() { return length; }
290 
291  /// Returns a new vector (mapped to Python list) containing elements from two
292  /// slices. The new vector is necessary because slices may not be contiguous
293  /// or even come from the same original sequence.
294  std::vector<ElementTy> dunderAdd(Derived &other) {
295  std::vector<ElementTy> elements;
296  elements.reserve(length + other.length);
297  for (intptr_t i = 0; i < length; ++i) {
298  elements.push_back(static_cast<Derived *>(this)->getElement(i));
299  }
300  for (intptr_t i = 0; i < other.length; ++i) {
301  elements.push_back(static_cast<Derived *>(&other)->getElement(i));
302  }
303  return elements;
304  }
305 
306  /// Binds the indexing and length methods in the Python class.
307  static void bind(pybind11::module &m) {
308  auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
309  pybind11::module_local())
310  .def("__add__", &Sliceable::dunderAdd);
311  Derived::bindDerived(clazz);
312 
313  // Manually implement the sequence protocol via the C API. We do this
314  // because it is approx 4x faster than via pybind11, largely because that
315  // formulation requires a C++ exception to be thrown to detect end of
316  // sequence.
317  // Since we are in a C-context, any C++ exception that happens here
318  // will terminate the program. There is nothing in this implementation
319  // that should throw in a non-terminal way, so we forgo further
320  // exception marshalling.
321  // See: https://github.com/pybind/pybind11/issues/2842
322  auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
323  assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
324  "must be heap type");
325  heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
326  auto self = pybind11::cast<Derived *>(rawSelf);
327  return self->length;
328  };
329  // sq_item is called as part of the sequence protocol for iteration,
330  // list construction, etc.
331  heap_type->as_sequence.sq_item =
332  +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
333  auto self = pybind11::cast<Derived *>(rawSelf);
334  return self->getItem(index).release().ptr();
335  };
336  // mp_subscript is used for both slices and integer lookups.
337  heap_type->as_mapping.mp_subscript =
338  +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
339  auto self = pybind11::cast<Derived *>(rawSelf);
340  Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
341  if (!PyErr_Occurred()) {
342  // Integer indexing.
343  return self->getItem(index).release().ptr();
344  }
345  PyErr_Clear();
346 
347  // Assume slice-based indexing.
348  if (PySlice_Check(rawSubscript)) {
349  return self->getItemSlice(rawSubscript).release().ptr();
350  }
351 
352  PyErr_SetString(PyExc_ValueError, "expected integer or slice");
353  return nullptr;
354  };
355  }
356 
357  /// Hook for derived classes willing to bind more methods.
358  static void bindDerived(ClassTy &) {}
359 
360 private:
361  intptr_t startIndex;
362  intptr_t length;
363  intptr_t step;
364 };
365 
366 } // namespace mlir
367 
368 namespace llvm {
369 
370 template <>
371 struct DenseMapInfo<MlirTypeID> {
372  static inline MlirTypeID getEmptyKey() {
373  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
374  return mlirTypeIDCreate(pointer);
375  }
376  static inline MlirTypeID getTombstoneKey() {
378  return mlirTypeIDCreate(pointer);
379  }
380  static inline unsigned getHashValue(const MlirTypeID &val) {
381  return mlirTypeIDHashValue(val);
382  }
383  static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
384  return mlirTypeIDEqual(lhs, rhs);
385  }
386 };
387 } // namespace llvm
388 
389 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:125
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
Definition: PybindUtils.h:127
MlirStringCallback getCallback()
Definition: PybindUtils.h:132
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:209
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
Definition: PybindUtils.h:224
static void bind(pybind11::module &m)
Binds the indexing and length methods in the Python class.
Definition: PybindUtils.h:307
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
Definition: PybindUtils.h:278
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
Definition: PybindUtils.h:294
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Definition: PybindUtils.h:358
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Definition: PybindUtils.h:271
decltype(&T::maybeDownCast) has_maybe_downcast
Trait to check if T provides a maybeDownCast method.
Definition: PybindUtils.h:235
pybind11::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
Definition: PybindUtils.h:259
pybind11::class_< Derived > ClassTy
Definition: PybindUtils.h:211
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
Definition: PybindUtils.h:215
intptr_t size()
Returns the size of slice.
Definition: PybindUtils.h:289
pybind11::object getItem(intptr_t index)
Returns the element at the given slice index.
Definition: PybindUtils.h:240
CRTP template for special wrapper types that are allowed to be passed in as 'None' function arguments...
Definition: PybindUtils.h:39
ReferrentTy * get() const
Definition: PybindUtils.h:47
Defaulting(ReferrentTy &referrent)
Definition: PybindUtils.h:45
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
ReferrentTy * operator->()
Definition: PybindUtils.h:48
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr)
ptr must be 8 byte aligned and unique to a type valid for the duration of the returned type id's usag...
Definition: Support.cpp:38
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:51
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:47
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
Definition: CallGraph.h:229
Include the generated interface declarations.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
static MlirTypeID getTombstoneKey()
Definition: PybindUtils.h:376
static MlirTypeID getEmptyKey()
Definition: PybindUtils.h:372
static bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs)
Definition: PybindUtils.h:383
static unsigned getHashValue(const MlirTypeID &val)
Definition: PybindUtils.h:380
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::list parts
Definition: PybindUtils.h:103
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Accumulates into a python string from a method that is expected to make one (no more,...
Definition: PybindUtils.h:156
MlirStringCallback getCallback()
Definition: PybindUtils.h:159
bool load(pybind11::handle src, bool)
Definition: PybindUtils.h:64
static handle cast(DefaultingTy src, return_value_policy policy, handle parent)
Definition: PybindUtils.h:86
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription))