MLIR  16.0.0git
PybindUtils.h
Go to the documentation of this file.
1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include "mlir-c/Support.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/Twine.h"
15 
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
18 
19 namespace mlir {
20 namespace python {
21 
22 // Sets a python error, ready to be thrown to return control back to the
23 // python runtime.
24 // Correct usage:
25 // throw SetPyError(PyExc_ValueError, "Foobar'd");
26 pybind11::error_already_set SetPyError(PyObject *excClass,
27  const llvm::Twine &message);
28 
29 /// CRTP template for special wrapper types that are allowed to be passed in as
30 /// 'None' function arguments and can be resolved by some global mechanic if
31 /// so. Such types will raise an error if this global resolution fails, and
32 /// it is actually illegal for them to ever be unresolved. From a user
33 /// perspective, they behave like a smart ptr to the underlying type (i.e.
34 /// 'get' method and operator-> overloaded).
35 ///
36 /// Derived types must provide a method, which is called when an environmental
37 /// resolution is required. It must raise an exception if resolution fails:
38 /// static ReferrentTy &resolve()
39 ///
40 /// They must also provide a parameter description that will be used in
41 /// error messages about mismatched types:
42 /// static constexpr const char kTypeDescription[] = "<Description>";
43 
44 template <typename DerivedTy, typename T>
45 class Defaulting {
46 public:
47  using ReferrentTy = T;
48  /// Type casters require the type to be default constructible, but using
49  /// such an instance is illegal.
50  Defaulting() = default;
51  Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
52 
53  ReferrentTy *get() const { return referrent; }
54  ReferrentTy *operator->() { return referrent; }
55 
56 private:
57  ReferrentTy *referrent = nullptr;
58 };
59 
60 } // namespace python
61 } // namespace mlir
62 
63 namespace pybind11 {
64 namespace detail {
65 
66 template <typename DefaultingTy>
68  PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
69 
70  bool load(pybind11::handle src, bool) {
71  if (src.is_none()) {
72  // Note that we do want an exception to propagate from here as it will be
73  // the most informative.
74  value = DefaultingTy{DefaultingTy::resolve()};
75  return true;
76  }
77 
78  // Unlike many casters that chain, these casters are expected to always
79  // succeed, so instead of doing an isinstance check followed by a cast,
80  // just cast in one step and handle the exception. Returning false (vs
81  // letting the exception propagate) causes higher level signature parsing
82  // code to produce nice error messages (other than "Cannot cast...").
83  try {
84  value = DefaultingTy{
85  pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
86  return true;
87  } catch (std::exception &) {
88  return false;
89  }
90  }
91 
92  static handle cast(DefaultingTy src, return_value_policy policy,
93  handle parent) {
94  return pybind11::cast(src, policy);
95  }
96 };
97 
98 template <typename T>
99 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
100 } // namespace detail
101 } // namespace pybind11
102 
103 //------------------------------------------------------------------------------
104 // Conversion utilities.
105 //------------------------------------------------------------------------------
106 
107 namespace mlir {
108 
109 /// Accumulates into a python string from a method that accepts an
110 /// MlirStringCallback.
112  pybind11::list parts;
113 
114  void *getUserData() { return this; }
115 
117  return [](MlirStringRef part, void *userData) {
118  PyPrintAccumulator *printAccum =
119  static_cast<PyPrintAccumulator *>(userData);
120  pybind11::str pyPart(part.data,
121  part.length); // Decodes as UTF-8 by default.
122  printAccum->parts.append(std::move(pyPart));
123  };
124  }
125 
126  pybind11::str join() {
127  pybind11::str delim("", 0);
128  return delim.attr("join")(parts);
129  }
130 };
131 
132 /// Accumulates int a python file-like object, either writing text (default)
133 /// or binary.
135 public:
136  PyFileAccumulator(const pybind11::object &fileObject, bool binary)
137  : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
138 
139  void *getUserData() { return this; }
140 
142  return [](MlirStringRef part, void *userData) {
143  pybind11::gil_scoped_acquire acquire;
144  PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
145  if (accum->binary) {
146  // Note: Still has to copy and not avoidable with this API.
147  pybind11::bytes pyBytes(part.data, part.length);
148  accum->pyWriteFunction(pyBytes);
149  } else {
150  pybind11::str pyStr(part.data,
151  part.length); // Decodes as UTF-8 by default.
152  accum->pyWriteFunction(pyStr);
153  }
154  };
155  }
156 
157 private:
158  pybind11::object pyWriteFunction;
159  bool binary;
160 };
161 
162 /// Accumulates into a python string from a method that is expected to make
163 /// one (no more, no less) call to the callback (asserts internally on
164 /// violation).
166  void *getUserData() { return this; }
167 
169  return [](MlirStringRef part, void *userData) {
171  static_cast<PySinglePartStringAccumulator *>(userData);
172  assert(!accum->invoked &&
173  "PySinglePartStringAccumulator called back multiple times");
174  accum->invoked = true;
175  accum->value = pybind11::str(part.data, part.length);
176  };
177  }
178 
179  pybind11::str takeValue() {
180  assert(invoked && "PySinglePartStringAccumulator not called back");
181  return std::move(value);
182  }
183 
184 private:
185  pybind11::str value;
186  bool invoked = false;
187 };
188 
189 /// A CRTP base class for pseudo-containers willing to support Python-type
190 /// slicing access on top of indexed access. Calling ::bind on this class
191 /// will define `__len__` as well as `__getitem__` with integer and slice
192 /// arguments.
193 ///
194 /// This is intended for pseudo-containers that can refer to arbitrary slices of
195 /// underlying storage indexed by a single integer. Indexing those with an
196 /// integer produces an instance of ElementTy. Indexing those with a slice
197 /// produces a new instance of Derived, which can be sliced further.
198 ///
199 /// A derived class must provide the following:
200 /// - a `static const char *pyClassName ` field containing the name of the
201 /// Python class to bind;
202 /// - an instance method `intptr_t getRawNumElements()` that returns the
203 /// number
204 /// of elements in the backing container (NOT that of the slice);
205 /// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
206 /// single element at the given linear index (NOT slice index);
207 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
208 /// constructs a new instance of the derived pseudo-container with the
209 /// given slice parameters (to be forwarded to the Sliceable constructor).
210 ///
211 /// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
212 /// throw.
213 ///
214 /// A derived class may additionally define:
215 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
216 /// the python class.
217 template <typename Derived, typename ElementTy>
218 class Sliceable {
219 protected:
220  using ClassTy = pybind11::class_<Derived>;
221 
222  /// Transforms `index` into a legal value to access the underlying sequence.
223  /// Returns <0 on failure.
224  intptr_t wrapIndex(intptr_t index) {
225  if (index < 0)
226  index = length + index;
227  if (index < 0 || index >= length)
228  return -1;
229  return index;
230  }
231 
232  /// Computes the linear index given the current slice properties.
233  intptr_t linearizeIndex(intptr_t index) {
234  intptr_t linearIndex = index * step + startIndex;
235  assert(linearIndex >= 0 &&
236  linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
237  "linear index out of bounds, the slice is ill-formed");
238  return linearIndex;
239  }
240 
241  /// Returns the element at the given slice index. Supports negative indices
242  /// by taking elements in inverse order. Returns a nullptr object if out
243  /// of bounds.
244  pybind11::object getItem(intptr_t index) {
245  // Negative indices mean we count from the end.
246  index = wrapIndex(index);
247  if (index < 0) {
248  PyErr_SetString(PyExc_IndexError, "index out of range");
249  return {};
250  }
251 
252  return pybind11::cast(
253  static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
254  }
255 
256  /// Returns a new instance of the pseudo-container restricted to the given
257  /// slice. Returns a nullptr object on failure.
258  pybind11::object getItemSlice(PyObject *slice) {
259  ssize_t start, stop, extraStep, sliceLength;
260  if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
261  &sliceLength) != 0) {
262  PyErr_SetString(PyExc_IndexError, "index out of range");
263  return {};
264  }
265  return pybind11::cast(static_cast<Derived *>(this)->slice(
266  startIndex + start * step, sliceLength, step * extraStep));
267  }
268 
269 public:
270  explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
271  : startIndex(startIndex), length(length), step(step) {
272  assert(length >= 0 && "expected non-negative slice length");
273  }
274 
275  /// Returns the `index`-th element in the slice, supports negative indices.
276  /// Throws if the index is out of bounds.
277  ElementTy getElement(intptr_t index) {
278  // Negative indices mean we count from the end.
279  index = wrapIndex(index);
280  if (index < 0) {
281  throw pybind11::index_error("index out of range");
282  }
283 
284  return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
285  }
286 
287  /// Returns the size of slice.
288  intptr_t size() { return length; }
289 
290  /// Returns a new vector (mapped to Python list) containing elements from two
291  /// slices. The new vector is necessary because slices may not be contiguous
292  /// or even come from the same original sequence.
293  std::vector<ElementTy> dunderAdd(Derived &other) {
294  std::vector<ElementTy> elements;
295  elements.reserve(length + other.length);
296  for (intptr_t i = 0; i < length; ++i) {
297  elements.push_back(static_cast<Derived *>(this)->getElement(i));
298  }
299  for (intptr_t i = 0; i < other.length; ++i) {
300  elements.push_back(static_cast<Derived *>(&other)->getElement(i));
301  }
302  return elements;
303  }
304 
305  /// Binds the indexing and length methods in the Python class.
306  static void bind(pybind11::module &m) {
307  auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
308  pybind11::module_local())
309  .def("__add__", &Sliceable::dunderAdd);
310  Derived::bindDerived(clazz);
311 
312  // Manually implement the sequence protocol via the C API. We do this
313  // because it is approx 4x faster than via pybind11, largely because that
314  // formulation requires a C++ exception to be thrown to detect end of
315  // sequence.
316  // Since we are in a C-context, any C++ exception that happens here
317  // will terminate the program. There is nothing in this implementation
318  // that should throw in a non-terminal way, so we forgo further
319  // exception marshalling.
320  // See: https://github.com/pybind/pybind11/issues/2842
321  auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
322  assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
323  "must be heap type");
324  heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
325  auto self = pybind11::cast<Derived *>(rawSelf);
326  return self->length;
327  };
328  // sq_item is called as part of the sequence protocol for iteration,
329  // list construction, etc.
330  heap_type->as_sequence.sq_item =
331  +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
332  auto self = pybind11::cast<Derived *>(rawSelf);
333  return self->getItem(index).release().ptr();
334  };
335  // mp_subscript is used for both slices and integer lookups.
336  heap_type->as_mapping.mp_subscript =
337  +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
338  auto self = pybind11::cast<Derived *>(rawSelf);
339  Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
340  if (!PyErr_Occurred()) {
341  // Integer indexing.
342  return self->getItem(index).release().ptr();
343  }
344  PyErr_Clear();
345 
346  // Assume slice-based indexing.
347  if (PySlice_Check(rawSubscript)) {
348  return self->getItemSlice(rawSubscript).release().ptr();
349  }
350 
351  PyErr_SetString(PyExc_ValueError, "expected integer or slice");
352  return nullptr;
353  };
354  }
355 
356  /// Hook for derived classes willing to bind more methods.
357  static void bindDerived(ClassTy &) {}
358 
359 private:
360  intptr_t startIndex;
361  intptr_t length;
362  intptr_t step;
363 };
364 
365 } // namespace mlir
366 
367 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
Include the generated interface declarations.
const char * data
Pointer to the first symbol.
Definition: Support.h:72
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
Definition: PybindUtils.h:224
pybind11::list parts
Definition: PybindUtils.h:112
pybind11::str join()
Definition: PybindUtils.h:126
static constexpr const bool value
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
Definition: PybindUtils.h:136
MlirStringCallback getCallback()
Definition: PybindUtils.h:168
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
Definition: PybindUtils.h:233
static void bind(pybind11::module &m)
Binds the indexing and length methods in the Python class.
Definition: PybindUtils.h:306
Defaulting(ReferrentTy &referrent)
Definition: PybindUtils.h:51
pybind11::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
Definition: PybindUtils.h:258
MlirStringCallback getCallback()
Definition: PybindUtils.h:141
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices. ...
Definition: PybindUtils.h:293
pybind11::error_already_set SetPyError(PyObject *excClass, const llvm::Twine &message)
Definition: PybindUtils.cpp:12
pybind11::class_< Derived > ClassTy
Definition: PybindUtils.h:220
intptr_t size()
Returns the size of slice.
Definition: PybindUtils.h:288
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Definition: PybindUtils.h:357
size_t length
Length of the fragment.
Definition: Support.h:73
Wrapper around an MlirLocation.
Definition: IRModule.h:420
MlirStringCallback getCallback()
Definition: PybindUtils.h:116
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:111
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:71
A CRTP base class for pseudo-containers willing to support Python-type slicing access on top of index...
Definition: PybindUtils.h:218
Accumulates into a python string from a method that is expected to make one (no more, no less) call to the callback (asserts internally on violation).
Definition: PybindUtils.h:165
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides...
pybind11::object getItem(intptr_t index)
Returns the element at the given slice index.
Definition: PybindUtils.h:244
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:103
ReferrentTy * operator->()
Definition: PybindUtils.h:54
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
Definition: PybindUtils.h:277
bool load(pybind11::handle src, bool)
Definition: PybindUtils.h:70
CRTP template for special wrapper types that are allowed to be passed in as &#39;None&#39; function arguments...
Definition: PybindUtils.h:45
Accumulates int a python file-like object, either writing text (default) or binary.
Definition: PybindUtils.h:134
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
Definition: PybindUtils.h:270
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal...
static handle cast(DefaultingTy src, return_value_policy policy, handle parent)
Definition: PybindUtils.h:92