19 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
20 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
22 #include <pybind11/functional.h>
23 #include <pybind11/pybind11.h>
24 #include <pybind11/pytypes.h>
25 #include <pybind11/stl.h>
31 #include "llvm/ADT/Twine.h"
34 using namespace py::literals;
51 if (PyCapsule_CheckExact(apiObject.ptr()))
52 return py::reinterpret_borrow<py::object>(apiObject);
54 auto repr = py::repr(apiObject).cast<std::string>();
56 (llvm::Twine(
"Expected an MLIR object (got ") + repr +
").").str());
68 struct type_caster<MlirAffineMap> {
70 bool load(handle src,
bool) {
78 static handle
cast(MlirAffineMap v, return_value_policy, handle) {
90 struct type_caster<MlirAttribute> {
92 bool load(handle src,
bool) {
97 static handle
cast(MlirAttribute v, return_value_policy, handle) {
110 struct type_caster<MlirBlock> {
121 struct type_caster<MlirContext> {
141 struct type_caster<MlirDialectRegistry> {
148 static handle
cast(MlirDialectRegistry v, return_value_policy, handle) {
149 py::object capsule = py::reinterpret_steal<py::object>(
152 .attr(
"DialectRegistry")
160 struct type_caster<MlirLocation> {
173 static handle
cast(MlirLocation v, return_value_policy, handle) {
185 struct type_caster<MlirModule> {
192 static handle
cast(MlirModule v, return_value_policy, handle) {
204 struct type_caster<MlirFrozenRewritePatternSet> {
206 _(
"MlirFrozenRewritePatternSet"));
210 return value.ptr !=
nullptr;
212 static handle
cast(MlirFrozenRewritePatternSet v, return_value_policy,
214 py::object capsule = py::reinterpret_steal<py::object>(
217 .attr(
"FrozenRewritePatternSet")
225 struct type_caster<MlirOperation> {
232 static handle
cast(MlirOperation v, return_value_policy, handle) {
233 if (v.ptr ==
nullptr)
246 struct type_caster<MlirValue> {
253 static handle
cast(MlirValue v, return_value_policy, handle) {
254 if (v.ptr ==
nullptr)
268 struct type_caster<MlirPassManager> {
279 struct type_caster<MlirTypeID> {
286 static handle
cast(MlirTypeID v, return_value_policy, handle) {
287 if (v.ptr ==
nullptr)
300 struct type_caster<MlirType> {
307 static handle
cast(MlirType t, return_value_policy, handle) {
339 const py::object &superClass) {
341 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
342 py::object metaclass = pyType(superClass);
346 metaclass(derivedClassName, py::make_tuple(superClass), attributes);
347 scope.attr(derivedClassName) = thisClass;
350 template <
typename Func,
typename... Extra>
353 std::forward<Func>(f), py::name(name), py::is_method(thisClass),
354 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
355 thisClass.attr(cf.name()) = cf;
359 template <
typename Func,
typename... Extra>
361 const Extra &...extra) {
363 std::forward<Func>(f), py::name(name), py::is_method(thisClass),
364 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
365 auto builtinProperty =
366 py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
367 thisClass.attr(name) = builtinProperty(cf);
371 template <
typename Func,
typename... Extra>
373 const Extra &...extra) {
374 static_assert(!std::is_member_function_pointer<Func>::value,
375 "def_staticmethod(...) called with a non-static member "
377 py::cpp_function cf(std::forward<Func>(f), py::name(name),
378 py::scope(thisClass), extra...);
379 thisClass.attr(cf.name()) = py::staticmethod(cf);
383 template <
typename Func,
typename... Extra>
385 const Extra &...extra) {
386 static_assert(!std::is_member_function_pointer<Func>::value,
387 "def_classmethod(...) called with a non-static member "
389 py::cpp_function cf(std::forward<Func>(f), py::name(name),
390 py::scope(thisClass), extra...);
391 thisClass.attr(cf.name()) =
392 py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
415 scope, attrClassName, isaFunction,
418 getTypeIDFunction) {}
436 std::string captureTypeName(
438 py::cpp_function newCf(
439 [superCls, isaFunction, captureTypeName](py::object cls,
440 py::object otherAttribute) {
441 MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
442 if (!isaFunction(rawAttribute)) {
443 auto origRepr = py::repr(otherAttribute).cast<std::string>();
444 throw std::invalid_argument(
445 (llvm::Twine(
"Cannot cast attribute to ") + captureTypeName +
446 " (from " + origRepr +
")")
449 py::object
self = superCls.attr(
"__new__")(cls, otherAttribute);
452 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_attr"));
453 thisClass.attr(
"__new__") = newCf;
458 [isaFunction](MlirAttribute other) {
return isaFunction(other); },
459 py::arg(
"other_attribute"));
460 def(
"__repr__", [superCls, captureTypeName](py::object
self) {
461 return py::repr(superCls(
self))
462 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
464 if (getTypeIDFunction) {
465 def_staticmethod(
"get_static_typeid",
466 [getTypeIDFunction]() {
return getTypeIDFunction(); });
469 getTypeIDFunction())(pybind11::cpp_function(
470 [thisClass = thisClass](
const py::object &mlirAttribute) {
471 return thisClass(mlirAttribute);
489 scope, typeClassName, isaFunction,
491 getTypeIDFunction) {}
509 std::string captureTypeName(
511 py::cpp_function newCf(
512 [superCls, isaFunction, captureTypeName](py::object cls,
513 py::object otherType) {
514 MlirType rawType = py::cast<MlirType>(otherType);
515 if (!isaFunction(rawType)) {
516 auto origRepr = py::repr(otherType).cast<std::string>();
517 throw std::invalid_argument((llvm::Twine(
"Cannot cast type to ") +
518 captureTypeName +
" (from " +
522 py::object
self = superCls.attr(
"__new__")(cls, otherType);
525 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_type"));
526 thisClass.attr(
"__new__") = newCf;
531 [isaFunction](MlirType other) {
return isaFunction(other); },
532 py::arg(
"other_type"));
533 def(
"__repr__", [superCls, captureTypeName](py::object
self) {
534 return py::repr(superCls(
self))
535 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
537 if (getTypeIDFunction) {
543 def_staticmethod(
"get_static_typeid",
544 [getTypeIDFunction]() {
return getTypeIDFunction(); });
547 getTypeIDFunction())(pybind11::cpp_function(
548 [thisClass = thisClass](
const py::object &mlirType) {
549 return thisClass(mlirType);
565 scope, valueClassName, isaFunction,
584 std::string captureValueName(
586 py::cpp_function newCf(
587 [superCls, isaFunction, captureValueName](py::object cls,
588 py::object otherValue) {
589 MlirValue rawValue = py::cast<MlirValue>(otherValue);
590 if (!isaFunction(rawValue)) {
591 auto origRepr = py::repr(otherValue).cast<std::string>();
592 throw std::invalid_argument((llvm::Twine(
"Cannot cast value to ") +
593 captureValueName +
" (from " +
597 py::object
self = superCls.attr(
"__new__")(cls, otherValue);
600 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_value"));
601 thisClass.attr(
"__new__") = newCf;
606 [isaFunction](MlirValue other) {
return isaFunction(other); },
607 py::arg(
"other_value"));
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
static MlirBlock mlirPythonCapsuleToBlock(PyObject *capsule)
Extracts an MlirBlock from a capsule as produced from mlirPythonBlockToCapsule.
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
static MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule)
Extracts an MlirAffineMap from a capsule as produced from mlirPythonAffineMapToCapsule.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
static PyObject * mlirPythonAffineMapToCapsule(MlirAffineMap affineMap)
Creates a capsule object encapsulating the raw C-API MlirAffineMap.
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
Creates a custom subclass of mlir.ir.Attribute, implementing a casting constructor and type checking ...
mlir_attribute_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Attribute super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirAttribute) IsAFunctionTy
mlir_attribute_subclass(py::handle scope, const char *attrClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
Creates a custom subclass of mlir.ir.Type, implementing a casting constructor and type checking metho...
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Type super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirType) IsAFunctionTy
Creates a custom subclass of mlir.ir.Value, implementing a casting constructor and type checking meth...
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction, const py::object &superCls)
Subclasses with a provided mlir.ir.Value super-class.
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirValue) IsAFunctionTy
Provides a facility like py::class_ for defining a new class in a scope, but this allows extension of...
pure_subclass & def_classmethod(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def(const char *name, Func &&f, const Extra &...extra)
py::object get_class() const
pure_subclass(py::handle scope, const char *derivedClassName, const py::object &superClass)
pure_subclass & def_property_readonly(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def_staticmethod(const char *name, Func &&f, const Extra &...extra)
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
static bool mlirAffineMapIsNull(MlirAffineMap affineMap)
Checks whether an affine map is null.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Include the generated interface declarations.
static py::object mlirApiObjectToCapsule(py::handle apiObject)
Helper to convert a presumed MLIR API object to a capsule, accepting either an explicit Capsule (whic...
PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"))
static handle cast(MlirAffineMap v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"))
static handle cast(MlirAttribute v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"))
static handle cast(MlirDialectRegistry v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, _("MlirFrozenRewritePatternSet"))
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"))
bool load(handle src, bool)
static handle cast(MlirLocation v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"))
static handle cast(MlirModule v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"))
bool load(handle src, bool)
static handle cast(MlirOperation v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"))
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"))
bool load(handle src, bool)
static handle cast(MlirTypeID v, return_value_policy, handle)
static handle cast(MlirType t, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"))
bool load(handle src, bool)
static handle cast(MlirValue v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"))