MLIR  20.0.0git
Rewrite.cpp
Go to the documentation of this file.
1 //===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Rewrite.h"
10 
11 #include "IRModule.h"
13 #include "mlir-c/Rewrite.h"
14 #include "mlir/Config/mlir-config.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace py::literals;
19 using namespace mlir::python;
20 
21 namespace {
22 
23 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
24 /// Owning Wrapper around a PDLPatternModule.
25 class PyPDLPatternModule {
26 public:
27  PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28  PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29  : module(other.module) {
30  other.module.ptr = nullptr;
31  }
32  ~PyPDLPatternModule() {
33  if (module.ptr != nullptr)
35  }
36  MlirPDLPatternModule get() { return module; }
37 
38 private:
39  MlirPDLPatternModule module;
40 };
41 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
42 
43 /// Owning Wrapper around a FrozenRewritePatternSet.
44 class PyFrozenRewritePatternSet {
45 public:
46  PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47  PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
48  : set(other.set) {
49  other.set.ptr = nullptr;
50  }
51  ~PyFrozenRewritePatternSet() {
52  if (set.ptr != nullptr)
54  }
55  MlirFrozenRewritePatternSet get() { return set; }
56 
57  pybind11::object getCapsule() {
58  return py::reinterpret_steal<py::object>(
60  }
61 
62  static pybind11::object createFromCapsule(pybind11::object capsule) {
63  MlirFrozenRewritePatternSet rawPm =
65  if (rawPm.ptr == nullptr)
66  throw py::error_already_set();
67  return py::cast(PyFrozenRewritePatternSet(rawPm),
68  py::return_value_policy::move);
69  }
70 
71 private:
72  MlirFrozenRewritePatternSet set;
73 };
74 
75 } // namespace
76 
77 /// Create the `mlir.rewrite` here.
78 void mlir::python::populateRewriteSubmodule(py::module &m) {
79  //----------------------------------------------------------------------------
80  // Mapping of the top-level PassManager
81  //----------------------------------------------------------------------------
82 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83  py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
84  .def(py::init<>([](MlirModule module) {
85  return mlirPDLPatternModuleFromModule(module);
86  }),
87  "module"_a, "Create a PDL module from the given module.")
88  .def("freeze", [](PyPDLPatternModule &self) {
89  return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
91  });
92 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
93  py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
94  py::module_local())
95  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
96  &PyFrozenRewritePatternSet::getCapsule)
98  &PyFrozenRewritePatternSet::createFromCapsule);
99  m.def(
100  "apply_patterns_and_fold_greedily",
101  [](MlirModule module, MlirFrozenRewritePatternSet set) {
102  auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
103  if (mlirLogicalResultIsFailure(status))
104  // FIXME: Not sure this is the right error to throw here.
105  throw py::value_error("pattern application failed to converge");
106  },
107  "module"_a, "set"_a,
108  "Applys the given patterns to the given module greedily while folding "
109  "results.");
110 }
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op)
Definition: Rewrite.cpp:283
MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition: Rewrite.cpp:289
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op)
FrozenRewritePatternSet API.
Definition: Rewrite.cpp:277
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition: Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition: Interop.h:293
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateRewriteSubmodule(pybind11::module &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...