MLIR  20.0.0git
Rewrite.cpp
Go to the documentation of this file.
1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Rewrite.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Rewrite.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15 #include "mlir/Config/mlir-config.h"
16 
17 namespace nb = nanobind;
18 using namespace mlir;
19 using namespace nb::literals;
20 using namespace mlir::python;
21 
22 namespace {
23 
24 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
25 /// Owning Wrapper around a PDLPatternModule.
26 class PyPDLPatternModule {
27 public:
28  PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
29  PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
30  : module(other.module) {
31  other.module.ptr = nullptr;
32  }
33  ~PyPDLPatternModule() {
34  if (module.ptr != nullptr)
36  }
37  MlirPDLPatternModule get() { return module; }
38 
39 private:
40  MlirPDLPatternModule module;
41 };
42 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
43 
44 /// Owning Wrapper around a FrozenRewritePatternSet.
45 class PyFrozenRewritePatternSet {
46 public:
47  PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
48  PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
49  : set(other.set) {
50  other.set.ptr = nullptr;
51  }
52  ~PyFrozenRewritePatternSet() {
53  if (set.ptr != nullptr)
55  }
56  MlirFrozenRewritePatternSet get() { return set; }
57 
58  nb::object getCapsule() {
59  return nb::steal<nb::object>(
61  }
62 
63  static nb::object createFromCapsule(nb::object capsule) {
64  MlirFrozenRewritePatternSet rawPm =
66  if (rawPm.ptr == nullptr)
67  throw nb::python_error();
68  return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
69  }
70 
71 private:
72  MlirFrozenRewritePatternSet set;
73 };
74 
75 } // namespace
76 
77 /// Create the `mlir.rewrite` here.
78 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
79  //----------------------------------------------------------------------------
80  // Mapping of the top-level PassManager
81  //----------------------------------------------------------------------------
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83  nb::class_<PyPDLPatternModule>(m, "PDLModule")
84  .def(
85  "__init__",
86  [](PyPDLPatternModule &self, MlirModule module) {
87  new (&self)
88  PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
89  },
90  "module"_a, "Create a PDL module from the given module.")
91  .def("freeze", [](PyPDLPatternModule &self) {
92  return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
94  });
95 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
96  nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
97  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
98  &PyFrozenRewritePatternSet::getCapsule)
100  &PyFrozenRewritePatternSet::createFromCapsule);
101  m.def(
102  "apply_patterns_and_fold_greedily",
103  [](MlirModule module, MlirFrozenRewritePatternSet set) {
104  auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105  if (mlirLogicalResultIsFailure(status))
106  // FIXME: Not sure this is the right error to throw here.
107  throw nb::value_error("pattern application failed to converge");
108  },
109  "module"_a, "set"_a,
110  "Applys the given patterns to the given module greedily while folding "
111  "results.");
112 }
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op)
Definition: Rewrite.cpp:283
MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition: Rewrite.cpp:289
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op)
FrozenRewritePatternSet API.
Definition: Rewrite.cpp:277
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition: Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition: Interop.h:293
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateRewriteSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...