MLIR 23.0.0git
Rewrite.cpp
Go to the documentation of this file.
1//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Rewrite.h"
10
12#include "mlir-c/IR.h"
13#include "mlir-c/Rewrite.h"
14#include "mlir-c/Support.h"
17#include "mlir/Config/mlir-config.h"
18#include "nanobind/nanobind.h"
19#include <type_traits>
20
21namespace nb = nanobind;
22using namespace mlir;
23using namespace nb::literals;
25
26namespace mlir {
27namespace python {
29
30class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
31public:
32 static constexpr const char *pyClassName = "PatternRewriter";
33
34 PyPatternRewriter(MlirPatternRewriter rewriter)
36};
37
44
46public:
47 PyConversionTarget(MlirContext context)
48 : target(mlirConversionTargetCreate(context)) {}
50
51 void addLegalOp(const std::string &opName) {
53 target, mlirStringRefCreate(opName.data(), opName.size()));
54 }
55
56 void addIllegalOp(const std::string &opName) {
58 target, mlirStringRefCreate(opName.data(), opName.size()));
59 }
60
61 void addLegalDialect(const std::string &dialectName) {
63 target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
64 }
65
66 void addIllegalDialect(const std::string &dialectName) {
68 target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
69 }
70
71 MlirConversionTarget get() { return target; }
72
73private:
74 MlirConversionTarget target;
75};
76
78public:
79 PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
80 PyTypeConverter(MlirTypeConverter typeConverter)
81 : typeConverter(typeConverter), owner(false) {}
83 if (owner)
84 mlirTypeConverterDestroy(typeConverter);
85 }
86
87 void addConversion(const nb::callable &convert) {
89 typeConverter,
90 [](MlirType type, MlirType *converted,
91 void *userData) -> MlirLogicalResult {
92 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
94 nb::object res = f(PyType(ctx, type).maybeDownCast());
95 if (res.is_none())
97
98 *converted = nb::cast<PyType>(res).get();
100 },
101 convert.ptr());
102 }
103
104 MlirTypeConverter get() { return typeConverter; }
105
106private:
107 MlirTypeConverter typeConverter;
108 bool owner;
109};
110
112public:
113 PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
114
118
119private:
120 MlirConversionPattern pattern;
121};
122
123#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
124struct PyMlirPDLResultList : MlirPDLResultList {};
125
126static nb::object objectFromPDLValue(MlirPDLValue value) {
127 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
128 return nb::cast(v);
129 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
130 return nb::cast(v);
131 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
132 return nb::cast(v);
133 if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
134 return nb::cast(v);
135
136 throw std::runtime_error("unsupported PDL value type");
137}
138
139static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
140 MlirPDLValue *values) {
141 std::vector<nb::object> args;
142 args.reserve(nValues);
143 for (size_t i = 0; i < nValues; ++i)
144 args.push_back(objectFromPDLValue(values[i]));
145 return args;
146}
147
148// Convert the Python object to a boolean.
149// If it evaluates to False, treat it as success;
150// otherwise, treat it as failure.
151// Note that None is considered success.
152static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
153 if (obj.is_none())
155
156 return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
158}
159
160/// Owning Wrapper around a PDLPatternModule.
161class PyPDLPatternModule {
162public:
163 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
164 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
165 : module(other.module) {
166 other.module.ptr = nullptr;
167 }
168 ~PyPDLPatternModule() {
169 if (module.ptr != nullptr)
170 mlirPDLPatternModuleDestroy(module);
171 }
172 MlirPDLPatternModule get() { return module; }
173
174 void registerRewriteFunction(const std::string &name,
175 const nb::callable &fn) {
176 mlirPDLPatternModuleRegisterRewriteFunction(
177 get(), mlirStringRefCreate(name.data(), name.size()),
178 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
179 size_t nValues, MlirPDLValue *values,
180 void *userData) -> MlirLogicalResult {
181 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
182 return logicalResultFromObject(
183 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
184 objectsFromPDLValues(nValues, values)));
185 },
186 fn.ptr());
187 }
188
189 void registerConstraintFunction(const std::string &name,
190 const nb::callable &fn) {
191 mlirPDLPatternModuleRegisterConstraintFunction(
192 get(), mlirStringRefCreate(name.data(), name.size()),
193 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
194 size_t nValues, MlirPDLValue *values,
195 void *userData) -> MlirLogicalResult {
196 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
197 return logicalResultFromObject(
198 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
199 objectsFromPDLValues(nValues, values)));
200 },
201 fn.ptr());
202 }
203
204private:
205 MlirPDLPatternModule module;
206};
207#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
208
209/// Owning Wrapper around a FrozenRewritePatternSet.
211public:
212 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
214 : set(other.set) {
215 other.set.ptr = nullptr;
216 }
218 if (set.ptr != nullptr)
220 }
221 MlirFrozenRewritePatternSet get() { return set; }
222
223 nb::object getCapsule() {
224 return nb::steal<nb::object>(
226 }
227
228 static nb::object createFromCapsule(const nb::object &capsule) {
229 MlirFrozenRewritePatternSet rawPm =
231 if (rawPm.ptr == nullptr)
232 throw nb::python_error();
233 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
234 }
235
236private:
237 MlirFrozenRewritePatternSet set;
238};
239
241public:
242 PyRewritePatternSet(MlirContext ctx)
243 : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
245 if (set.ptr)
247 }
248
249 void add(MlirStringRef rootName, unsigned benefit,
250 const nb::callable &matchAndRewrite) {
252 callbacks.construct = [](void *userData) {
253 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
254 };
255 callbacks.destruct = [](void *userData) {
256 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
257 };
258 callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
259 MlirPatternRewriter rewriter,
260 void *userData) -> MlirLogicalResult {
261 nb::handle f(static_cast<PyObject *>(userData));
262
263 PyMlirContextRef ctx =
265 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
266
267 nb::object res = f(opView, PyPatternRewriter(rewriter));
268 return logicalResultFromObject(res);
269 };
270 MlirRewritePattern pattern = mlirOpRewritePatternCreate(
271 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
272 /* nGeneratedNames */ 0,
273 /* generatedNames */ nullptr);
274 mlirRewritePatternSetAdd(set, pattern);
275 }
276
277 void addConversion(MlirStringRef rootName, unsigned benefit,
278 const nb::callable &matchAndRewrite,
279 PyTypeConverter &typeConverter) {
281 callbacks.construct = [](void *userData) {
282 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
283 };
284 callbacks.destruct = [](void *userData) {
285 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
286 };
287 callbacks.matchAndRewrite =
288 [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
289 MlirValue *operands, MlirConversionPatternRewriter rewriter,
290 void *userData) -> MlirLogicalResult {
291 nb::handle f(static_cast<PyObject *>(userData));
292
293 PyMlirContextRef ctx =
295 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
296
297 std::vector<MlirValue> operandsVec(operands, operands + nOperands);
298 nb::object adaptorCls =
302 .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
303
304 nb::object res = f(opView, adaptorCls(operandsVec, opView),
305 PyConversionPattern(pattern).getTypeConverter(),
307 return logicalResultFromObject(res);
308 };
309 MlirConversionPattern pattern = mlirOpConversionPatternCreate(
310 rootName, benefit, ctx, typeConverter.get(), callbacks,
311 matchAndRewrite.ptr(),
312 /* nGeneratedNames */ 0,
313 /* generatedNames */ nullptr);
316 }
317
319 MlirRewritePatternSet s = set;
320 set.ptr = nullptr;
321 return mlirFreezeRewritePattern(s);
322 }
323
324private:
325 MlirRewritePatternSet set;
326 MlirContext ctx;
327};
328
335
342
343/// Owning Wrapper around a GreedyRewriteDriverConfig.
345public:
350 : config(std::move(other.config)) {}
352 : config(other.config) {}
353
354 MlirGreedyRewriteDriverConfig get() {
355 return MlirGreedyRewriteDriverConfig{config.get()};
356 }
357
358 void setMaxIterations(int64_t maxIterations) {
360 }
361
362 void setMaxNumRewrites(int64_t maxNumRewrites) {
364 }
365
366 void setUseTopDownTraversal(bool useTopDownTraversal) {
368 useTopDownTraversal);
369 }
370
371 void enableFolding(bool enable) {
373 }
374
379
384
388
392
396
400
404
409
414
418
419private:
420 std::shared_ptr<void> config;
421 static void customDeleter(void *c) {
422 mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
423 }
424};
425
432
434public:
437 PyConversionConfig::customDeleter) {}
438
439 MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
440
445
450
454
458
459private:
460 std::shared_ptr<void> config;
461 static void customDeleter(void *c) {
462 mlirConversionConfigDestroy(MlirConversionConfig{c});
463 }
464};
465
466/// Create the `mlir.rewrite` here.
467void populateRewriteSubmodule(nb::module_ &m) {
468 // Enum definitions
469 nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
470 .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
471 .value("EXISTING_AND_NEW_OPS",
473 .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
474
475 nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
476 .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
478 .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
479
480 nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
482 .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
483 .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
484
485 //----------------------------------------------------------------------------
486 // Mapping of the PatternRewriter
487 //----------------------------------------------------------------------------
488
490
491 //----------------------------------------------------------------------------
492 // Mapping of the RewritePatternSet
493 //----------------------------------------------------------------------------
494 nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
495 .def(
496 "__init__",
498 new (&self) PyRewritePatternSet(context.get()->get());
499 },
500 "context"_a = nb::none())
501 .def(
502 "add",
503 [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
504 unsigned benefit) {
505 std::string opName;
506 if (root.is_type()) {
507 opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
508 } else if (nb::isinstance<nb::str>(root)) {
509 opName = nb::cast<std::string>(root);
510 } else {
511 throw nb::type_error(
512 "the root argument must be a type or a string");
513 }
514 self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
515 fn);
516 },
517 "root"_a, "fn"_a, "benefit"_a = 1,
518 // clang-format off
519 nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
520 // clang-format on
521 R"(
522 Add a new rewrite pattern on the specified root operation, using the provided callable
523 for matching and rewriting, and assign it the given benefit.
524
525 Args:
526 root: The root operation to which this pattern applies.
527 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
528 an operation name string (e.g., ``"arith.addi"``).
529 fn: The callable to use for matching and rewriting,
530 which takes an operation and a pattern rewriter as arguments.
531 The match is considered successful iff the callable returns
532 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
533 If possible, the operation is cast to its corresponding OpView subclass
534 before being passed to the callable.
535 benefit: The benefit of the pattern, defaulting to 1.)")
536 .def(
537 "add_conversion",
538 [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
539 PyTypeConverter &typeConverter, unsigned benefit) {
540 std::string opName =
541 nb::cast<std::string>(root.attr("OPERATION_NAME"));
542 self.addConversion(
543 mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
544 typeConverter);
545 },
546 "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
547 R"(
548 Add a new conversion pattern on the specified root operation,
549 using the provided callable for matching and rewriting,
550 and assign it the given benefit.
551
552 Args:
553 root: The root operation to which this pattern applies.
554 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
555 an operation name string (e.g., ``"arith.addi"``).
556 fn: The callable to use for matching and rewriting,
557 which takes an operation, its adaptor,
558 the type converter and a pattern rewriter as arguments.
559 The match is considered successful iff the callable returns
560 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
561 If possible, the operation is cast to its corresponding OpView subclass
562 before being passed to the callable.
563 type_converter: The type converter to convert types in the IR.
564 benefit: The benefit of the pattern, defaulting to 1.)")
565 .def("freeze", &PyRewritePatternSet::freeze,
566 "Freeze the pattern set into a frozen one.");
567
568 nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
569 m, "ConversionPatternRewriter");
570
571 nb::class_<PyConversionTarget>(m, "ConversionTarget")
572 .def(
573 "__init__",
575 new (&self) PyConversionTarget(context.get()->get());
576 },
577 "context"_a = nb::none())
578 .def(
579 "add_legal_op",
580 [](PyConversionTarget &self, const nb::args &ops) {
581 for (auto op : ops) {
582 std::string opName =
583 nb::cast<std::string>(op.attr("OPERATION_NAME"));
584 self.addLegalOp(opName);
585 }
586 },
587 "ops"_a, "Mark the given operations as legal.")
588 .def(
589 "add_illegal_op",
590 [](PyConversionTarget &self, const nb::args &ops) {
591 for (auto op : ops) {
592 std::string opName =
593 nb::cast<std::string>(op.attr("OPERATION_NAME"));
594 self.addIllegalOp(opName);
595 }
596 },
597 "ops"_a, "Mark the given operations as illegal.")
598 .def(
599 "add_legal_dialect",
600 [](PyConversionTarget &self, const nb::args &dialects) {
601 for (auto dialect : dialects) {
602 std::string dialectName =
603 nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
604 self.addLegalDialect(dialectName);
605 }
606 },
607 "dialects"_a, "Mark the given dialects as legal.")
608 .def(
609 "add_illegal_dialect",
610 [](PyConversionTarget &self, const nb::args &dialects) {
611 for (auto dialect : dialects) {
612 std::string dialectName =
613 nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
614 self.addIllegalDialect(dialectName);
615 }
616 },
617 "dialects"_a, "Mark the given dialect as illegal.");
618
619 nb::class_<PyTypeConverter>(m, "TypeConverter")
620 .def(nb::init<>(), "Create a new TypeConverter.")
621 .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
622 nb::keep_alive<0, 1>(), "Register a type conversion function.");
623
624 //----------------------------------------------------------------------------
625 // Mapping of the PDLResultList and PDLModule
626 //----------------------------------------------------------------------------
627#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
628 nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
629 .def("append",
630 [](PyMlirPDLResultList results, const PyValue &value) {
631 mlirPDLResultListPushBackValue(results, value);
632 })
633 .def("append",
634 [](PyMlirPDLResultList results, const PyOperation &op) {
635 mlirPDLResultListPushBackOperation(results, op);
636 })
637 .def("append",
638 [](PyMlirPDLResultList results, const PyType &type) {
639 mlirPDLResultListPushBackType(results, type);
640 })
641 .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
642 mlirPDLResultListPushBackAttribute(results, attr);
643 });
644 nb::class_<PyPDLPatternModule>(m, "PDLModule")
645 .def(
646 "__init__",
647 [](PyPDLPatternModule &self, PyModule &module) {
648 new (&self) PyPDLPatternModule(
649 mlirPDLPatternModuleFromModule(module.get()));
650 },
651 "module"_a, "Create a PDL module from the given module.")
652 .def(
653 "__init__",
654 [](PyPDLPatternModule &self, PyModule &module) {
655 new (&self) PyPDLPatternModule(
656 mlirPDLPatternModuleFromModule(module.get()));
657 },
658 "module"_a, "Create a PDL module from the given module.")
659 .def(
660 "freeze",
661 [](PyPDLPatternModule &self) {
663 mlirRewritePatternSetFromPDLPatternModule(self.get())));
664 },
665 nb::keep_alive<0, 1>())
666 .def(
667 "register_rewrite_function",
668 [](PyPDLPatternModule &self, const std::string &name,
669 const nb::callable &fn) {
670 self.registerRewriteFunction(name, fn);
671 },
672 nb::keep_alive<1, 3>())
673 .def(
674 "register_constraint_function",
675 [](PyPDLPatternModule &self, const std::string &name,
676 const nb::callable &fn) {
677 self.registerConstraintFunction(name, fn);
678 },
679 nb::keep_alive<1, 3>());
680#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
681
682 nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
683 .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
684 .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
686 "Maximum number of iterations")
687 .def_prop_rw("max_num_rewrites",
690 "Maximum number of rewrites per iteration")
691 .def_prop_rw("use_top_down_traversal",
694 "Whether to use top-down traversal")
695 .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
697 "Enable or disable folding")
698 .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
700 "Rewrite strictness level")
701 .def_prop_rw("region_simplification_level",
704 "Region simplification level")
705 .def_prop_rw("enable_constant_cse",
708 "Enable or disable constant CSE");
709
710 nb::class_<PyConversionConfig>(m, "ConversionConfig")
711 .def(nb::init<>(), "Create a conversion config with defaults")
712 .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
714 "folding behavior during dialect conversion")
715 .def_prop_rw("build_materializations",
718 "Whether the dialect conversion attempts to build "
719 "source/target materializations");
720
721 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
722 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
726 m.def(
727 "apply_patterns_and_fold_greedily",
728 [](PyModule &module, PyFrozenRewritePatternSet &set,
729 std::optional<PyGreedyRewriteConfig> config) {
731 module.get(), set.get(),
732 config.has_value() ? config->get()
734 if (mlirLogicalResultIsFailure(status))
735 throw std::runtime_error("pattern application failed to converge");
736 },
737 "module"_a, "set"_a, "config"_a = nb::none(),
738 "Applys the given patterns to the given module greedily while folding "
739 "results.")
740 .def(
741 "apply_patterns_and_fold_greedily",
743 std::optional<PyGreedyRewriteConfig> config) {
745 op.getOperation(), set.get(),
746 config.has_value() ? config->get()
748 if (mlirLogicalResultIsFailure(status))
749 throw std::runtime_error(
750 "pattern application failed to converge");
751 },
752 "op"_a, "set"_a, "config"_a = nb::none(),
753 "Applys the given patterns to the given op greedily while folding "
754 "results.")
755 .def(
756 "walk_and_apply_patterns",
759 },
760 "op"_a, "set"_a,
761 "Applies the given patterns to the given op by a fast walk-based "
762 "driver.")
763 .def(
764 "apply_partial_conversion",
767 std::optional<PyConversionConfig> config) {
768 if (!config)
769 config.emplace(PyConversionConfig());
771 op.getOperation(), target.get(), set.get(), config->get());
772 if (mlirLogicalResultIsFailure(status))
773 throw std::runtime_error("partial conversion failed");
774 },
775 "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
776 "Applies a partial conversion on the given operation.")
777 .def(
778 "apply_full_conversion",
781 std::optional<PyConversionConfig> config) {
782 if (!config)
783 config.emplace(PyConversionConfig());
785 op.getOperation(), target.get(), set.get(), config->get());
786 if (mlirLogicalResultIsFailure(status))
787 throw std::runtime_error("full conversion failed");
788 },
789 "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
790 "Applies a full conversion on the given operation.");
791}
792} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
793} // namespace python
794} // namespace mlir
true
Given two iterators into the same block, return "true" if a is before `b.
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:668
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition Interop.h:293
false
Parses a map_entries map type from a string format back into its numeric value.
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:279
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1006
void setFoldingMode(PyDialectConversionFoldingMode mode)
Definition Rewrite.cpp:441
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
Definition Rewrite.cpp:40
void addIllegalDialect(const std::string &dialectName)
Definition Rewrite.cpp:66
void addLegalDialect(const std::string &dialectName)
Definition Rewrite.cpp:61
Owning Wrapper around a FrozenRewritePatternSet.
Definition Rewrite.cpp:210
static nb::object createFromCapsule(const nb::object &capsule)
Definition Rewrite.cpp:228
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
Definition Rewrite.cpp:213
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
Definition Globals.cpp:236
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
Definition Rewrite.cpp:351
void setStrictness(PyGreedyRewriteStrictness strictness)
Definition Rewrite.cpp:375
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
Definition Rewrite.cpp:380
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
Definition Rewrite.cpp:349
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:212
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:548
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:578
PyOperation & getOperation() override
Each must provide access to the raw Operation.
Definition IRCore.h:635
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite)
Definition Rewrite.cpp:249
void addConversion(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite, PyTypeConverter &typeConverter)
Definition Rewrite.cpp:277
void addConversion(const nb::callable &convert)
Definition Rewrite.cpp:87
Wrapper around the generic MlirType.
Definition IRCore.h:875
MLIR_CAPI_EXPORTED void mlirConversionTargetDestroy(MlirConversionTarget target)
Destroy the given ConversionTarget.
Definition Rewrite.cpp:538
MlirDialectConversionFoldingMode
Definition Rewrite.h:452
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS
Definition Rewrite.h:455
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS
Definition Rewrite.h:454
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER
Definition Rewrite.h:453
MLIR_CAPI_EXPORTED MlirRewritePattern mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern)
Cast the ConversionPattern to a RewritePattern.
Definition Rewrite.cpp:656
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
Definition Rewrite.cpp:365
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a partial conversion on the given operation.
Definition Rewrite.cpp:447
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
Definition Rewrite.cpp:405
MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode mlirConversionConfigGetFoldingMode(MlirConversionConfig config)
Get the folding mode for the given ConversionConfig.
Definition Rewrite.cpp:492
MLIR_CAPI_EXPORTED bool mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config)
Check if building materializations during conversion is enabled.
Definition Rewrite.cpp:508
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a full conversion on the given operation.
Definition Rewrite.cpp:454
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
Definition Rewrite.cpp:385
MLIR_CAPI_EXPORTED void mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config, bool enable)
Enable or disable building materializations during conversion.
Definition Rewrite.cpp:503
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
Definition Rewrite.cpp:301
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
Definition Rewrite.cpp:420
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
Definition Rewrite.cpp:723
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as illegal.
Definition Rewrite.cpp:559
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
Definition Rewrite.cpp:441
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
Definition Rewrite.cpp:348
MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void)
TypeConverter API.
Definition Rewrite.cpp:568
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as legal.
Definition Rewrite.cpp:542
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
Definition Rewrite.cpp:426
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
Definition Rewrite.cpp:305
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
Definition Rewrite.cpp:320
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:310
MLIR_CAPI_EXPORTED MlirTypeConverter mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern)
Get the type converter used by this conversion pattern.
Definition Rewrite.cpp:651
MLIR_CAPI_EXPORTED MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(MlirConversionPatternRewriter rewriter)
ConversionPatternRewriter API.
Definition Rewrite.cpp:525
MLIR_CAPI_EXPORTED MlirConversionTarget mlirConversionTargetCreate(MlirContext context)
ConversionTarget API.
Definition Rewrite.cpp:534
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
Definition Rewrite.cpp:697
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as illegal.
Definition Rewrite.cpp:548
MLIR_CAPI_EXPORTED void mlirConversionConfigSetFoldingMode(MlirConversionConfig config, MlirDialectConversionFoldingMode mode)
Set the folding mode for the given ConversionConfig.
Definition Rewrite.cpp:474
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:370
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:390
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:330
MLIR_CAPI_EXPORTED void mlirTypeConverterAddConversion(MlirTypeConverter typeConverter, MlirTypeConverterConversionCallback convertType, void *userData)
Add a type conversion function to the given TypeConverter.
Definition Rewrite.cpp:576
MLIR_CAPI_EXPORTED void mlirTypeConverterDestroy(MlirTypeConverter typeConverter)
Destroy the given TypeConverter.
Definition Rewrite.cpp:572
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
Definition Rewrite.cpp:719
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
Definition Rewrite.cpp:517
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition Rewrite.cpp:434
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
Definition Rewrite.cpp:380
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
Definition Rewrite.cpp:715
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:315
MLIR_CAPI_EXPORTED void mlirConversionConfigDestroy(MlirConversionConfig config)
Destroy the given ConversionConfig.
Definition Rewrite.cpp:470
MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void)
ConversionConfig API.
Definition Rewrite.cpp:466
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
Definition Rewrite.cpp:325
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
Definition Rewrite.cpp:283
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
Definition Rewrite.cpp:277
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
Definition Rewrite.h:51
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
Definition Rewrite.h:53
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
Definition Rewrite.h:55
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
Definition Rewrite.h:57
MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a conversion pattern that matches the operation with the given rootName, corresponding to mlir...
Definition Rewrite.cpp:637
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
Definition Rewrite.h:41
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
Definition Rewrite.h:45
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
Definition Rewrite.h:47
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
Definition Rewrite.h:43
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as legal.
Definition Rewrite.cpp:554
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:375
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1336
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1160
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1253
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
Definition Rewrite.cpp:467
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:198
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ConversionPattern API.
Definition Rewrite.h:553
MlirLogicalResult(* matchAndRewrite)(MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the conver...
Definition Rewrite.h:563
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:556
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:559
A logical result value, essentially a boolean with named states.
Definition Support.h:121
RewritePattern API.
Definition Rewrite.h:590
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
Definition Rewrite.h:600
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:593
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:596
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78