19#include "mlir/Config/mlir-config.h"
20#include "nanobind/nanobind.h"
24using namespace nb::literals;
29class PyPatternRewriter {
31 PyPatternRewriter(MlirPatternRewriter rewriter)
35 PyInsertionPoint getInsertionPoint()
const {
39 if (mlirOperationIsNull(op)) {
42 return PyInsertionPoint(PyBlock(parent, block));
48 void replaceOp(MlirOperation op, MlirOperation newOp) {
52 void replaceOp(MlirOperation op,
const std::vector<MlirValue> &values) {
59 MlirRewriterBase base;
63#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
64static nb::object objectFromPDLValue(MlirPDLValue value) {
65 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
67 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
69 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
74 throw std::runtime_error(
"unsupported PDL value type");
77static std::vector<nb::object> objectsFromPDLValues(
size_t nValues,
78 MlirPDLValue *values) {
79 std::vector<nb::object> args;
80 args.reserve(nValues);
81 for (
size_t i = 0; i < nValues; ++i)
82 args.push_back(objectFromPDLValue(values[i]));
99class PyPDLPatternModule {
101 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
102 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
103 : module(other.module) {
104 other.module.ptr =
nullptr;
106 ~PyPDLPatternModule() {
107 if (module.ptr !=
nullptr)
108 mlirPDLPatternModuleDestroy(module);
110 MlirPDLPatternModule
get() {
return module; }
112 void registerRewriteFunction(
const std::string &name,
113 const nb::callable &fn) {
114 mlirPDLPatternModuleRegisterRewriteFunction(
116 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
117 size_t nValues, MlirPDLValue *values,
119 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
120 return logicalResultFromObject(
121 f(PyPatternRewriter(rewriter), results,
122 objectsFromPDLValues(nValues, values)));
127 void registerConstraintFunction(
const std::string &name,
128 const nb::callable &fn) {
129 mlirPDLPatternModuleRegisterConstraintFunction(
131 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
132 size_t nValues, MlirPDLValue *values,
134 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
135 return logicalResultFromObject(
136 f(PyPatternRewriter(rewriter), results,
137 objectsFromPDLValues(nValues, values)));
143 MlirPDLPatternModule module;
148class PyFrozenRewritePatternSet {
150 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
151 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
153 other.set.ptr =
nullptr;
155 ~PyFrozenRewritePatternSet() {
156 if (set.ptr !=
nullptr)
159 MlirFrozenRewritePatternSet
get() {
return set; }
161 nb::object getCapsule() {
162 return nb::steal<nb::object>(
166 static nb::object createFromCapsule(
const nb::object &capsule) {
167 MlirFrozenRewritePatternSet rawPm =
169 if (rawPm.ptr ==
nullptr)
170 throw nb::python_error();
171 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
175 MlirFrozenRewritePatternSet set;
178class PyRewritePatternSet {
180 PyRewritePatternSet(MlirContext ctx)
182 ~PyRewritePatternSet() {
188 const nb::callable &matchAndRewrite) {
189 MlirRewritePatternCallbacks callbacks;
190 callbacks.
construct = [](
void *userData) {
191 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
193 callbacks.
destruct = [](
void *userData) {
194 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
197 MlirPatternRewriter rewriter,
199 nb::handle f(
static_cast<PyObject *
>(userData));
205 nb::object res = f(opView, PyPatternRewriter(rewriter));
206 return logicalResultFromObject(res);
209 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
215 PyFrozenRewritePatternSet freeze() {
216 MlirRewritePatternSet s = set;
222 MlirRewritePatternSet set;
234 class_<PyPatternRewriter>(m,
"PatternRewriter")
235 .def_prop_ro(
"ip", &PyPatternRewriter::getInsertionPoint,
236 "The current insertion point of the PatternRewriter.")
239 [](PyPatternRewriter &self, MlirOperation op,
240 MlirOperation newOp) { self.replaceOp(op, newOp); },
241 "Replace an operation with a new operation.", nb::arg(
"op"),
249 [](PyPatternRewriter &self, MlirOperation op,
250 const std::vector<MlirValue> &values) {
251 self.replaceOp(op, values);
253 "Replace an operation with a list of values.", nb::arg(
"op"),
259 .def(
"erase_op", &PyPatternRewriter::eraseOp,
"Erase an operation.",
269 nb::class_<PyRewritePatternSet>(m,
"RewritePatternSet")
273 new (&self) PyRewritePatternSet(context.
get()->
get());
275 "context"_a = nb::none())
278 [](PyRewritePatternSet &self, nb::handle root,
const nb::callable &fn,
281 nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
285 "root"_a,
"fn"_a,
"benefit"_a = 1,
286 "Add a new rewrite pattern on the given root operation with the "
287 "callable as the matching and rewriting function and the given "
289 .def(
"freeze", &PyRewritePatternSet::freeze,
290 "Freeze the pattern set into a frozen one.");
295#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
296 nb::class_<MlirPDLResultList>(m,
"PDLResultList")
299 [](MlirPDLResultList results,
const PyValue &value) {
300 mlirPDLResultListPushBackValue(results, value);
308 [](MlirPDLResultList results,
const PyOperation &op) {
309 mlirPDLResultListPushBackOperation(results, op);
317 [](MlirPDLResultList results,
const PyType &type) {
318 mlirPDLResultListPushBackType(results, type);
326 [](MlirPDLResultList results,
const PyAttribute &attr) {
327 mlirPDLResultListPushBackAttribute(results, attr);
333 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
336 [](PyPDLPatternModule &self, MlirModule module) {
338 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
343 "module"_a,
"Create a PDL module from the given module.")
346 [](PyPDLPatternModule &self,
PyModule &module) {
347 new (&self) PyPDLPatternModule(
348 mlirPDLPatternModuleFromModule(module.
get()));
353 "module"_a,
"Create a PDL module from the given module.")
356 [](PyPDLPatternModule &self) {
358 mlirRewritePatternSetFromPDLPatternModule(self.get())));
360 nb::keep_alive<0, 1>())
362 "register_rewrite_function",
363 [](PyPDLPatternModule &self,
const std::string &name,
364 const nb::callable &fn) {
365 self.registerRewriteFunction(name, fn);
367 nb::keep_alive<1, 3>())
369 "register_constraint_function",
370 [](PyPDLPatternModule &self,
const std::string &name,
371 const nb::callable &fn) {
372 self.registerConstraintFunction(name, fn);
374 nb::keep_alive<1, 3>());
376 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
378 &PyFrozenRewritePatternSet::getCapsule)
380 &PyFrozenRewritePatternSet::createFromCapsule);
382 "apply_patterns_and_fold_greedily",
383 [](
PyModule &module, PyFrozenRewritePatternSet &set) {
387 throw std::runtime_error(
"pattern application failed to converge");
391 nb::sig(
"def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME(
"ir.Module")
", set: FrozenRewritePatternSet) -> None"),
393 "Applys the given patterns to the given module greedily while folding "
396 "apply_patterns_and_fold_greedily",
397 [](
PyModule &module, MlirFrozenRewritePatternSet set) {
401 throw std::runtime_error(
402 "pattern application failed to converge");
406 nb::sig(
"def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME(
"ir.Module")
", set: FrozenRewritePatternSet) -> None"),
408 "Applys the given patterns to the given module greedily while "
412 "apply_patterns_and_fold_greedily",
417 throw std::runtime_error(
418 "pattern application failed to converge");
422 nb::sig(
"def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME(
"ir._OperationBase")
", set: FrozenRewritePatternSet) -> None"),
424 "Applys the given patterns to the given op greedily while folding "
427 "apply_patterns_and_fold_greedily",
432 throw std::runtime_error(
433 "pattern application failed to converge");
437 nb::sig(
"def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME(
"ir._OperationBase")
", set: FrozenRewritePatternSet) -> None"),
439 "Applys the given patterns to the given op greedily while folding "
MlirContext mlirOperationGetContext(MlirOperation op)
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Used in function arguments when None should resolve to the current context manager set instance.
ReferrentTy * get() const
Wrapper around the generic MlirAttribute.
MlirContext get()
Accesses the underlying MlirContext.
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
MlirModule get()
Gets the backing MlirModule.
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Wrapper around the generic MlirType.
Wrapper around the generic MlirValue.
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter)
Returns the operation right after the current insertion point of the rewriter.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
struct MlirLogicalResult MlirLogicalResult
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
struct MlirStringRef MlirStringRef
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
void populateRewriteSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.