19#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
20#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
22#include <pybind11/functional.h>
23#include <pybind11/pybind11.h>
24#include <pybind11/pytypes.h>
25#include <pybind11/stl.h>
31#include "llvm/ADT/Twine.h"
34using namespace py::literals;
51 if (PyCapsule_CheckExact(apiObject.ptr()))
52 return py::reinterpret_borrow<py::object>(apiObject);
54 auto repr = py::repr(apiObject).cast<std::string>();
56 (llvm::Twine(
"Expected an MLIR object (got ") + repr +
").").str());
68struct type_caster<MlirAffineMap> {
70 bool load(handle src,
bool) {
78 static handle
cast(MlirAffineMap v, return_value_policy, handle) {
90struct type_caster<MlirAttribute> {
92 bool load(handle src,
bool) {
95 return !mlirAttributeIsNull(value);
97 static handle
cast(MlirAttribute v, return_value_policy, handle) {
110struct type_caster<MlirBlock> {
121struct type_caster<MlirContext> {
141struct type_caster<MlirDialectRegistry> {
148 static handle
cast(MlirDialectRegistry v, return_value_policy, handle) {
149 py::object capsule = py::reinterpret_steal<py::object>(
152 .attr(
"DialectRegistry")
160struct type_caster<MlirLocation> {
173 static handle
cast(MlirLocation v, return_value_policy, handle) {
185struct type_caster<MlirModule> {
190 return !mlirModuleIsNull(value);
192 static handle
cast(MlirModule v, return_value_policy, handle) {
204struct type_caster<MlirFrozenRewritePatternSet> {
206 _(
"MlirFrozenRewritePatternSet"));
210 return value.ptr !=
nullptr;
212 static handle
cast(MlirFrozenRewritePatternSet v, return_value_policy,
214 py::object capsule = py::reinterpret_steal<py::object>(
217 .attr(
"FrozenRewritePatternSet")
225struct type_caster<MlirOperation> {
230 return !mlirOperationIsNull(value);
232 static handle
cast(MlirOperation v, return_value_policy, handle) {
233 if (v.ptr ==
nullptr)
246struct type_caster<MlirValue> {
251 return !mlirValueIsNull(value);
253 static handle
cast(MlirValue v, return_value_policy, handle) {
254 if (v.ptr ==
nullptr)
268struct type_caster<MlirPassManager> {
279struct type_caster<MlirTypeID> {
286 static handle
cast(MlirTypeID v, return_value_policy, handle) {
287 if (v.ptr ==
nullptr)
300struct type_caster<MlirType> {
307 static handle
cast(MlirType t, return_value_policy, handle) {
341 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
346 metaclass(derivedClassName, py::make_tuple(
superClass), attributes);
347 scope.attr(derivedClassName) =
thisClass;
350 template <
typename Func,
typename... Extra>
353 std::forward<Func>(f), py::name(name), py::is_method(
thisClass),
354 py::sibling(py::getattr(
thisClass, name, py::none())), extra...);
359 template <
typename Func,
typename... Extra>
361 const Extra &...extra) {
363 std::forward<Func>(f), py::name(name), py::is_method(
thisClass),
364 py::sibling(py::getattr(
thisClass, name, py::none())), extra...);
365 auto builtinProperty =
366 py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
371 template <
typename Func,
typename... Extra>
373 const Extra &...extra) {
374 static_assert(!std::is_member_function_pointer<Func>::value,
375 "def_staticmethod(...) called with a non-static member "
377 py::cpp_function
cf(std::forward<Func>(f), py::name(name),
383 template <
typename Func,
typename... Extra>
385 const Extra &...extra) {
386 static_assert(!std::is_member_function_pointer<Func>::value,
387 "def_classmethod(...) called with a non-static member "
389 py::cpp_function
cf(std::forward<Func>(f), py::name(name),
392 py::reinterpret_borrow<py::object>(PyClassMethod_New(
cf.ptr()));
415 scope, attrClassName, isaFunction,
418 getTypeIDFunction) {}
436 std::string captureTypeName(
438 py::cpp_function newCf(
439 [superCls, isaFunction, captureTypeName](py::object cls,
440 py::object otherAttribute) {
441 MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
442 if (!isaFunction(rawAttribute)) {
443 auto origRepr = py::repr(otherAttribute).cast<std::string>();
444 throw std::invalid_argument(
445 (llvm::Twine(
"Cannot cast attribute to ") + captureTypeName +
446 " (from " + origRepr +
")")
449 py::object self = superCls.attr(
"__new__")(cls, otherAttribute);
452 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_attr"));
458 [isaFunction](MlirAttribute other) {
return isaFunction(other); },
459 py::arg(
"other_attribute"));
460 def(
"__repr__", [superCls, captureTypeName](py::object self) {
461 return py::repr(superCls(self))
462 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
464 if (getTypeIDFunction) {
466 [getTypeIDFunction]() {
return getTypeIDFunction(); });
469 getTypeIDFunction())(pybind11::cpp_function(
489 scope, typeClassName, isaFunction,
491 getTypeIDFunction) {}
509 std::string captureTypeName(
511 py::cpp_function newCf(
512 [superCls, isaFunction, captureTypeName](py::object cls,
513 py::object otherType) {
514 MlirType rawType = py::cast<MlirType>(otherType);
515 if (!isaFunction(rawType)) {
516 auto origRepr = py::repr(otherType).cast<std::string>();
517 throw std::invalid_argument((llvm::Twine(
"Cannot cast type to ") +
518 captureTypeName +
" (from " +
522 py::object self = superCls.attr(
"__new__")(cls, otherType);
525 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_type"));
531 [isaFunction](MlirType other) {
return isaFunction(other); },
532 py::arg(
"other_type"));
533 def(
"__repr__", [superCls, captureTypeName](py::object self) {
534 return py::repr(superCls(self))
535 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
537 if (getTypeIDFunction) {
544 [getTypeIDFunction]() {
return getTypeIDFunction(); });
547 getTypeIDFunction())(pybind11::cpp_function(
565 scope, valueClassName, isaFunction,
584 std::string captureValueName(
586 py::cpp_function newCf(
587 [superCls, isaFunction, captureValueName](py::object cls,
588 py::object otherValue) {
589 MlirValue rawValue = py::cast<MlirValue>(otherValue);
590 if (!isaFunction(rawValue)) {
591 auto origRepr = py::repr(otherValue).cast<std::string>();
592 throw std::invalid_argument((llvm::Twine(
"Cannot cast value to ") +
593 captureValueName +
" (from " +
597 py::object self = superCls.attr(
"__new__")(cls, otherValue);
600 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_value"));
606 [isaFunction](MlirValue other) {
return isaFunction(other); },
607 py::arg(
"other_value"));
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
static MlirBlock mlirPythonCapsuleToBlock(PyObject *capsule)
Extracts an MlirBlock from a capsule as produced from mlirPythonBlockToCapsule.
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
static MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule)
Extracts an MlirAffineMap from a capsule as produced from mlirPythonAffineMapToCapsule.
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
static PyObject * mlirPythonAffineMapToCapsule(MlirAffineMap affineMap)
Creates a capsule object encapsulating the raw C-API MlirAffineMap.
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
mlir_attribute_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Attribute super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirAttribute) IsAFunctionTy
mlir_attribute_subclass(py::handle scope, const char *attrClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Type super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirType) IsAFunctionTy
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction, const py::object &superCls)
Subclasses with a provided mlir.ir.Value super-class.
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirValue) IsAFunctionTy
pure_subclass & def_property_readonly(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def_staticmethod(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def(const char *name, Func &&f, const Extra &...extra)
py::object get_class() const
pure_subclass(py::handle scope, const char *derivedClassName, const py::object &superClass)
pure_subclass & def_classmethod(const char *name, Func &&f, const Extra &...extra)
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
static bool mlirAffineMapIsNull(MlirAffineMap affineMap)
Checks whether an affine map is null.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Include the generated interface declarations.
static py::object mlirApiObjectToCapsule(py::handle apiObject)
Helper to convert a presumed MLIR API object to a capsule, accepting either an explicit Capsule (whic...
PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"))
static handle cast(MlirAffineMap v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"))
static handle cast(MlirAttribute v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"))
static handle cast(MlirDialectRegistry v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, _("MlirFrozenRewritePatternSet"))
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"))
bool load(handle src, bool)
static handle cast(MlirLocation v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"))
static handle cast(MlirModule v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"))
bool load(handle src, bool)
static handle cast(MlirOperation v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"))
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"))
bool load(handle src, bool)
static handle cast(MlirTypeID v, return_value_policy, handle)
static handle cast(MlirType t, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"))
bool load(handle src, bool)
static handle cast(MlirValue v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"))