14 #include "mlir/Config/mlir-config.h"
18 using namespace py::literals;
23 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
25 class PyPDLPatternModule {
27 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29 : module(other.module) {
30 other.module.ptr =
nullptr;
32 ~PyPDLPatternModule() {
33 if (module.ptr !=
nullptr)
36 MlirPDLPatternModule
get() {
return module; }
39 MlirPDLPatternModule module;
44 class PyFrozenRewritePatternSet {
46 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
49 other.set.ptr =
nullptr;
51 ~PyFrozenRewritePatternSet() {
52 if (set.ptr !=
nullptr)
55 MlirFrozenRewritePatternSet
get() {
return set; }
57 pybind11::object getCapsule() {
58 return py::reinterpret_steal<py::object>(
62 static pybind11::object createFromCapsule(pybind11::object capsule) {
63 MlirFrozenRewritePatternSet rawPm =
65 if (rawPm.ptr ==
nullptr)
66 throw py::error_already_set();
67 return py::cast(PyFrozenRewritePatternSet(rawPm),
68 py::return_value_policy::move);
72 MlirFrozenRewritePatternSet set;
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83 py::class_<PyPDLPatternModule>(m,
"PDLModule", py::module_local())
84 .def(py::init<>([](MlirModule module) {
87 "module"_a,
"Create a PDL module from the given module.")
88 .def(
"freeze", [](PyPDLPatternModule &
self) {
93 py::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet",
96 &PyFrozenRewritePatternSet::getCapsule)
98 &PyFrozenRewritePatternSet::createFromCapsule);
100 "apply_patterns_and_fold_greedily",
101 [](MlirModule module, MlirFrozenRewritePatternSet set) {
105 throw py::value_error(
"pattern application failed to converge");
108 "Applys the given patterns to the given module greedily while folding "
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
void populateRewriteSubmodule(pybind11::module &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...