19#include "mlir/Config/mlir-config.h"
20#include "nanobind/nanobind.h"
24using namespace nb::literals;
41 if (mlirOperationIsNull(op)) {
50 void replaceOp(MlirOperation op, MlirOperation newOp) {
54 void replaceOp(MlirOperation op,
const std::vector<MlirValue> &values) {
61 MlirRewriterBase base;
67#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
68static nb::object objectFromPDLValue(MlirPDLValue value) {
69 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
71 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
73 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
78 throw std::runtime_error(
"unsupported PDL value type");
81static std::vector<nb::object> objectsFromPDLValues(
size_t nValues,
82 MlirPDLValue *values) {
83 std::vector<nb::object> args;
84 args.reserve(nValues);
85 for (
size_t i = 0; i < nValues; ++i)
86 args.push_back(objectFromPDLValue(values[i]));
103class PyPDLPatternModule {
105 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
106 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
107 : module(other.module) {
108 other.module.ptr =
nullptr;
110 ~PyPDLPatternModule() {
111 if (module.ptr !=
nullptr)
112 mlirPDLPatternModuleDestroy(module);
114 MlirPDLPatternModule
get() {
return module; }
116 void registerRewriteFunction(
const std::string &name,
117 const nb::callable &fn) {
118 mlirPDLPatternModuleRegisterRewriteFunction(
120 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
121 size_t nValues, MlirPDLValue *values,
123 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
124 return logicalResultFromObject(
125 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
126 objectsFromPDLValues(nValues, values)));
131 void registerConstraintFunction(
const std::string &name,
132 const nb::callable &fn) {
133 mlirPDLPatternModuleRegisterConstraintFunction(
135 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
136 size_t nValues, MlirPDLValue *values,
138 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
139 return logicalResultFromObject(
140 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
141 objectsFromPDLValues(nValues, values)));
147 MlirPDLPatternModule module;
157 other.set.ptr =
nullptr;
160 if (set.ptr !=
nullptr)
163 MlirFrozenRewritePatternSet
get() {
return set; }
166 return nb::steal<nb::object>(
171 MlirFrozenRewritePatternSet rawPm =
173 if (rawPm.ptr ==
nullptr)
174 throw nb::python_error();
179 MlirFrozenRewritePatternSet set;
192 const nb::callable &matchAndRewrite) {
194 callbacks.
construct = [](
void *userData) {
195 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
197 callbacks.
destruct = [](
void *userData) {
198 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
201 MlirPatternRewriter rewriter,
203 nb::handle f(
static_cast<PyObject *
>(userData));
210 return logicalResultFromObject(res);
213 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
220 MlirRewritePatternSet s = set;
226 MlirRewritePatternSet set;
251 :
config(std::move(other.config)) {}
255 MlirGreedyRewriteDriverConfig
get() {
256 return MlirGreedyRewriteDriverConfig{config.get()};
269 useTopDownTraversal);
321 std::shared_ptr<void>
config;
322 static void customDeleter(
void *c) {
330 nb::enum_<PyGreedyRewriteStrictness>(m,
"GreedyRewriteStrictness")
332 .value(
"EXISTING_AND_NEW_OPS",
336 nb::enum_<PyGreedySimplifyRegionLevel>(m,
"GreedySimplifyRegionLevel")
343 nb::class_<PyPatternRewriter>(m,
"PatternRewriter")
345 "The current insertion point of the PatternRewriter.")
352 "Replace an operation with a new operation.", nb::arg(
"op"),
357 const std::vector<PyValue> &values) {
358 std::vector<MlirValue> values_(values.size());
359 std::copy(values.begin(), values.end(), values_.begin());
362 "Replace an operation with a list of values.", nb::arg(
"op"),
370 nb::class_<PyRewritePatternSet>(m,
"RewritePatternSet")
376 "context"_a = nb::none())
382 if (root.is_type()) {
383 opName = nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
384 }
else if (nb::isinstance<nb::str>(root)) {
385 opName = nb::cast<std::string>(root);
387 throw nb::type_error(
388 "the root argument must be a type or a string");
393 "root"_a,
"fn"_a,
"benefit"_a = 1,
395 nb::sig(
"def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME(
"ir.Operation")
", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
398 Add a new rewrite pattern on the specified root operation, using the provided callable
399 for matching and rewriting, and assign it the given benefit.
402 root: The root operation to which this pattern applies.
403 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
404 an operation name string (e.g., ``"arith.addi"``).
405 fn: The callable to use for matching and rewriting,
406 which takes an operation and a pattern rewriter as arguments.
407 The match is considered successful iff the callable returns
408 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
409 If possible, the operation is cast to its corresponding OpView subclass
410 before being passed to the callable.
411 benefit: The benefit of the pattern, defaulting to 1.)")
413 "Freeze the pattern set into a frozen one.");
418#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
419 nb::class_<PyMlirPDLResultList>(m,
"PDLResultList")
422 mlirPDLResultListPushBackValue(results, value);
426 mlirPDLResultListPushBackOperation(results, op);
430 mlirPDLResultListPushBackType(results, type);
433 mlirPDLResultListPushBackAttribute(results, attr);
435 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
438 [](PyPDLPatternModule &self,
PyModule &module) {
439 new (&self) PyPDLPatternModule(
440 mlirPDLPatternModuleFromModule(module.
get()));
442 "module"_a,
"Create a PDL module from the given module.")
445 [](PyPDLPatternModule &self,
PyModule &module) {
446 new (&self) PyPDLPatternModule(
447 mlirPDLPatternModuleFromModule(module.
get()));
449 "module"_a,
"Create a PDL module from the given module.")
452 [](PyPDLPatternModule &self) {
454 mlirRewritePatternSetFromPDLPatternModule(self.get())));
456 nb::keep_alive<0, 1>())
458 "register_rewrite_function",
459 [](PyPDLPatternModule &self,
const std::string &name,
460 const nb::callable &fn) {
461 self.registerRewriteFunction(name, fn);
463 nb::keep_alive<1, 3>())
465 "register_constraint_function",
466 [](PyPDLPatternModule &self,
const std::string &name,
467 const nb::callable &fn) {
468 self.registerConstraintFunction(name, fn);
470 nb::keep_alive<1, 3>());
473 nb::class_<PyGreedyRewriteConfig>(m,
"GreedyRewriteConfig")
474 .def(nb::init<>(),
"Create a greedy rewrite driver config with defaults")
477 "Maximum number of iterations")
478 .def_prop_rw(
"max_num_rewrites",
481 "Maximum number of rewrites per iteration")
482 .def_prop_rw(
"use_top_down_traversal",
485 "Whether to use top-down traversal")
488 "Enable or disable folding")
491 "Rewrite strictness level")
492 .def_prop_rw(
"region_simplification_level",
495 "Region simplification level")
496 .def_prop_rw(
"enable_constant_cse",
499 "Enable or disable constant CSE");
501 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
507 "apply_patterns_and_fold_greedily",
509 std::optional<PyGreedyRewriteConfig>
config) {
515 throw std::runtime_error(
"pattern application failed to converge");
517 "module"_a,
"set"_a,
"config"_a = nb::none(),
518 "Applys the given patterns to the given module greedily while folding "
521 "apply_patterns_and_fold_greedily",
523 std::optional<PyGreedyRewriteConfig>
config) {
529 throw std::runtime_error(
530 "pattern application failed to converge");
532 "op"_a,
"set"_a,
"config"_a = nb::none(),
533 "Applys the given patterns to the given op greedily while folding "
536 "walk_and_apply_patterns",
541 "Applies the given patterns to the given op by a fast walk-based "
MlirContext mlirOperationGetContext(MlirOperation op)
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Wrapper around the generic MlirAttribute.
Wrapper around an MlirBlock.
Owning Wrapper around a FrozenRewritePatternSet.
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set)
MlirFrozenRewritePatternSet get()
static nb::object createFromCapsule(const nb::object &capsule)
~PyFrozenRewritePatternSet()
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
int64_t getMaxIterations()
MlirGreedyRewriteDriverConfig get()
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
void setUseTopDownTraversal(bool useTopDownTraversal)
int64_t getMaxNumRewrites()
PyGreedySimplifyRegionLevel getRegionSimplificationLevel()
void enableConstantCSE(bool enable)
void setMaxIterations(int64_t maxIterations)
PyGreedyRewriteStrictness getStrictness()
void enableFolding(bool enable)
void setMaxNumRewrites(int64_t maxNumRewrites)
bool getUseTopDownTraversal()
bool isConstantCSEEnabled()
void setStrictness(PyGreedyRewriteStrictness strictness)
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
An insertion point maintains a pointer to a Block and a reference operation.
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
MlirContext get()
Accesses the underlying MlirContext.
MlirModule get()
Gets the backing MlirModule.
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
void replaceOp(MlirOperation op, const std::vector< MlirValue > &values)
void replaceOp(MlirOperation op, MlirOperation newOp)
void eraseOp(const PyOperation &op)
PyPatternRewriter(MlirPatternRewriter rewriter)
PyInsertionPoint getInsertionPoint() const
void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite)
PyFrozenRewritePatternSet freeze()
PyRewritePatternSet(MlirContext ctx)
Wrapper around the generic MlirType.
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter)
Returns the operation right after the current insertion point of the rewriter.
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
struct MlirLogicalResult MlirLogicalResult
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
PyGreedySimplifyRegionLevel
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
PyGreedyRewriteStrictness
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.