17#include "mlir/Config/mlir-config.h"
18#include "nanobind/nanobind.h"
23using namespace nb::literals;
44 return nb::cast<std::string>(root.attr(
"OPERATION_NAME"));
45 if (nb::isinstance<nb::str>(root))
46 return nb::cast<std::string>(root);
48 throw nb::type_error(
"the root argument must be a type or a string");
53 return nb::cast<std::string>(root.attr(
"DIALECT_NAMESPACE"));
54 if (nb::isinstance<nb::str>(root))
55 return nb::cast<std::string>(root);
57 throw nb::type_error(
"the root argument must be a type or a string");
76 : patterns(patterns), owned(
false) {}
79 if (owned && patterns.ptr)
88 const nb::callable &matchAndRewrite,
94 callbacks.
construct = [](
void *userData) {
95 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
97 callbacks.
destruct = [](
void *userData) {
98 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
101 MlirPatternRewriter rewriter,
103 nb::handle f(
static_cast<PyObject *
>(userData));
115 matchAndRewrite.ptr(),
161 MlirConversionTarget
get() {
return target; }
164 MlirConversionTarget
target;
171 : typeConverter(typeConverter), owner(
false) {}
180 [](MlirType type, MlirType *converted,
182 nb::handle f = nb::handle(
static_cast<PyObject *
>(userData));
184 nb::object res = f(
PyType(ctx, type).maybeDownCast());
188 *converted = nb::cast<PyType>(res).get();
203 MlirTypeConverter
get() {
return typeConverter; }
206 MlirTypeConverter typeConverter;
219 MlirConversionPattern pattern;
223 const nb::callable &matchAndRewrite,
230 callbacks.
construct = [](
void *userData) {
231 nb::handle(
static_cast<PyObject *
>(userData)).inc_ref();
233 callbacks.
destruct = [](
void *userData) {
234 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
237 [](MlirConversionPattern pattern, MlirOperation op,
intptr_t nOperands,
238 MlirValue *operands, MlirConversionPatternRewriter rewriter,
240 nb::handle f(
static_cast<PyObject *
>(userData));
246 std::vector<MlirValue> operandsVec(operands, operands + nOperands);
247 nb::object adaptorCls =
251 return std::string_view(ref.
data, ref.
length);
253 .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
255 nb::object res = f(opView, adaptorCls(operandsVec, opView),
262 typeConverter.
get(), callbacks, matchAndRewrite.ptr(),
269#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
270struct PyMlirPDLResultList : MlirPDLResultList {};
272static nb::object objectFromPDLValue(MlirPDLValue value) {
273 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
275 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
277 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
282 throw std::runtime_error(
"unsupported PDL value type");
285static std::vector<nb::object> objectsFromPDLValues(
size_t nValues,
286 MlirPDLValue *values) {
287 std::vector<nb::object> args;
288 args.reserve(nValues);
289 for (
size_t i = 0; i < nValues; ++i)
290 args.push_back(objectFromPDLValue(values[i]));
295class PyPDLPatternModule {
297 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
298 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
299 : module(other.module) {
300 other.module.ptr =
nullptr;
302 ~PyPDLPatternModule() {
303 if (module.ptr !=
nullptr)
304 mlirPDLPatternModuleDestroy(module);
306 MlirPDLPatternModule
get() {
return module; }
308 void registerRewriteFunction(
const std::string &name,
309 const nb::callable &fn) {
310 mlirPDLPatternModuleRegisterRewriteFunction(
312 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
313 size_t nValues, MlirPDLValue *values,
315 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
316 return logicalResultFromObject(
317 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
318 objectsFromPDLValues(nValues, values)));
323 void registerConstraintFunction(
const std::string &name,
324 const nb::callable &fn) {
325 mlirPDLPatternModuleRegisterConstraintFunction(
327 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
328 size_t nValues, MlirPDLValue *values,
330 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
331 return logicalResultFromObject(
332 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
333 objectsFromPDLValues(nValues, values)));
339 MlirPDLPatternModule module;
349 other.set.ptr =
nullptr;
352 if (set.ptr !=
nullptr)
355 MlirFrozenRewritePatternSet
get() {
return set; }
358 return nb::steal<nb::object>(
363 MlirFrozenRewritePatternSet rawPm =
365 if (rawPm.ptr ==
nullptr)
366 throw nb::python_error();
371 MlirFrozenRewritePatternSet set;
375 nb::class_<PyRewritePatternSet>(m,
"RewritePatternSet")
381 "context"_a = nb::none())
383 nb::arg(
"benefit") = 1,
384 R
"(Add a new rewrite pattern on the specified root operation, using
385 the provided callable for matching and rewriting, and assign it
389 root: The root operation to which this pattern applies. This may
390 be either an OpView subclass or an operation name.
391 fn: The callable to use for matching and rewriting, which takes
392 an operation and a pattern rewriter. The match is considered
393 successful iff the callable returns a falsy value.
394 benefit: The benefit of the pattern, defaulting to 1.)")
396 nb::arg(
"root"), nb::arg(
"fn"), nb::arg(
"type_converter"),
397 nb::arg(
"benefit") = 1,
399 Add a new conversion pattern on the specified root operation,
400 using the provided callable for matching and rewriting,
401 and assign it the given benefit.
404 root: The root operation to which this pattern applies.
405 This may be either an OpView subclass or an operation name.
406 fn: The callable to use for matching and rewriting, which takes an
407 operation, its adaptor, the type converter and a pattern
408 rewriter. The match is considered successful iff the callable
409 returns a falsy value.
410 type_converter: The type converter to convert types in the IR.
411 benefit: The benefit of the pattern, defaulting to 1.)")
416 throw std::runtime_error(
417 "cannot freeze a non-owning pattern set");
418 MlirRewritePatternSet s = self.
get();
421 "Freeze the pattern set into a frozen one.");
443 PyGreedyRewriteConfig::customDeleter) {}
445 : config(std::move(other.config)) {}
447 : config(other.config) {}
449 MlirGreedyRewriteDriverConfig
get() {
450 return MlirGreedyRewriteDriverConfig{config.get()};
463 useTopDownTraversal);
480 void enableConstantCSE(
bool enable) {
492 bool getUseTopDownTraversal() {
496 bool isFoldingEnabled() {
510 bool isConstantCSEEnabled() {
515 std::shared_ptr<void> config;
516 static void customDeleter(
void *c) {
532 PyConversionConfig::customDeleter) {}
534 MlirConversionConfig
get() {
return MlirConversionConfig{config.get()}; }
546 void enableBuildMaterializations(
bool enabled) {
550 bool isBuildMaterializationsEnabled() {
555 std::shared_ptr<void> config;
556 static void customDeleter(
void *c) {
564 nb::enum_<PyGreedyRewriteStrictness>(m,
"GreedyRewriteStrictness")
565 .value(
"ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
566 .value(
"EXISTING_AND_NEW_OPS",
567 PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS)
568 .value(
"EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
570 nb::enum_<PyGreedySimplifyRegionLevel>(m,
"GreedySimplifyRegionLevel")
571 .value(
"DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
572 .value(
"NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
573 .value(
"AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
575 nb::enum_<PyDialectConversionFoldingMode>(m,
"DialectConversionFoldingMode")
576 .value(
"NEVER", PyDialectConversionFoldingMode::Never)
577 .value(
"BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
578 .value(
"AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
584 PyPatternRewriter::bind(m);
591 nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
592 m,
"ConversionPatternRewriter")
593 .def(
"convert_region_types",
594 [](PyConversionPatternRewriter &self, PyRegion ®ion,
595 PyTypeConverter &typeConverter) {
600 nb::class_<PyConversionTarget>(m,
"ConversionTarget")
603 [](PyConversionTarget &self, DefaultingPyMlirContext context) {
604 new (&self) PyConversionTarget(context.
get()->
get());
606 "context"_a = nb::none())
609 [](PyConversionTarget &self,
const nb::args &ops) {
610 for (
auto op : ops) {
614 "ops"_a,
"Mark the given operations as legal.")
617 [](PyConversionTarget &self,
const nb::args &ops) {
618 for (
auto op : ops) {
622 "ops"_a,
"Mark the given operations as illegal.")
625 [](PyConversionTarget &self,
const nb::args &dialects) {
626 for (
auto dialect : dialects) {
630 "dialects"_a,
"Mark the given dialects as legal.")
632 "add_illegal_dialect",
633 [](PyConversionTarget &self,
const nb::args &dialects) {
634 for (
auto dialect : dialects) {
638 "dialects"_a,
"Mark the given dialect as illegal.");
640 nb::class_<PyTypeConverter>(m,
"TypeConverter")
641 .def(nb::init<>(),
"Create a new TypeConverter.")
643 nb::keep_alive<0, 1>(),
"Register a type conversion function.")
645 "Convert the given type. Returns None if conversion fails.");
650#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
651 nb::class_<PyMlirPDLResultList>(m,
"PDLResultList")
653 [](PyMlirPDLResultList results,
const PyValue &value) {
654 mlirPDLResultListPushBackValue(results, value);
657 [](PyMlirPDLResultList results,
const PyOperation &op) {
658 mlirPDLResultListPushBackOperation(results, op);
661 [](PyMlirPDLResultList results,
const PyType &type) {
662 mlirPDLResultListPushBackType(results, type);
664 .def(
"append", [](PyMlirPDLResultList results,
const PyAttribute &attr) {
665 mlirPDLResultListPushBackAttribute(results, attr);
667 nb::class_<PyPDLPatternModule>(m,
"PDLModule")
670 [](PyPDLPatternModule &self, PyModule &module) {
671 new (&self) PyPDLPatternModule(
672 mlirPDLPatternModuleFromModule(module.
get()));
674 "module"_a,
"Create a PDL module from the given module.")
677 [](PyPDLPatternModule &self, PyModule &module) {
678 new (&self) PyPDLPatternModule(
679 mlirPDLPatternModuleFromModule(module.
get()));
681 "module"_a,
"Create a PDL module from the given module.")
684 [](PyPDLPatternModule &self) {
686 mlirRewritePatternSetFromPDLPatternModule(self.get())));
688 nb::keep_alive<0, 1>())
690 "register_rewrite_function",
691 [](PyPDLPatternModule &self,
const std::string &name,
692 const nb::callable &fn) {
693 self.registerRewriteFunction(name, fn);
695 nb::keep_alive<1, 3>())
697 "register_constraint_function",
698 [](PyPDLPatternModule &self,
const std::string &name,
699 const nb::callable &fn) {
700 self.registerConstraintFunction(name, fn);
702 nb::keep_alive<1, 3>());
705 nb::class_<PyGreedyRewriteConfig>(m,
"GreedyRewriteConfig")
706 .def(nb::init<>(),
"Create a greedy rewrite driver config with defaults")
709 "Maximum number of iterations")
710 .def_prop_rw(
"max_num_rewrites",
713 "Maximum number of rewrites per iteration")
714 .def_prop_rw(
"use_top_down_traversal",
717 "Whether to use top-down traversal")
720 "Enable or disable folding")
723 "Rewrite strictness level")
724 .def_prop_rw(
"region_simplification_level",
727 "Region simplification level")
728 .def_prop_rw(
"enable_constant_cse",
731 "Enable or disable constant CSE");
733 nb::class_<PyConversionConfig>(m,
"ConversionConfig")
734 .def(nb::init<>(),
"Create a conversion config with defaults")
737 "folding behavior during dialect conversion")
738 .def_prop_rw(
"build_materializations",
741 "Whether the dialect conversion attempts to build "
742 "source/target materializations");
744 nb::class_<PyFrozenRewritePatternSet>(m,
"FrozenRewritePatternSet")
750 "apply_patterns_and_fold_greedily",
751 [](PyModule &module, PyFrozenRewritePatternSet &set,
752 std::optional<PyGreedyRewriteConfig> config) {
755 config.has_value() ? config->get()
758 throw std::runtime_error(
"pattern application failed to converge");
760 "module"_a,
"set"_a,
"config"_a = nb::none(),
761 "Applys the given patterns to the given module greedily while folding "
764 "apply_patterns_and_fold_greedily",
765 [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
766 std::optional<PyGreedyRewriteConfig> config) {
769 config.has_value() ? config->get()
772 throw std::runtime_error(
773 "pattern application failed to converge");
775 "op"_a,
"set"_a,
"config"_a = nb::none(),
776 "Applys the given patterns to the given op greedily while folding "
779 "walk_and_apply_patterns",
780 [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
784 "Applies the given patterns to the given op by a fast walk-based "
787 "apply_partial_conversion",
788 [](PyOperationBase &op, PyConversionTarget &
target,
789 PyFrozenRewritePatternSet &set,
790 std::optional<PyConversionConfig> config) {
792 config.emplace(PyConversionConfig());
797 throw MLIRError(
"partial conversion failed", errors.take());
799 "op"_a,
"target"_a,
"set"_a,
"config"_a = nb::none(),
800 "Applies a partial conversion on the given operation.")
802 "apply_full_conversion",
803 [](PyOperationBase &op, PyConversionTarget &
target,
804 PyFrozenRewritePatternSet &set,
805 std::optional<PyConversionConfig> config) {
807 config.emplace(PyConversionConfig());
812 throw MLIRError(
"full conversion failed", errors.take());
814 "op"_a,
"target"_a,
"set"_a,
"config"_a = nb::none(),
815 "Applies a full conversion on the given operation.");
true
Given two iterators into the same block, return "true" if a is before `b.
MlirIdentifier mlirOperationGetName(MlirOperation op)
MlirContext mlirOperationGetContext(MlirOperation op)
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
false
Parses a map_entries map type from a string format back into its numeric value.
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Used in function arguments when None should resolve to the current context manager set instance.
void enableBuildMaterializations(bool enabled)
PyDialectConversionFoldingMode getFoldingMode()
void setFoldingMode(PyDialectConversionFoldingMode mode)
MlirConversionConfig get()
bool isBuildMaterializationsEnabled()
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
MlirConversionPatternRewriter rewriter
PyConversionPattern(MlirConversionPattern pattern)
PyTypeConverter getTypeConverter()
void addLegalOp(const std::string &opName)
void addIllegalDialect(const std::string &dialectName)
void addIllegalOp(const std::string &opName)
MlirConversionTarget get()
void addLegalDialect(const std::string &dialectName)
PyConversionTarget(MlirContext context)
Owning Wrapper around a FrozenRewritePatternSet.
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set)
MlirFrozenRewritePatternSet get()
static nb::object createFromCapsule(const nb::object &capsule)
~PyFrozenRewritePatternSet()
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
static PyGlobals & get()
Most code should get the globals via this static accessor.
Owning Wrapper around a GreedyRewriteDriverConfig.
int64_t getMaxIterations()
void setUseTopDownTraversal(bool useTopDownTraversal)
int64_t getMaxNumRewrites()
PyGreedySimplifyRegionLevel getRegionSimplificationLevel()
void enableConstantCSE(bool enable)
void setMaxIterations(int64_t maxIterations)
PyGreedyRewriteStrictness getStrictness()
void enableFolding(bool enable)
void setMaxNumRewrites(int64_t maxNumRewrites)
bool getUseTopDownTraversal()
bool isConstantCSEEnabled()
void setStrictness(PyGreedyRewriteStrictness strictness)
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
MlirContext get()
Accesses the underlying MlirContext.
MlirModule get()
Gets the backing MlirModule.
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
MlirOperation get() const
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
static constexpr const char * pyClassName
PyPatternRewriter(MlirPatternRewriter rewriter)
void add(nanobind::handle root, const nanobind::callable &matchAndRewrite, unsigned benefit)
Add a new rewrite pattern to the pattern set.
static void bind(nanobind::module_ &m)
void addConversion(nanobind::handle root, const nanobind::callable &matchAndRewrite, PyTypeConverter &typeConverter, unsigned benefit)
Add a new conversion pattern to the pattern set.
PyRewritePatternSet(MlirContext ctx)
Create an owned pattern set.
MlirRewritePatternSet get() const
PyRewriterBase(MlirRewriterBase rewriter)
nb::typed< nb::object, std::optional< PyType > > convertType(PyType &type)
PyTypeConverter(MlirTypeConverter typeConverter)
void addConversion(const nb::callable &convert)
Wrapper around the generic MlirType.
nanobind::typed< nanobind::object, PyType > maybeDownCast()
MLIR_CAPI_EXPORTED void mlirConversionTargetDestroy(MlirConversionTarget target)
Destroy the given ConversionTarget.
MLIR_CAPI_EXPORTED MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set)
Get the context associated with a MlirRewritePatternSet.
MlirDialectConversionFoldingMode
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER
MLIR_CAPI_EXPORTED MlirRewritePattern mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern)
Cast the ConversionPattern to a RewritePattern.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a partial conversion on the given operation.
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode mlirConversionConfigGetFoldingMode(MlirConversionConfig config)
Get the folding mode for the given ConversionConfig.
MLIR_CAPI_EXPORTED bool mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config)
Check if building materializations during conversion is enabled.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a full conversion on the given operation.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
MLIR_CAPI_EXPORTED void mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config, bool enable)
Enable or disable building materializations during conversion.
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as illegal.
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void)
TypeConverter API.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as legal.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirTypeConverter mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern)
Get the type converter used by this conversion pattern.
MLIR_CAPI_EXPORTED MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(MlirConversionPatternRewriter rewriter)
ConversionPatternRewriter API.
MLIR_CAPI_EXPORTED MlirConversionTarget mlirConversionTargetCreate(MlirContext context)
ConversionTarget API.
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as illegal.
MLIR_CAPI_EXPORTED void mlirConversionConfigSetFoldingMode(MlirConversionConfig config, MlirDialectConversionFoldingMode mode)
Set the folding mode for the given ConversionConfig.
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED MlirType mlirTypeConverterConvertType(MlirTypeConverter typeConverter, MlirType type)
Convert the given type using the given TypeConverter.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
MLIR_CAPI_EXPORTED void mlirTypeConverterAddConversion(MlirTypeConverter typeConverter, MlirTypeConverterConversionCallback convertType, void *userData)
Add a type conversion function to the given TypeConverter.
MLIR_CAPI_EXPORTED void mlirTypeConverterDestroy(MlirTypeConverter typeConverter)
Destroy the given TypeConverter.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
MLIR_CAPI_EXPORTED void mlirConversionConfigDestroy(MlirConversionConfig config)
Destroy the given ConversionConfig.
MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void)
ConversionConfig API.
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
MLIR_CAPI_EXPORTED MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(MlirConversionPatternRewriter rewriter, MlirRegion region, MlirTypeConverter typeConverter)
Apply a signature conversion to each block in the given region.
MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a conversion pattern that matches the operation with the given rootName, corresponding to mlir...
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as legal.
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
struct MlirLogicalResult MlirLogicalResult
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
PyGreedySimplifyRegionLevel
static std::string dialectNameFromObject(nb::handle root)
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
PyDialectConversionFoldingMode
static std::string operationNameFromObject(nb::handle root)
static MlirLogicalResult logicalResultFromObject(const nb::object &obj)
PyGreedyRewriteStrictness
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
MlirLogicalResult(* matchAndRewrite)(MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the conver...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A logical result value, essentially a boolean with named states.
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.