15 #include "mlir/Config/mlir-config.h"
19 using namespace nb::literals;
24 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
26 class PyPDLPatternModule {
28 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
29 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
30 : module(other.module) {
31 other.module.ptr =
nullptr;
33 ~PyPDLPatternModule() {
34 if (module.ptr !=
nullptr)
37 MlirPDLPatternModule
get() {
return module; }
40 MlirPDLPatternModule module;
45 class PyFrozenRewritePatternSet {
47 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
48 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
50 other.set.ptr =
nullptr;
52 ~PyFrozenRewritePatternSet() {
53 if (set.ptr !=
nullptr)
56 MlirFrozenRewritePatternSet
get() {
return set; }
58 nb::object getCapsule() {
59 return nb::steal<nb::object>(
63 static nb::object createFromCapsule(nb::object capsule) {
64 MlirFrozenRewritePatternSet rawPm =
66 if (rawPm.ptr ==
nullptr)
67 throw nb::python_error();
68 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
72 MlirFrozenRewritePatternSet set;
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
86 [](PyPDLPatternModule &
self, MlirModule module) {
90 "module"_a,
"Create a PDL module from the given module.")
91 .def(
"freeze", [](PyPDLPatternModule &
self) {
96 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
98 &PyFrozenRewritePatternSet::getCapsule)
100 &PyFrozenRewritePatternSet::createFromCapsule);
102 "apply_patterns_and_fold_greedily",
103 [](MlirModule module, MlirFrozenRewritePatternSet set) {
107 throw nb::value_error(
"pattern application failed to converge");
110 "Applys the given patterns to the given module greedily while folding "
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
void populateRewriteSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...