17#include "mlir/Config/mlir-config.h"
18#include "nanobind/nanobind.h"
23using namespace nb::literals;
71 MlirConversionTarget
get() {
return target; }
74 MlirConversionTarget
target;
81 : typeConverter(typeConverter), owner(
false) {}
90 [](MlirType type, MlirType *converted,
92 nb::handle f = nb::handle(
static_cast<PyObject *
>(userData));
94 nb::object res = f(
PyType(ctx, type).maybeDownCast());
98 *converted = nb::cast<PyType>(res).get();
104 MlirTypeConverter
get() {
return typeConverter; }
107 MlirTypeConverter typeConverter;
120 MlirConversionPattern pattern;
123#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
124struct PyMlirPDLResultList : MlirPDLResultList {};
126static nb::object objectFromPDLValue(MlirPDLValue value) {
127 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
129 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
131 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
136 throw std::runtime_error(
"unsupported PDL value type");
139static std::vector<nb::object> objectsFromPDLValues(
size_t nValues,
140 MlirPDLValue *values) {
141 std::vector<nb::object> args;
142 args.reserve(nValues);
143 for (
size_t i = 0; i < nValues; ++i)
144 args.push_back(objectFromPDLValue(values[i]));
161class PyPDLPatternModule {
163 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
164 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
165 : module(other.module) {
166 other.module.ptr =
nullptr;
168 ~PyPDLPatternModule() {
169 if (module.ptr !=
nullptr)
170 mlirPDLPatternModuleDestroy(module);
172 MlirPDLPatternModule
get() {
return module; }
174 void registerRewriteFunction(
const std::string &name,
175 const nb::callable &fn) {
176 mlirPDLPatternModuleRegisterRewriteFunction(
178 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
179 size_t nValues, MlirPDLValue *values,
181 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
182 return logicalResultFromObject(
183 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
184 objectsFromPDLValues(nValues, values)));
189 void registerConstraintFunction(
const std::string &name,
190 const nb::callable &fn) {
191 mlirPDLPatternModuleRegisterConstraintFunction(
193 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
194 size_t nValues, MlirPDLValue *values,
196 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
197 return logicalResultFromObject(
198 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
199 objectsFromPDLValues(nValues, values)));
205 MlirPDLPatternModule module;
215 other.set.ptr =
nullptr;
218 if (set.ptr !=
nullptr)
221 MlirFrozenRewritePatternSet
get() {
return set; }
224 return nb::steal<nb::object>(
229 MlirFrozenRewritePatternSet rawPm =
231 if (rawPm.ptr ==
nullptr)
232 throw nb::python_error();
237 MlirFrozenRewritePatternSet set;
250 const nb::callable &matchAndRewrite) {
252 callbacks.
construct = [](
void *userData) {
253 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
255 callbacks.
destruct = [](
void *userData) {
256 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
259 MlirPatternRewriter rewriter,
261 nb::handle f(
static_cast<PyObject *
>(userData));
268 return logicalResultFromObject(res);
271 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
278 const nb::callable &matchAndRewrite,
281 callbacks.
construct = [](
void *userData) {
282 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
284 callbacks.
destruct = [](
void *userData) {
285 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
288 [](MlirConversionPattern pattern, MlirOperation op,
intptr_t nOperands,
289 MlirValue *operands, MlirConversionPatternRewriter rewriter,
291 nb::handle f(
static_cast<PyObject *
>(userData));
297 std::vector<MlirValue> operandsVec(operands, operands + nOperands);
298 nb::object adaptorCls =
302 .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
304 nb::object res = f(opView, adaptorCls(operandsVec, opView),
307 return logicalResultFromObject(res);
310 rootName, benefit, ctx, typeConverter.
get(), callbacks,
311 matchAndRewrite.ptr(),
319 MlirRewritePatternSet s = set;
325 MlirRewritePatternSet set;
350 :
config(std::move(other.config)) {}
354 MlirGreedyRewriteDriverConfig
get() {
355 return MlirGreedyRewriteDriverConfig{config.get()};
368 useTopDownTraversal);
420 std::shared_ptr<void>
config;
421 static void customDeleter(
void *c) {
439 MlirConversionConfig
get() {
return MlirConversionConfig{config.get()}; }
460 std::shared_ptr<void>
config;
461 static void customDeleter(
void *c) {
469 nb::enum_<PyGreedyRewriteStrictness>(m,
"GreedyRewriteStrictness")
471 .value(
"EXISTING_AND_NEW_OPS",
475 nb::enum_<PyGreedySimplifyRegionLevel>(m,
"GreedySimplifyRegionLevel")
480 nb::enum_<PyDialectConversionFoldingMode>(m,
"DialectConversionFoldingMode")
494 nb::class_<PyRewritePatternSet>(m,
"RewritePatternSet")
500 "context"_a = nb::none())
506 if (root.is_type()) {
507 opName = nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
508 }
else if (nb::isinstance<nb::str>(root)) {
509 opName = nb::cast<std::string>(root);
511 throw nb::type_error(
512 "the root argument must be a type or a string");
517 "root"_a,
"fn"_a,
"benefit"_a = 1,
519 nb::sig(
"def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME(
"ir.Operation")
", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
522 Add a new rewrite pattern on the specified root operation, using the provided callable
523 for matching and rewriting, and assign it the given benefit.
526 root: The root operation to which this pattern applies.
527 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
528 an operation name string (e.g., ``"arith.addi"``).
529 fn: The callable to use for matching and rewriting,
530 which takes an operation and a pattern rewriter as arguments.
531 The match is considered successful iff the callable returns
532 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
533 If possible, the operation is cast to its corresponding OpView subclass
534 before being passed to the callable.
535 benefit: The benefit of the pattern, defaulting to 1.)")
541 nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
546 "root"_a,
"fn"_a,
"type_converter"_a,
"benefit"_a = 1,
548 Add a new conversion pattern on the specified root operation,
549 using the provided callable for matching and rewriting,
550 and assign it the given benefit.
553 root: The root operation to which this pattern applies.
554 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
555 an operation name string (e.g., ``"arith.addi"``).
556 fn: The callable to use for matching and rewriting,
557 which takes an operation, its adaptor,
558 the type converter and a pattern rewriter as arguments.
559 The match is considered successful iff the callable returns
560 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
561 If possible, the operation is cast to its corresponding OpView subclass
562 before being passed to the callable.
563 type_converter: The type converter to convert types in the IR.
564 benefit: The benefit of the pattern, defaulting to 1.)")
566 "Freeze the pattern set into a frozen one.");
568 nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
569 m,
"ConversionPatternRewriter");
571 nb::class_<PyConversionTarget>(m,
"ConversionTarget")
577 "context"_a = nb::none())
581 for (
auto op : ops) {
583 nb::cast<std::string>(op.attr(
"OPERATION_NAME"));
584 self.addLegalOp(opName);
587 "ops"_a,
"Mark the given operations as legal.")
591 for (
auto op : ops) {
593 nb::cast<std::string>(op.attr(
"OPERATION_NAME"));
594 self.addIllegalOp(opName);
597 "ops"_a,
"Mark the given operations as illegal.")
601 for (
auto dialect : dialects) {
602 std::string dialectName =
603 nb::cast<std::string>(dialect.attr(
"DIALECT_NAMESPACE"));
604 self.addLegalDialect(dialectName);
607 "dialects"_a,
"Mark the given dialects as legal.")
609 "add_illegal_dialect",
611 for (
auto dialect : dialects) {
612 std::string dialectName =
613 nb::cast<std::string>(dialect.attr(
"DIALECT_NAMESPACE"));
614 self.addIllegalDialect(dialectName);
617 "dialects"_a,
"Mark the given dialect as illegal.");
619 nb::class_<PyTypeConverter>(m,
"TypeConverter")
620 .def(nb::init<>(),
"Create a new TypeConverter.")
622 nb::keep_alive<0, 1>(),
"Register a type conversion function.");
627#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
628 nb::class_<PyMlirPDLResultList>(m,
"PDLResultList")
630 [](PyMlirPDLResultList results,
const PyValue &value) {
631 mlirPDLResultListPushBackValue(results, value);
634 [](PyMlirPDLResultList results,
const PyOperation &op) {
635 mlirPDLResultListPushBackOperation(results, op);
638 [](PyMlirPDLResultList results,
const PyType &type) {
639 mlirPDLResultListPushBackType(results, type);
641 .def(
"append", [](PyMlirPDLResultList results,
const PyAttribute &attr) {
642 mlirPDLResultListPushBackAttribute(results, attr);
644 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
647 [](PyPDLPatternModule &self,
PyModule &module) {
648 new (&self) PyPDLPatternModule(
649 mlirPDLPatternModuleFromModule(module.
get()));
651 "module"_a,
"Create a PDL module from the given module.")
654 [](PyPDLPatternModule &self,
PyModule &module) {
655 new (&self) PyPDLPatternModule(
656 mlirPDLPatternModuleFromModule(module.
get()));
658 "module"_a,
"Create a PDL module from the given module.")
661 [](PyPDLPatternModule &self) {
663 mlirRewritePatternSetFromPDLPatternModule(self.get())));
665 nb::keep_alive<0, 1>())
667 "register_rewrite_function",
668 [](PyPDLPatternModule &self,
const std::string &name,
669 const nb::callable &fn) {
670 self.registerRewriteFunction(name, fn);
672 nb::keep_alive<1, 3>())
674 "register_constraint_function",
675 [](PyPDLPatternModule &self,
const std::string &name,
676 const nb::callable &fn) {
677 self.registerConstraintFunction(name, fn);
679 nb::keep_alive<1, 3>());
682 nb::class_<PyGreedyRewriteConfig>(m,
"GreedyRewriteConfig")
683 .def(nb::init<>(),
"Create a greedy rewrite driver config with defaults")
686 "Maximum number of iterations")
687 .def_prop_rw(
"max_num_rewrites",
690 "Maximum number of rewrites per iteration")
691 .def_prop_rw(
"use_top_down_traversal",
694 "Whether to use top-down traversal")
697 "Enable or disable folding")
700 "Rewrite strictness level")
701 .def_prop_rw(
"region_simplification_level",
704 "Region simplification level")
705 .def_prop_rw(
"enable_constant_cse",
708 "Enable or disable constant CSE");
710 nb::class_<PyConversionConfig>(m,
"ConversionConfig")
711 .def(nb::init<>(),
"Create a conversion config with defaults")
714 "folding behavior during dialect conversion")
715 .def_prop_rw(
"build_materializations",
718 "Whether the dialect conversion attempts to build "
719 "source/target materializations");
721 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
727 "apply_patterns_and_fold_greedily",
729 std::optional<PyGreedyRewriteConfig>
config) {
735 throw std::runtime_error(
"pattern application failed to converge");
737 "module"_a,
"set"_a,
"config"_a = nb::none(),
738 "Applys the given patterns to the given module greedily while folding "
741 "apply_patterns_and_fold_greedily",
743 std::optional<PyGreedyRewriteConfig>
config) {
749 throw std::runtime_error(
750 "pattern application failed to converge");
752 "op"_a,
"set"_a,
"config"_a = nb::none(),
753 "Applys the given patterns to the given op greedily while folding "
756 "walk_and_apply_patterns",
761 "Applies the given patterns to the given op by a fast walk-based "
764 "apply_partial_conversion",
767 std::optional<PyConversionConfig>
config) {
773 throw std::runtime_error(
"partial conversion failed");
775 "op"_a,
"target"_a,
"set"_a,
"config"_a = nb::none(),
776 "Applies a partial conversion on the given operation.")
778 "apply_full_conversion",
781 std::optional<PyConversionConfig>
config) {
787 throw std::runtime_error(
"full conversion failed");
789 "op"_a,
"target"_a,
"set"_a,
"config"_a = nb::none(),
790 "Applies a full conversion on the given operation.");
true
Given two iterators into the same block, return "true" if a is before `b.
MlirIdentifier mlirOperationGetName(MlirOperation op)
MlirContext mlirOperationGetContext(MlirOperation op)
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
false
Parses a map_entries map type from a string format back into its numeric value.
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Wrapper around the generic MlirAttribute.
void enableBuildMaterializations(bool enabled)
PyDialectConversionFoldingMode getFoldingMode()
void setFoldingMode(PyDialectConversionFoldingMode mode)
MlirConversionConfig get()
bool isBuildMaterializationsEnabled()
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
PyConversionPattern(MlirConversionPattern pattern)
PyTypeConverter getTypeConverter()
void addLegalOp(const std::string &opName)
void addIllegalDialect(const std::string &dialectName)
void addIllegalOp(const std::string &opName)
MlirConversionTarget get()
void addLegalDialect(const std::string &dialectName)
PyConversionTarget(MlirContext context)
Owning Wrapper around a FrozenRewritePatternSet.
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set)
MlirFrozenRewritePatternSet get()
static nb::object createFromCapsule(const nb::object &capsule)
~PyFrozenRewritePatternSet()
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
static PyGlobals & get()
Most code should get the globals via this static accessor.
int64_t getMaxIterations()
MlirGreedyRewriteDriverConfig get()
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
void setUseTopDownTraversal(bool useTopDownTraversal)
int64_t getMaxNumRewrites()
PyGreedySimplifyRegionLevel getRegionSimplificationLevel()
void enableConstantCSE(bool enable)
void setMaxIterations(int64_t maxIterations)
PyGreedyRewriteStrictness getStrictness()
void enableFolding(bool enable)
void setMaxNumRewrites(int64_t maxNumRewrites)
bool getUseTopDownTraversal()
bool isConstantCSEEnabled()
void setStrictness(PyGreedyRewriteStrictness strictness)
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
MlirContext get()
Accesses the underlying MlirContext.
MlirModule get()
Gets the backing MlirModule.
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
PyOperation & getOperation() override
Each must provide access to the raw Operation.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
static constexpr const char * pyClassName
PyPatternRewriter(MlirPatternRewriter rewriter)
void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite)
void addConversion(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite, PyTypeConverter &typeConverter)
PyFrozenRewritePatternSet freeze()
PyRewritePatternSet(MlirContext ctx)
static void bind(nanobind::module_ &m)
PyRewriterBase(MlirRewriterBase rewriter)
PyTypeConverter(MlirTypeConverter typeConverter)
void addConversion(const nb::callable &convert)
Wrapper around the generic MlirType.
MLIR_CAPI_EXPORTED void mlirConversionTargetDestroy(MlirConversionTarget target)
Destroy the given ConversionTarget.
MlirDialectConversionFoldingMode
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER
MLIR_CAPI_EXPORTED MlirRewritePattern mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern)
Cast the ConversionPattern to a RewritePattern.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a partial conversion on the given operation.
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode mlirConversionConfigGetFoldingMode(MlirConversionConfig config)
Get the folding mode for the given ConversionConfig.
MLIR_CAPI_EXPORTED bool mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config)
Check if building materializations during conversion is enabled.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a full conversion on the given operation.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
MLIR_CAPI_EXPORTED void mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config, bool enable)
Enable or disable building materializations during conversion.
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as illegal.
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void)
TypeConverter API.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as legal.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirTypeConverter mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern)
Get the type converter used by this conversion pattern.
MLIR_CAPI_EXPORTED MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(MlirConversionPatternRewriter rewriter)
ConversionPatternRewriter API.
MLIR_CAPI_EXPORTED MlirConversionTarget mlirConversionTargetCreate(MlirContext context)
ConversionTarget API.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as illegal.
MLIR_CAPI_EXPORTED void mlirConversionConfigSetFoldingMode(MlirConversionConfig config, MlirDialectConversionFoldingMode mode)
Set the folding mode for the given ConversionConfig.
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED void mlirTypeConverterAddConversion(MlirTypeConverter typeConverter, MlirTypeConverterConversionCallback convertType, void *userData)
Add a type conversion function to the given TypeConverter.
MLIR_CAPI_EXPORTED void mlirTypeConverterDestroy(MlirTypeConverter typeConverter)
Destroy the given TypeConverter.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
MLIR_CAPI_EXPORTED void mlirConversionConfigDestroy(MlirConversionConfig config)
Destroy the given ConversionConfig.
MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void)
ConversionConfig API.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a conversion pattern that matches the operation with the given rootName, corresponding to mlir...
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as legal.
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
struct MlirLogicalResult MlirLogicalResult
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
PyGreedySimplifyRegionLevel
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
PyDialectConversionFoldingMode
PyGreedyRewriteStrictness
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
MlirLogicalResult(* matchAndRewrite)(MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the conver...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A logical result value, essentially a boolean with named states.
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.