19 #include "mlir/Config/mlir-config.h"
20 #include "nanobind/nanobind.h"
24 using namespace nb::literals;
29 class PyPatternRewriter {
31 PyPatternRewriter(MlirPatternRewriter rewriter)
41 auto parent = PyOperation::forOperation(ctx, owner);
48 void replaceOp(MlirOperation op, MlirOperation newOp) {
52 void replaceOp(MlirOperation op,
const std::vector<MlirValue> &values) {
59 MlirRewriterBase base;
63 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
64 static nb::object objectFromPDLValue(MlirPDLValue value) {
74 throw std::runtime_error(
"unsupported PDL value type");
77 static std::vector<nb::object> objectsFromPDLValues(
size_t nValues,
78 MlirPDLValue *values) {
79 std::vector<nb::object> args;
80 args.reserve(nValues);
81 for (
size_t i = 0; i < nValues; ++i)
82 args.push_back(objectFromPDLValue(values[i]));
99 class PyPDLPatternModule {
101 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
102 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
103 : module(other.module) {
104 other.module.ptr =
nullptr;
106 ~PyPDLPatternModule() {
107 if (module.ptr !=
nullptr)
110 MlirPDLPatternModule
get() {
return module; }
112 void registerRewriteFunction(
const std::string &name,
113 const nb::callable &fn) {
116 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
117 size_t nValues, MlirPDLValue *values,
119 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
120 return logicalResultFromObject(
121 f(PyPatternRewriter(rewriter), results,
122 objectsFromPDLValues(nValues, values)));
127 void registerConstraintFunction(
const std::string &name,
128 const nb::callable &fn) {
131 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
132 size_t nValues, MlirPDLValue *values,
134 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
135 return logicalResultFromObject(
136 f(PyPatternRewriter(rewriter), results,
137 objectsFromPDLValues(nValues, values)));
143 MlirPDLPatternModule module;
148 class PyFrozenRewritePatternSet {
150 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
151 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
153 other.set.ptr =
nullptr;
155 ~PyFrozenRewritePatternSet() {
156 if (set.ptr !=
nullptr)
159 MlirFrozenRewritePatternSet
get() {
return set; }
161 nb::object getCapsule() {
162 return nb::steal<nb::object>(
166 static nb::object createFromCapsule(
const nb::object &capsule) {
167 MlirFrozenRewritePatternSet rawPm =
169 if (rawPm.ptr ==
nullptr)
170 throw nb::python_error();
171 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
175 MlirFrozenRewritePatternSet set;
178 class PyRewritePatternSet {
180 PyRewritePatternSet(MlirContext ctx)
182 ~PyRewritePatternSet() {
188 const nb::callable &matchAndRewrite) {
190 callbacks.
construct = [](
void *userData) {
191 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
193 callbacks.
destruct = [](
void *userData) {
194 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
197 MlirPatternRewriter rewriter,
199 nb::handle f(
static_cast<PyObject *
>(userData));
203 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
205 nb::object res = f(opView, PyPatternRewriter(rewriter));
206 return logicalResultFromObject(res);
209 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
215 PyFrozenRewritePatternSet freeze() {
216 MlirRewritePatternSet s = set;
222 MlirRewritePatternSet set;
234 class_<PyPatternRewriter>(m,
"PatternRewriter")
235 .def_prop_ro(
"ip", &PyPatternRewriter::getInsertionPoint,
236 "The current insertion point of the PatternRewriter.")
239 [](PyPatternRewriter &
self, MlirOperation op,
240 MlirOperation newOp) {
self.replaceOp(op, newOp); },
241 "Replace an operation with a new operation.", nb::arg(
"op"),
249 [](PyPatternRewriter &
self, MlirOperation op,
250 const std::vector<MlirValue> &values) {
251 self.replaceOp(op, values);
253 "Replace an operation with a list of values.", nb::arg(
"op"),
259 .def(
"erase_op", &PyPatternRewriter::eraseOp,
"Erase an operation.",
269 nb::class_<PyRewritePatternSet>(m,
"RewritePatternSet")
273 new (&
self) PyRewritePatternSet(context.
get()->
get());
275 "context"_a = nb::none())
278 [](PyRewritePatternSet &
self, nb::handle root,
const nb::callable &fn,
281 nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
285 "root"_a,
"fn"_a,
"benefit"_a = 1,
286 "Add a new rewrite pattern on the given root operation with the "
287 "callable as the matching and rewriting function and the given "
289 .def(
"freeze", &PyRewritePatternSet::freeze,
290 "Freeze the pattern set into a frozen one.");
295 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
296 nb::class_<MlirPDLResultList>(m,
"PDLResultList")
299 [](MlirPDLResultList results,
const PyValue &value) {
308 [](MlirPDLResultList results,
const PyOperation &op) {
317 [](MlirPDLResultList results,
const PyType &type) {
326 [](MlirPDLResultList results,
const PyAttribute &attr) {
333 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
336 [](PyPDLPatternModule &
self, MlirModule module) {
343 "module"_a,
"Create a PDL module from the given module.")
346 [](PyPDLPatternModule &
self,
PyModule &module) {
347 new (&
self) PyPDLPatternModule(
353 "module"_a,
"Create a PDL module from the given module.")
356 [](PyPDLPatternModule &
self) {
360 nb::keep_alive<0, 1>())
362 "register_rewrite_function",
363 [](PyPDLPatternModule &
self,
const std::string &name,
364 const nb::callable &fn) {
365 self.registerRewriteFunction(name, fn);
367 nb::keep_alive<1, 3>())
369 "register_constraint_function",
370 [](PyPDLPatternModule &
self,
const std::string &name,
371 const nb::callable &fn) {
372 self.registerConstraintFunction(name, fn);
374 nb::keep_alive<1, 3>());
376 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
378 &PyFrozenRewritePatternSet::getCapsule)
380 &PyFrozenRewritePatternSet::createFromCapsule);
382 "apply_patterns_and_fold_greedily",
383 [](
PyModule &module, PyFrozenRewritePatternSet &set) {
387 throw std::runtime_error(
"pattern application failed to converge");
391 nb::sig(
"def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME(
"ir.Module")
", set: FrozenRewritePatternSet) -> None"),
393 "Applys the given patterns to the given module greedily while folding "
396 "apply_patterns_and_fold_greedily",
397 [](
PyModule &module, MlirFrozenRewritePatternSet set) {
401 throw std::runtime_error(
402 "pattern application failed to converge");
406 nb::sig(
"def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME(
"ir.Module")
", set: FrozenRewritePatternSet) -> None"),
408 "Applys the given patterns to the given module greedily while "
412 "apply_patterns_and_fold_greedily",
417 throw std::runtime_error(
418 "pattern application failed to converge");
422 nb::sig(
"def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME(
"ir._OperationBase")
", set: FrozenRewritePatternSet) -> None"),
424 "Applys the given patterns to the given op greedily while folding "
427 "apply_patterns_and_fold_greedily",
432 throw std::runtime_error(
433 "pattern application failed to converge");
437 nb::sig(
"def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME(
"ir._OperationBase")
", set: FrozenRewritePatternSet) -> None"),
439 "Applys the given patterns to the given op greedily while folding "
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Used in function arguments when None should resolve to the current context manager set instance.
ReferrentTy * get() const
Wrapper around the generic MlirAttribute.
Wrapper around an MlirBlock.
An insertion point maintains a pointer to a Block and a reference operation.
MlirContext get()
Accesses the underlying MlirContext.
MlirModule get()
Gets the backing MlirModule.
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
MlirOperation get() const
Wrapper around the generic MlirType.
Wrapper around the generic MlirValue.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value)
Cast the MlirPDLValue to an MlirValue.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData)
Register a rewrite function into the given PDL pattern module.
MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op)
MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value)
Cast the MlirPDLValue to an MlirType.
MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value)
Cast the MlirPDLValue to an MlirOperation.
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, MlirAttribute value)
Push the MlirAttribute into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackOperation(MlirPDLResultList results, MlirOperation value)
Push the MlirOperation into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op)
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value)
Push the MlirType into the given MlirPDLResultList.
MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value)
Cast the MlirPDLValue to an MlirAttribute.
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLConstraintFunction constraintFn, void *userData)
Register a constraint function into the given PDL pattern module.
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value)
Push the MlirValue into the given MlirPDLResultList.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
static bool mlirValueIsNull(MlirValue value)
Returns whether the value is null.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
static bool mlirOperationIsNull(MlirOperation op)
Checks whether the underlying operation is null.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
void populateRewriteSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.