MLIR  17.0.0git
PybindUtils.h
Go to the documentation of this file.
1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include "mlir-c/Support.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/DataTypes.h"
15 
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
18 
19 namespace mlir {
20 namespace python {
21 
22 /// CRTP template for special wrapper types that are allowed to be passed in as
23 /// 'None' function arguments and can be resolved by some global mechanic if
24 /// so. Such types will raise an error if this global resolution fails, and
25 /// it is actually illegal for them to ever be unresolved. From a user
26 /// perspective, they behave like a smart ptr to the underlying type (i.e.
27 /// 'get' method and operator-> overloaded).
28 ///
29 /// Derived types must provide a method, which is called when an environmental
30 /// resolution is required. It must raise an exception if resolution fails:
31 /// static ReferrentTy &resolve()
32 ///
33 /// They must also provide a parameter description that will be used in
34 /// error messages about mismatched types:
35 /// static constexpr const char kTypeDescription[] = "<Description>";
36 
37 template <typename DerivedTy, typename T>
38 class Defaulting {
39 public:
40  using ReferrentTy = T;
41  /// Type casters require the type to be default constructible, but using
42  /// such an instance is illegal.
43  Defaulting() = default;
44  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
45 
46  ReferrentTy *get() const { return referrent; }
47  ReferrentTy *operator->() { return referrent; }
48 
49 private:
50  ReferrentTy *referrent = nullptr;
51 };
52 
53 } // namespace python
54 } // namespace mlir
55 
56 namespace pybind11 {
57 namespace detail {
58 
59 template <typename DefaultingTy>
61  PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
62 
63  bool load(pybind11::handle src, bool) {
64  if (src.is_none()) {
65  // Note that we do want an exception to propagate from here as it will be
66  // the most informative.
67  value = DefaultingTy{DefaultingTy::resolve()};
68  return true;
69  }
70 
71  // Unlike many casters that chain, these casters are expected to always
72  // succeed, so instead of doing an isinstance check followed by a cast,
73  // just cast in one step and handle the exception. Returning false (vs
74  // letting the exception propagate) causes higher level signature parsing
75  // code to produce nice error messages (other than "Cannot cast...").
76  try {
77  value = DefaultingTy{
78  pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
79  return true;
80  } catch (std::exception &) {
81  return false;
82  }
83  }
84 
85  static handle cast(DefaultingTy src, return_value_policy policy,
86  handle parent) {
87  return pybind11::cast(src, policy);
88  }
89 };
90 } // namespace detail
91 } // namespace pybind11
92 
93 //------------------------------------------------------------------------------
94 // Conversion utilities.
95 //------------------------------------------------------------------------------
96 
97 namespace mlir {
98 
99 /// Accumulates into a python string from a method that accepts an
100 /// MlirStringCallback.
102  pybind11::list parts;
103 
104  void *getUserData() { return this; }
105 
107  return [](MlirStringRef part, void *userData) {
108  PyPrintAccumulator *printAccum =
109  static_cast<PyPrintAccumulator *>(userData);
110  pybind11::str pyPart(part.data,
111  part.length); // Decodes as UTF-8 by default.
112  printAccum->parts.append(std::move(pyPart));
113  };
114  }
115 
116  pybind11::str join() {
117  pybind11::str delim("", 0);
118  return delim.attr("join")(parts);
119  }
120 };
121 
122 /// Accumulates int a python file-like object, either writing text (default)
123 /// or binary.
125 public:
126  PyFileAccumulator(const pybind11::object &fileObject, bool binary)
127  : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
128 
129  void *getUserData() { return this; }
130 
132  return [](MlirStringRef part, void *userData) {
133  pybind11::gil_scoped_acquire acquire;
134  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
135  if (accum->binary) {
136  // Note: Still has to copy and not avoidable with this API.
137  pybind11::bytes pyBytes(part.data, part.length);
138  accum->pyWriteFunction(pyBytes);
139  } else {
140  pybind11::str pyStr(part.data,
141  part.length); // Decodes as UTF-8 by default.
142  accum->pyWriteFunction(pyStr);
143  }
144  };
145  }
146 
147 private:
148  pybind11::object pyWriteFunction;
149  bool binary;
150 };
151 
152 /// Accumulates into a python string from a method that is expected to make
153 /// one (no more, no less) call to the callback (asserts internally on
154 /// violation).
156  void *getUserData() { return this; }
157 
159  return [](MlirStringRef part, void *userData) {
161  static_cast<PySinglePartStringAccumulator *>(userData);
162  assert(!accum->invoked &&
163  "PySinglePartStringAccumulator called back multiple times");
164  accum->invoked = true;
165  accum->value = pybind11::str(part.data, part.length);
166  };
167  }
168 
169  pybind11::str takeValue() {
170  assert(invoked && "PySinglePartStringAccumulator not called back");
171  return std::move(value);
172  }
173 
174 private:
175  pybind11::str value;
176  bool invoked = false;
177 };
178 
179 /// A CRTP base class for pseudo-containers willing to support Python-type
180 /// slicing access on top of indexed access. Calling ::bind on this class
181 /// will define `__len__` as well as `__getitem__` with integer and slice
182 /// arguments.
183 ///
184 /// This is intended for pseudo-containers that can refer to arbitrary slices of
185 /// underlying storage indexed by a single integer. Indexing those with an
186 /// integer produces an instance of ElementTy. Indexing those with a slice
187 /// produces a new instance of Derived, which can be sliced further.
188 ///
189 /// A derived class must provide the following:
190 /// - a `static const char *pyClassName ` field containing the name of the
191 /// Python class to bind;
192 /// - an instance method `intptr_t getRawNumElements()` that returns the
193 /// number
194 /// of elements in the backing container (NOT that of the slice);
195 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
196 /// single element at the given linear index (NOT slice index);
197 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
198 /// constructs a new instance of the derived pseudo-container with the
199 /// given slice parameters (to be forwarded to the Sliceable constructor).
200 ///
201 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
202 /// throw.
203 ///
204 /// A derived class may additionally define:
205 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
206 /// the python class.
207 template <typename Derived, typename ElementTy>
208 class Sliceable {
209 protected:
210  using ClassTy = pybind11::class_<Derived>;
211 
212  /// Transforms `index` into a legal value to access the underlying sequence.
213  /// Returns <0 on failure.
214  intptr_t wrapIndex(intptr_t index) {
215  if (index < 0)
216  index = length + index;
217  if (index < 0 || index >= length)
218  return -1;
219  return index;
220  }
221 
222  /// Computes the linear index given the current slice properties.
223  intptr_t linearizeIndex(intptr_t index) {
224  intptr_t linearIndex = index * step + startIndex;
225  assert(linearIndex >= 0 &&
226  linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
227  "linear index out of bounds, the slice is ill-formed");
228  return linearIndex;
229  }
230 
231  /// Returns the element at the given slice index. Supports negative indices
232  /// by taking elements in inverse order. Returns a nullptr object if out
233  /// of bounds.
234  pybind11::object getItem(intptr_t index) {
235  // Negative indices mean we count from the end.
236  index = wrapIndex(index);
237  if (index < 0) {
238  PyErr_SetString(PyExc_IndexError, "index out of range");
239  return {};
240  }
241 
242  return pybind11::cast(
243  static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
244  }
245 
246  /// Returns a new instance of the pseudo-container restricted to the given
247  /// slice. Returns a nullptr object on failure.
248  pybind11::object getItemSlice(PyObject *slice) {
249  ssize_t start, stop, extraStep, sliceLength;
250  if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
251  &sliceLength) != 0) {
252  PyErr_SetString(PyExc_IndexError, "index out of range");
253  return {};
254  }
255  return pybind11::cast(static_cast<Derived *>(this)->slice(
256  startIndex + start * step, sliceLength, step * extraStep));
257  }
258 
259 public:
260  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
261  : startIndex(startIndex), length(length), step(step) {
262  assert(length >= 0 && "expected non-negative slice length");
263  }
264 
265  /// Returns the `index`-th element in the slice, supports negative indices.
266  /// Throws if the index is out of bounds.
267  ElementTy getElement(intptr_t index) {
268  // Negative indices mean we count from the end.
269  index = wrapIndex(index);
270  if (index < 0) {
271  throw pybind11::index_error("index out of range");
272  }
273 
274  return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
275  }
276 
277  /// Returns the size of slice.
278  intptr_t size() { return length; }
279 
280  /// Returns a new vector (mapped to Python list) containing elements from two
281  /// slices. The new vector is necessary because slices may not be contiguous
282  /// or even come from the same original sequence.
283  std::vector<ElementTy> dunderAdd(Derived &other) {
284  std::vector<ElementTy> elements;
285  elements.reserve(length + other.length);
286  for (intptr_t i = 0; i < length; ++i) {
287  elements.push_back(static_cast<Derived *>(this)->getElement(i));
288  }
289  for (intptr_t i = 0; i < other.length; ++i) {
290  elements.push_back(static_cast<Derived *>(&other)->getElement(i));
291  }
292  return elements;
293  }
294 
295  /// Binds the indexing and length methods in the Python class.
296  static void bind(pybind11::module &m) {
297  auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
298  pybind11::module_local())
299  .def("__add__", &Sliceable::dunderAdd);
300  Derived::bindDerived(clazz);
301 
302  // Manually implement the sequence protocol via the C API. We do this
303  // because it is approx 4x faster than via pybind11, largely because that
304  // formulation requires a C++ exception to be thrown to detect end of
305  // sequence.
306  // Since we are in a C-context, any C++ exception that happens here
307  // will terminate the program. There is nothing in this implementation
308  // that should throw in a non-terminal way, so we forgo further
309  // exception marshalling.
310  // See: https://github.com/pybind/pybind11/issues/2842
311  auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
312  assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
313  "must be heap type");
314  heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
315  auto self = pybind11::cast<Derived *>(rawSelf);
316  return self->length;
317  };
318  // sq_item is called as part of the sequence protocol for iteration,
319  // list construction, etc.
320  heap_type->as_sequence.sq_item =
321  +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
322  auto self = pybind11::cast<Derived *>(rawSelf);
323  return self->getItem(index).release().ptr();
324  };
325  // mp_subscript is used for both slices and integer lookups.
326  heap_type->as_mapping.mp_subscript =
327  +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
328  auto self = pybind11::cast<Derived *>(rawSelf);
329  Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
330  if (!PyErr_Occurred()) {
331  // Integer indexing.
332  return self->getItem(index).release().ptr();
333  }
334  PyErr_Clear();
335 
336  // Assume slice-based indexing.
337  if (PySlice_Check(rawSubscript)) {
338  return self->getItemSlice(rawSubscript).release().ptr();
339  }
340 
341  PyErr_SetString(PyExc_ValueError, "expected integer or slice");
342  return nullptr;
343  };
344  }
345 
346  /// Hook for derived classes willing to bind more methods.
347  static void bindDerived(ClassTy &) {}
348 
349 private:
350  intptr_t startIndex;
351  intptr_t length;
352  intptr_t step;
353 };
354 
355 } // namespace mlir
356 
357 namespace llvm {
358 
359 template <>
360 struct DenseMapInfo<MlirTypeID> {
361  static inline MlirTypeID getEmptyKey() {
362  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
363  return mlirTypeIDCreate(pointer);
364  }
365  static inline MlirTypeID getTombstoneKey() {
367  return mlirTypeIDCreate(pointer);
368  }
369  static inline unsigned getHashValue(const MlirTypeID &val) {
370  return mlirTypeIDHashValue(val);
371  }
372  static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
373  return mlirTypeIDEqual(lhs, rhs);
374  }
375 };
376 } // namespace llvm
377 
378 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:124
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
Definition: PybindUtils.h:126
MlirStringCallback getCallback()
Definition: PybindUtils.h:131
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:208
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
Definition: PybindUtils.h:223
static void bind(pybind11::module &m)
Binds the indexing and length methods in the Python class.
Definition: PybindUtils.h:296
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
Definition: PybindUtils.h:267
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
Definition: PybindUtils.h:283
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Definition: PybindUtils.h:347
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Definition: PybindUtils.h:260
pybind11::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
Definition: PybindUtils.h:248
pybind11::class_< Derived > ClassTy
Definition: PybindUtils.h:210
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
Definition: PybindUtils.h:214
intptr_t size()
Returns the size of slice.
Definition: PybindUtils.h:278
pybind11::object getItem(intptr_t index)
Returns the element at the given slice index.
Definition: PybindUtils.h:234
CRTP template for special wrapper types that are allowed to be passed in as 'None' function arguments...
Definition: PybindUtils.h:38
ReferrentTy * get() const
Definition: PybindUtils.h:46
Defaulting(ReferrentTy &referrent)
Definition: PybindUtils.h:44
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
ReferrentTy * operator->()
Definition: PybindUtils.h:47
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr)
ptr must be 8 byte aligned and unique to a type valid for the duration of the returned type id's usag...
Definition: Support.cpp:26
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
Definition: Support.cpp:39
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
Definition: Support.cpp:35
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:103
Include the generated interface declarations.
Definition: CallGraph.h:229
This header declares functions that assist transformations in the MemRef dialect.
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:71
const char * data
Pointer to the first symbol.
Definition: Support.h:72
size_t length
Length of the fragment.
Definition: Support.h:73
static MlirTypeID getTombstoneKey()
Definition: PybindUtils.h:365
static MlirTypeID getEmptyKey()
Definition: PybindUtils.h:361
static bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs)
Definition: PybindUtils.h:372
static unsigned getHashValue(const MlirTypeID &val)
Definition: PybindUtils.h:369
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:101
pybind11::list parts
Definition: PybindUtils.h:102
pybind11::str join()
Definition: PybindUtils.h:116
MlirStringCallback getCallback()
Definition: PybindUtils.h:106
Accumulates into a python string from a method that is expected to make one (no more,...
Definition: PybindUtils.h:155
MlirStringCallback getCallback()
Definition: PybindUtils.h:158
bool load(pybind11::handle src, bool)
Definition: PybindUtils.h:63
static handle cast(DefaultingTy src, return_value_policy policy, handle parent)
Definition: PybindUtils.h:85
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription))