10#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
25struct std::iterator_traits<
nanobind::detail::fast_iterator> {
63template <
typename DerivedTy,
typename T>
86template <
typename... Ts>
87inline std::string
join(
const Ts &...args) {
88 std::ostringstream oss;
93template <
typename DefaultingTy>
95 NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
101 value = DefaultingTy{DefaultingTy::resolve()};
111 value = DefaultingTy{
112 nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
114 }
catch (std::exception &) {
119 static handle
from_cpp(DefaultingTy src, rv_policy policy,
120 cleanup_list *cleanup)
noexcept {
121 return nanobind::cast(src, policy);
144 nanobind::str pyPart(part.
data,
146 printAccum->
parts.append(std::move(pyPart));
151 nanobind::str delim(
"", 0);
152 return nanobind::cast<nanobind::str>(delim.attr(
"join")(
parts));
173 std::string filePath;
174 if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
175 std::string errorMessage;
176 auto errorCallback = +[](
MlirStringRef message,
void *userData) {
177 auto *storage =
static_cast<std::string *
>(userData);
178 storage->assign(message.
data, message.
length);
181 filePath.c_str(), binary, errorCallback, &errorMessage);
183 throw nanobind::value_error(
184 (std::string(
"Unable to open file for writing: ") + errorMessage)
189 writeTarget.emplace<nanobind::object>(fileOrStringObject.attr(
"write"));
194 return writeTarget.index() == 0 ? getPyWriteCallback()
195 : getOStreamCallback();
203 nanobind::gil_scoped_acquire acquire;
207 nanobind::bytes pyBytes(part.
data, part.
length);
208 std::get<nanobind::object>(accum->writeTarget)(pyBytes);
210 nanobind::str pyStr(part.
data,
212 std::get<nanobind::object>(accum->writeTarget)(pyStr);
221 std::get<RAIIMlirLlvmRawFdOStream>(accum->writeTarget), part);
225 std::variant<nanobind::object, RAIIMlirLlvmRawFdOStream> writeTarget;
239 assert(!accum->invoked &&
240 "PySinglePartStringAccumulator called back multiple times");
241 accum->invoked =
true;
242 accum->value = nanobind::str(part.
data, part.
length);
247 assert(invoked &&
"PySinglePartStringAccumulator not called back");
248 return std::move(value);
253 bool invoked =
false;
284template <
typename Derived,
typename ElementTy>
294 if (index < 0 || index >=
length)
302 assert(linearIndex >= 0 &&
303 linearIndex <
static_cast<Derived *
>(
this)->getRawNumElements() &&
304 "linear index out of bounds, the slice is ill-formed");
310 template <
typename T,
typename =
void>
313 template <
typename T>
324 PyErr_SetString(PyExc_IndexError,
"index out of range");
328 if constexpr (has_maybe_downcast<ElementTy>::value)
329 return static_cast<Derived *
>(
this)
333 return nanobind::cast(
340 Py_ssize_t start, stop, extraStep, sliceLength;
341 if (PySlice_GetIndicesEx(slice,
length, &start, &stop, &extraStep,
342 &sliceLength) != 0) {
343 PyErr_SetString(PyExc_IndexError,
"index out of range");
346 return nanobind::cast(
static_cast<Derived *
>(
this)->slice(
353 assert(
length >= 0 &&
"expected non-negative slice length");
362 throw nanobind::index_error(
"index out of range");
375 std::vector<ElementTy> elements;
376 elements.reserve(
length + other.length);
378 elements.push_back(
static_cast<Derived *
>(
this)->
getElement(i));
380 for (
intptr_t i = 0; i < other.length; ++i) {
381 elements.push_back(
static_cast<Derived *
>(&other)->
getElement(i));
387 static void bind(nanobind::module_ &m) {
388 const std::type_info &elemTy =
typeid(ElementTy);
389 PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
391 "expected nb_type_lookup to succeed for Sliceable elemTy");
392 nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
393 std::string sig = std::string(
"class ") + Derived::pyClassName +
394 "(collections.abc.Sequence[" +
395 nanobind::cast<std::string>(elemTyName) +
"])";
396 auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
397 nanobind::sig(sig.c_str()))
399 Derived::bindDerived(clazz);
410 auto heap_type =
reinterpret_cast<PyHeapTypeObject *
>(clazz.ptr());
411 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
412 "must be heap type");
413 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
414 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
419 heap_type->as_sequence.sq_item =
420 +[](PyObject *rawSelf, Py_ssize_t
index) -> PyObject * {
421 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
422 return self->getItem(
index).release().ptr();
425 heap_type->as_mapping.mp_subscript =
426 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
427 auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
428 Py_ssize_t
index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
429 if (!PyErr_Occurred()) {
431 return self->getItem(
index).release().ptr();
436 if (PySlice_Check(rawSubscript)) {
437 return self->getItemSlice(rawSubscript).release().ptr();
440 PyErr_SetString(PyExc_ValueError,
"expected integer or slice");
Accumulates into a file, either writing text (default) or binary.
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
MlirStringCallback getCallback()
nanobind::typed< nanobind::object, ElementTy > getItem(intptr_t index)
Returns the element at the given slice index.
intptr_t linearizeIndex(intptr_t index)
Computes the linear index given the current slice properties.
static void bind(nanobind::module_ &m)
Binds the indexing and length methods in the Python class.
std::vector< ElementTy > dunderAdd(Derived &other)
Returns a new vector (mapped to Python list) containing elements from two slices.
ElementTy getElement(intptr_t index)
Returns the index-th element in the slice, supports negative indices.
nanobind::object getItemSlice(PyObject *slice)
Returns a new instance of the pseudo-container restricted to the given slice.
nanobind::class_< PyOpResultList > ClassTy
static void bindDerived(ClassTy &)
Hook for derived classes willing to bind more methods.
Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
intptr_t wrapIndex(intptr_t index)
Transforms index into a legal value to access the underlying sequence.
intptr_t size()
Returns the size of slice.
ReferrentTy * operator->()
Defaulting(ReferrentTy &referrent)
ReferrentTy * get() const
Defaulting()=default
Type casters require the type to be default constructible, but using such an instance is illegal.
MLIR_CAPI_EXPORTED bool mlirLlvmRawFdOStreamIsNull(MlirLlvmRawFdOStream stream)
Checks if a raw_fd_ostream is null.
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamWrite(MlirLlvmRawFdOStream stream, MlirStringRef string)
Write a string to a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID)
Returns the hash value of the type id.
MLIR_CAPI_EXPORTED void mlirLlvmRawFdOStreamDestroy(MlirLlvmRawFdOStream stream)
Destroy a raw_fd_ostream created with mlirLlvmRawFdOStreamCreate.
struct MlirStringRef MlirStringRef
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2)
Checks if two type ids are equal.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
MLIR_CAPI_EXPORTED MlirLlvmRawFdOStream mlirLlvmRawFdOStreamCreate(const char *path, bool binary, MlirStringCallback errorCallback, void *userData)
Create a raw_fd_ostream for the given path.
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Accumulates into a python string from a method that is expected to make one (no more,...
nanobind::str takeValue()
MlirStringCallback getCallback()
RAII wrapper for MlirLlvmRawFdOStream that ensures destruction on scope exit.
RAIIMlirLlvmRawFdOStream & operator=(const RAIIMlirLlvmRawFdOStream &)=delete
RAIIMlirLlvmRawFdOStream(MlirLlvmRawFdOStream stream)
~RAIIMlirLlvmRawFdOStream()
RAIIMlirLlvmRawFdOStream(const RAIIMlirLlvmRawFdOStream &)=delete
Trait to check if T provides a maybeDownCast method.
bool operator()(MlirTypeID lhs, MlirTypeID rhs) const
size_t operator()(MlirTypeID typeID) const
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup)
static handle from_cpp(DefaultingTy src, rv_policy policy, cleanup_list *cleanup) noexcept
const value_type reference
std::ptrdiff_t difference_type
std::forward_iterator_tag iterator_category
nanobind::handle value_type