18 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
19 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
21 #include <pybind11/functional.h>
22 #include <pybind11/pybind11.h>
23 #include <pybind11/pytypes.h>
24 #include <pybind11/stl.h>
30 #include "llvm/ADT/Twine.h"
33 using namespace py::literals;
50 if (PyCapsule_CheckExact(apiObject.ptr()))
51 return py::reinterpret_borrow<py::object>(apiObject);
53 auto repr = py::repr(apiObject).cast<std::string>();
55 (llvm::Twine(
"Expected an MLIR object (got ") + repr +
").").str());
67 struct type_caster<MlirAffineMap> {
69 bool load(handle src,
bool) {
77 static handle
cast(MlirAffineMap v, return_value_policy, handle) {
89 struct type_caster<MlirAttribute> {
91 bool load(handle src,
bool) {
96 static handle
cast(MlirAttribute v, return_value_policy, handle) {
109 struct type_caster<MlirBlock> {
120 struct type_caster<MlirContext> {
140 struct type_caster<MlirDialectRegistry> {
147 static handle
cast(MlirDialectRegistry v, return_value_policy, handle) {
148 py::object capsule = py::reinterpret_steal<py::object>(
151 .attr(
"DialectRegistry")
159 struct type_caster<MlirLocation> {
172 static handle
cast(MlirLocation v, return_value_policy, handle) {
184 struct type_caster<MlirModule> {
191 static handle
cast(MlirModule v, return_value_policy, handle) {
203 struct type_caster<MlirFrozenRewritePatternSet> {
205 _(
"MlirFrozenRewritePatternSet"));
209 return value.ptr !=
nullptr;
211 static handle
cast(MlirFrozenRewritePatternSet v, return_value_policy,
213 py::object capsule = py::reinterpret_steal<py::object>(
216 .attr(
"FrozenRewritePatternSet")
224 struct type_caster<MlirOperation> {
231 static handle
cast(MlirOperation v, return_value_policy, handle) {
232 if (v.ptr ==
nullptr)
245 struct type_caster<MlirValue> {
252 static handle
cast(MlirValue v, return_value_policy, handle) {
253 if (v.ptr ==
nullptr)
267 struct type_caster<MlirPassManager> {
278 struct type_caster<MlirTypeID> {
285 static handle
cast(MlirTypeID v, return_value_policy, handle) {
286 if (v.ptr ==
nullptr)
299 struct type_caster<MlirType> {
306 static handle
cast(MlirType t, return_value_policy, handle) {
338 const py::object &superClass) {
340 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
341 py::object metaclass = pyType(superClass);
345 metaclass(derivedClassName, py::make_tuple(superClass), attributes);
346 scope.attr(derivedClassName) = thisClass;
349 template <
typename Func,
typename... Extra>
352 std::forward<Func>(f), py::name(name), py::is_method(thisClass),
353 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
354 thisClass.attr(cf.name()) = cf;
358 template <
typename Func,
typename... Extra>
360 const Extra &...extra) {
362 std::forward<Func>(f), py::name(name), py::is_method(thisClass),
363 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
364 auto builtinProperty =
365 py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
366 thisClass.attr(name) = builtinProperty(cf);
370 template <
typename Func,
typename... Extra>
372 const Extra &...extra) {
373 static_assert(!std::is_member_function_pointer<Func>::value,
374 "def_staticmethod(...) called with a non-static member "
377 std::forward<Func>(f), py::name(name), py::scope(thisClass),
378 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
379 thisClass.attr(cf.name()) = py::staticmethod(cf);
383 template <
typename Func,
typename... Extra>
385 const Extra &...extra) {
386 static_assert(!std::is_member_function_pointer<Func>::value,
387 "def_classmethod(...) called with a non-static member "
390 std::forward<Func>(f), py::name(name), py::scope(thisClass),
391 py::sibling(py::getattr(thisClass, name, py::none())), extra...);
392 thisClass.attr(cf.name()) =
393 py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
416 scope, attrClassName, isaFunction,
419 getTypeIDFunction) {}
437 std::string captureTypeName(
439 py::cpp_function newCf(
440 [superCls, isaFunction, captureTypeName](py::object cls,
441 py::object otherAttribute) {
442 MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
443 if (!isaFunction(rawAttribute)) {
444 auto origRepr = py::repr(otherAttribute).cast<std::string>();
445 throw std::invalid_argument(
446 (llvm::Twine(
"Cannot cast attribute to ") + captureTypeName +
447 " (from " + origRepr +
")")
450 py::object
self = superCls.attr(
"__new__")(cls, otherAttribute);
453 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_attr"));
454 thisClass.attr(
"__new__") = newCf;
459 [isaFunction](MlirAttribute other) {
return isaFunction(other); },
460 py::arg(
"other_attribute"));
461 def(
"__repr__", [superCls, captureTypeName](py::object
self) {
462 return py::repr(superCls(
self))
463 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
465 if (getTypeIDFunction) {
466 def_staticmethod(
"get_static_typeid",
467 [getTypeIDFunction]() {
return getTypeIDFunction(); });
470 getTypeIDFunction())(pybind11::cpp_function(
471 [thisClass = thisClass](
const py::object &mlirAttribute) {
472 return thisClass(mlirAttribute);
490 scope, typeClassName, isaFunction,
492 getTypeIDFunction) {}
510 std::string captureTypeName(
512 py::cpp_function newCf(
513 [superCls, isaFunction, captureTypeName](py::object cls,
514 py::object otherType) {
515 MlirType rawType = py::cast<MlirType>(otherType);
516 if (!isaFunction(rawType)) {
517 auto origRepr = py::repr(otherType).cast<std::string>();
518 throw std::invalid_argument((llvm::Twine(
"Cannot cast type to ") +
519 captureTypeName +
" (from " +
523 py::object
self = superCls.attr(
"__new__")(cls, otherType);
526 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_type"));
527 thisClass.attr(
"__new__") = newCf;
532 [isaFunction](MlirType other) {
return isaFunction(other); },
533 py::arg(
"other_type"));
534 def(
"__repr__", [superCls, captureTypeName](py::object
self) {
535 return py::repr(superCls(
self))
536 .attr(
"replace")(superCls.attr(
"__name__"), captureTypeName);
538 if (getTypeIDFunction) {
544 def_staticmethod(
"get_static_typeid",
545 [getTypeIDFunction]() {
return getTypeIDFunction(); });
548 getTypeIDFunction())(pybind11::cpp_function(
549 [thisClass = thisClass](
const py::object &mlirType) {
550 return thisClass(mlirType);
566 scope, valueClassName, isaFunction,
585 std::string captureValueName(
587 py::cpp_function newCf(
588 [superCls, isaFunction, captureValueName](py::object cls,
589 py::object otherValue) {
590 MlirValue rawValue = py::cast<MlirValue>(otherValue);
591 if (!isaFunction(rawValue)) {
592 auto origRepr = py::repr(otherValue).cast<std::string>();
593 throw std::invalid_argument((llvm::Twine(
"Cannot cast value to ") +
594 captureValueName +
" (from " +
598 py::object
self = superCls.attr(
"__new__")(cls, otherValue);
601 py::name(
"__new__"), py::arg(
"cls"), py::arg(
"cast_from_value"));
602 thisClass.attr(
"__new__") = newCf;
607 [isaFunction](MlirValue other) {
return isaFunction(other); },
608 py::arg(
"other_value"));
623 assert(errorMessage.empty() &&
"unchecked error message");
627 [[nodiscard]] std::string
takeMessage() {
return std::move(errorMessage); }
632 *
static_cast<std::string *
>(data) +=
633 llvm::StringRef(message.
data, message.
length);
636 *
static_cast<std::string *
>(data) +=
"at ";
638 *
static_cast<std::string *
>(data) +=
": ";
645 std::string errorMessage =
"";
static PyObject * mlirPythonModuleToCapsule(MlirModule module)
Creates a capsule object encapsulating the raw C-API MlirModule.
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR
Attribute on MLIR Python objects that expose a function for downcasting the corresponding Python obje...
static MlirBlock mlirPythonCapsuleToBlock(PyObject *capsule)
Extracts an MlirBlock from a capsule as produced from mlirPythonBlockToCapsule.
static PyObject * mlirPythonTypeIDToCapsule(MlirTypeID typeID)
Creates a capsule object encapsulating the raw C-API MlirTypeID.
static MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule)
Extracts an MlirOperations from a capsule as produced from mlirPythonOperationToCapsule.
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
static MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule)
Extracts an MlirAttribute from a capsule as produced from mlirPythonAttributeToCapsule.
static PyObject * mlirPythonAttributeToCapsule(MlirAttribute attribute)
Creates a capsule object encapsulating the raw C-API MlirAttribute.
static PyObject * mlirPythonLocationToCapsule(MlirLocation loc)
Creates a capsule object encapsulating the raw C-API MlirLocation.
static MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule)
Extracts an MlirAffineMap from a capsule as produced from mlirPythonAffineMapToCapsule.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static MlirModule mlirPythonCapsuleToModule(PyObject *capsule)
Extracts an MlirModule from a capsule as produced from mlirPythonModuleToCapsule.
static MlirContext mlirPythonCapsuleToContext(PyObject *capsule)
Extracts a MlirContext from a capsule as produced from mlirPythonContextToCapsule.
static MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule)
Extracts an MlirTypeID from a capsule as produced from mlirPythonTypeIDToCapsule.
static PyObject * mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry)
Creates a capsule object encapsulating the raw C-API MlirDialectRegistry.
static PyObject * mlirPythonTypeToCapsule(MlirType type)
Creates a capsule object encapsulating the raw C-API MlirType.
static MlirDialectRegistry mlirPythonCapsuleToDialectRegistry(PyObject *capsule)
Extracts an MlirDialectRegistry from a capsule as produced from mlirPythonDialectRegistryToCapsule.
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
static MlirType mlirPythonCapsuleToType(PyObject *capsule)
Extracts an MlirType from a capsule as produced from mlirPythonTypeToCapsule.
static MlirValue mlirPythonCapsuleToValue(PyObject *capsule)
Extracts an MlirValue from a capsule as produced from mlirPythonValueToCapsule.
static PyObject * mlirPythonAffineMapToCapsule(MlirAffineMap affineMap)
Creates a capsule object encapsulating the raw C-API MlirAffineMap.
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
static PyObject * mlirPythonOperationToCapsule(MlirOperation operation)
Creates a capsule object encapsulating the raw C-API MlirOperation.
static MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule)
Extracts an MlirLocation from a capsule as produced from mlirPythonLocationToCapsule.
static PyObject * mlirPythonValueToCapsule(MlirValue value)
Creates a capsule object encapsulating the raw C-API MlirValue.
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR
Attribute on main C extension module (_mlir) that corresponds to the type caster registration binding...
static std::string diag(const llvm::Value &value)
RAII scope intercepting all diagnostics into a string.
std::string takeMessage()
~CollectDiagnosticsToStringScope()
CollectDiagnosticsToStringScope(MlirContext ctx)
Creates a custom subclass of mlir.ir.Attribute, implementing a casting constructor and type checking ...
mlir_attribute_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Attribute super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirAttribute) IsAFunctionTy
mlir_attribute_subclass(py::handle scope, const char *attrClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
Creates a custom subclass of mlir.ir.Type, implementing a casting constructor and type checking metho...
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, const py::object &superCls, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses with a provided mlir.ir.Type super-class.
MlirTypeID(*)() GetTypeIDFunctionTy
mlir_type_subclass(py::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction=nullptr)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirType) IsAFunctionTy
Creates a custom subclass of mlir.ir.Value, implementing a casting constructor and type checking meth...
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction, const py::object &superCls)
Subclasses with a provided mlir.ir.Value super-class.
mlir_value_subclass(py::handle scope, const char *valueClassName, IsAFunctionTy isaFunction)
Subclasses by looking up the super-class dynamically.
bool(*)(MlirValue) IsAFunctionTy
Provides a facility like py::class_ for defining a new class in a scope, but this allows extension of...
pure_subclass & def_classmethod(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def(const char *name, Func &&f, const Extra &...extra)
py::object get_class() const
pure_subclass(py::handle scope, const char *derivedClassName, const py::object &superClass)
pure_subclass & def_property_readonly(const char *name, Func &&f, const Extra &...extra)
pure_subclass & def_staticmethod(const char *name, Func &&f, const Extra &...extra)
MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, void *userData)
Prints a diagnostic using the provided callback.
MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler(MlirContext context, MlirDiagnosticHandler handler, void *userData, void(*deleteUserData)(void *))
Attaches the diagnostic handler to the context.
MLIR_CAPI_EXPORTED void mlirContextDetachDiagnosticHandler(MlirContext context, MlirDiagnosticHandlerID id)
Detaches an attached diagnostic handler from the context given its identifier.
uint64_t MlirDiagnosticHandlerID
Opaque identifier of a diagnostic handler, useful to detach a handler.
MLIR_CAPI_EXPORTED MlirLocation mlirDiagnosticGetLocation(MlirDiagnostic diagnostic)
Returns the location at which the diagnostic is reported.
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
static bool mlirAffineMapIsNull(MlirAffineMap affineMap)
Checks whether an affine map is null.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
static bool mlirModuleIsNull(MlirModule module)
Checks whether a module is null.
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
static bool mlirContextIsNull(MlirContext context)
Checks whether a context is null.
static bool mlirBlockIsNull(MlirBlock block)
Checks whether a block is null.
static bool mlirLocationIsNull(MlirLocation location)
Checks if the location is null.
MLIR_CAPI_EXPORTED void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData)
Prints a location by sending chunks of the string representation and forwarding userData tocallback`.
static bool mlirDialectRegistryIsNull(MlirDialectRegistry registry)
Checks if the dialect registry is null.
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
static bool mlirTypeIDIsNull(MlirTypeID typeID)
Checks whether a type id is null.
Include the generated interface declarations.
static py::object mlirApiObjectToCapsule(py::handle apiObject)
Helper to convert a presumed MLIR API object to a capsule, accepting either an explicit Capsule (whic...
An opaque reference to a diagnostic, always owned by the diagnostics engine (context).
A logical result value, essentially a boolean with named states.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"))
static handle cast(MlirAffineMap v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"))
static handle cast(MlirAttribute v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"))
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"))
static handle cast(MlirDialectRegistry v, return_value_policy, handle)
bool load(handle src, bool)
bool load(handle src, bool)
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, _("MlirFrozenRewritePatternSet"))
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"))
bool load(handle src, bool)
static handle cast(MlirLocation v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"))
static handle cast(MlirModule v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"))
bool load(handle src, bool)
static handle cast(MlirOperation v, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"))
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"))
bool load(handle src, bool)
static handle cast(MlirTypeID v, return_value_policy, handle)
static handle cast(MlirType t, return_value_policy, handle)
bool load(handle src, bool)
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"))
bool load(handle src, bool)
static handle cast(MlirValue v, return_value_policy, handle)
PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"))