MLIR 23.0.0git
Rewrite.cpp
Go to the documentation of this file.
1//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Rewrite.h"
10
12#include "mlir-c/IR.h"
13#include "mlir-c/Rewrite.h"
14#include "mlir-c/Support.h"
17#include "mlir/Config/mlir-config.h"
18#include "nanobind/nanobind.h"
19#include <type_traits>
20
21namespace nb = nanobind;
22using namespace mlir;
23using namespace nb::literals;
25
26namespace mlir {
27namespace python {
29
30// Convert the Python object to a boolean.
31// If it evaluates to False, treat it as success;
32// otherwise, treat it as failure.
33// Note that None is considered success.
34static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
35 if (obj.is_none())
37
38 return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
40}
41
42static std::string operationNameFromObject(nb::handle root) {
43 if (root.is_type())
44 return nb::cast<std::string>(root.attr("OPERATION_NAME"));
45 if (nb::isinstance<nb::str>(root))
46 return nb::cast<std::string>(root);
47
48 throw nb::type_error("the root argument must be a type or a string");
49}
50
51static std::string dialectNameFromObject(nb::handle root) {
52 if (root.is_type())
53 return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
54 if (nb::isinstance<nb::str>(root))
55 return nb::cast<std::string>(root);
56
57 throw nb::type_error("the root argument must be a type or a string");
58}
59
60class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
61public:
62 static constexpr const char *pyClassName = "PatternRewriter";
63
64 PyPatternRewriter(MlirPatternRewriter rewriter)
66};
67
68//===----------------------------------------------------------------------===//
69// PyRewritePatternSet
70//===----------------------------------------------------------------------===//
71
73 : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
74
75PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
76 : patterns(patterns), owned(false) {}
77
79 if (owned && patterns.ptr)
81}
82
83MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
84
85bool PyRewritePatternSet::isOwned() const { return owned; }
86
87void PyRewritePatternSet::add(nb::handle root,
88 const nb::callable &matchAndRewrite,
89 unsigned benefit) {
90 std::string opName = operationNameFromObject(root);
91 MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
92
94 callbacks.construct = [](void *userData) {
95 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
96 };
97 callbacks.destruct = [](void *userData) {
98 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
99 };
100 callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
101 MlirPatternRewriter rewriter,
102 void *userData) -> MlirLogicalResult {
103 nb::handle f(static_cast<PyObject *>(userData));
104
105 PyMlirContextRef context =
107 nb::object opView = PyOperation::forOperation(context, op)->createOpView();
108
109 nb::object res = f(opView, PyPatternRewriter(rewriter));
110 return logicalResultFromObject(res);
111 };
112
113 MlirRewritePattern pattern = mlirOpRewritePatternCreate(
114 rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks,
115 matchAndRewrite.ptr(),
116 /* nGeneratedNames */ 0,
117 /* generatedNames */ nullptr);
118 mlirRewritePatternSetAdd(patterns, pattern);
119}
120
121//===----------------------------------------------------------------------===//
122// PyConversionPatternRewriter
123//===----------------------------------------------------------------------===//
124
126public:
131
132 MlirConversionPatternRewriter rewriter;
133};
134
136public:
137 PyConversionTarget(MlirContext context)
138 : target(mlirConversionTargetCreate(context)) {}
140
141 void addLegalOp(const std::string &opName) {
143 target, mlirStringRefCreate(opName.data(), opName.size()));
144 }
145
146 void addIllegalOp(const std::string &opName) {
148 target, mlirStringRefCreate(opName.data(), opName.size()));
149 }
150
151 void addLegalDialect(const std::string &dialectName) {
153 target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
154 }
155
156 void addIllegalDialect(const std::string &dialectName) {
158 target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
159 }
160
161 MlirConversionTarget get() { return target; }
162
163private:
164 MlirConversionTarget target;
165};
166
168public:
169 PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
170 PyTypeConverter(MlirTypeConverter typeConverter)
171 : typeConverter(typeConverter), owner(false) {}
173 if (owner)
174 mlirTypeConverterDestroy(typeConverter);
175 }
176
177 void addConversion(const nb::callable &convert) {
179 typeConverter,
180 [](MlirType type, MlirType *converted,
181 void *userData) -> MlirLogicalResult {
182 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
184 nb::object res = f(PyType(ctx, type).maybeDownCast());
185 if (res.is_none())
187
188 *converted = nb::cast<PyType>(res).get();
190 },
191 convert.ptr());
192 }
193
194 nb::typed<nb::object, std::optional<PyType>> convertType(PyType &type) {
195 MlirType converted = mlirTypeConverterConvertType(typeConverter, type);
196 if (mlirTypeIsNull(converted))
197 return nb::none();
199 converted)
200 .maybeDownCast();
201 }
202
203 MlirTypeConverter get() { return typeConverter; }
204
205private:
206 MlirTypeConverter typeConverter;
207 bool owner;
208};
209
211public:
212 PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
213
217
218private:
219 MlirConversionPattern pattern;
220};
221
223 const nb::callable &matchAndRewrite,
224 PyTypeConverter &typeConverter,
225 unsigned benefit) {
226 std::string opName = operationNameFromObject(root);
227 MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
228
230 callbacks.construct = [](void *userData) {
231 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
232 };
233 callbacks.destruct = [](void *userData) {
234 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
235 };
236 callbacks.matchAndRewrite =
237 [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
238 MlirValue *operands, MlirConversionPatternRewriter rewriter,
239 void *userData) -> MlirLogicalResult {
240 nb::handle f(static_cast<PyObject *>(userData));
241
242 PyMlirContextRef ctx =
244 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
245
246 std::vector<MlirValue> operandsVec(operands, operands + nOperands);
247 nb::object adaptorCls =
251 return std::string_view(ref.data, ref.length);
252 }())
253 .value_or(nb::borrow(nb::type<PyOpAdaptor>()));
254
255 nb::object res = f(opView, adaptorCls(operandsVec, opView),
256 PyConversionPattern(pattern).getTypeConverter(),
258 return logicalResultFromObject(res);
259 };
260 MlirConversionPattern pattern = mlirOpConversionPatternCreate(
261 rootName, benefit, mlirRewritePatternSetGetContext(patterns),
262 typeConverter.get(), callbacks, matchAndRewrite.ptr(),
263 /* nGeneratedNames */ 0,
264 /* generatedNames */ nullptr);
267}
268
269#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
270struct PyMlirPDLResultList : MlirPDLResultList {};
271
272static nb::object objectFromPDLValue(MlirPDLValue value) {
273 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
274 return nb::cast(v);
275 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
276 return nb::cast(v);
277 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
278 return nb::cast(v);
279 if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
280 return nb::cast(v);
281
282 throw std::runtime_error("unsupported PDL value type");
283}
284
285static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
286 MlirPDLValue *values) {
287 std::vector<nb::object> args;
288 args.reserve(nValues);
289 for (size_t i = 0; i < nValues; ++i)
290 args.push_back(objectFromPDLValue(values[i]));
291 return args;
292}
293
294/// Owning Wrapper around a PDLPatternModule.
295class PyPDLPatternModule {
296public:
297 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
298 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
299 : module(other.module) {
300 other.module.ptr = nullptr;
301 }
302 ~PyPDLPatternModule() {
303 if (module.ptr != nullptr)
304 mlirPDLPatternModuleDestroy(module);
305 }
306 MlirPDLPatternModule get() { return module; }
307
308 void registerRewriteFunction(const std::string &name,
309 const nb::callable &fn) {
310 mlirPDLPatternModuleRegisterRewriteFunction(
311 get(), mlirStringRefCreate(name.data(), name.size()),
312 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
313 size_t nValues, MlirPDLValue *values,
314 void *userData) -> MlirLogicalResult {
315 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
316 return logicalResultFromObject(
317 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
318 objectsFromPDLValues(nValues, values)));
319 },
320 fn.ptr());
321 }
322
323 void registerConstraintFunction(const std::string &name,
324 const nb::callable &fn) {
325 mlirPDLPatternModuleRegisterConstraintFunction(
326 get(), mlirStringRefCreate(name.data(), name.size()),
327 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
328 size_t nValues, MlirPDLValue *values,
329 void *userData) -> MlirLogicalResult {
330 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
331 return logicalResultFromObject(
332 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
333 objectsFromPDLValues(nValues, values)));
334 },
335 fn.ptr());
336 }
337
338private:
339 MlirPDLPatternModule module;
340};
341#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
342
343/// Owning Wrapper around a FrozenRewritePatternSet.
345public:
346 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
348 : set(other.set) {
349 other.set.ptr = nullptr;
350 }
352 if (set.ptr != nullptr)
354 }
355 MlirFrozenRewritePatternSet get() { return set; }
356
357 nb::object getCapsule() {
358 return nb::steal<nb::object>(
360 }
361
362 static nb::object createFromCapsule(const nb::object &capsule) {
363 MlirFrozenRewritePatternSet rawPm =
365 if (rawPm.ptr == nullptr)
366 throw nb::python_error();
367 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
368 }
369
370private:
371 MlirFrozenRewritePatternSet set;
372};
373
374void PyRewritePatternSet::bind(nb::module_ &m) {
375 nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
376 .def(
377 "__init__",
379 new (&self) PyRewritePatternSet(context.get()->get());
380 },
381 "context"_a = nb::none())
382 .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
383 nb::arg("benefit") = 1,
384 R"(Add a new rewrite pattern on the specified root operation, using
385 the provided callable for matching and rewriting, and assign it
386 the given benefit.
387
388 Args:
389 root: The root operation to which this pattern applies. This may
390 be either an OpView subclass or an operation name.
391 fn: The callable to use for matching and rewriting, which takes
392 an operation and a pattern rewriter. The match is considered
393 successful iff the callable returns a falsy value.
394 benefit: The benefit of the pattern, defaulting to 1.)")
395 .def("add_conversion", &PyRewritePatternSet::addConversion,
396 nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"),
397 nb::arg("benefit") = 1,
398 R"(
399 Add a new conversion pattern on the specified root operation,
400 using the provided callable for matching and rewriting,
401 and assign it the given benefit.
402
403 Args:
404 root: The root operation to which this pattern applies.
405 This may be either an OpView subclass or an operation name.
406 fn: The callable to use for matching and rewriting, which takes an
407 operation, its adaptor, the type converter and a pattern
408 rewriter. The match is considered successful iff the callable
409 returns a falsy value.
410 type_converter: The type converter to convert types in the IR.
411 benefit: The benefit of the pattern, defaulting to 1.)")
412 .def(
413 "freeze",
414 [](PyRewritePatternSet &self) {
415 if (!self.isOwned())
416 throw std::runtime_error(
417 "cannot freeze a non-owning pattern set");
418 MlirRewritePatternSet s = self.get();
420 },
421 "Freeze the pattern set into a frozen one.");
422}
424enum class PyGreedyRewriteStrictness : std::underlying_type_t<
429};
431enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
437
438/// Owning Wrapper around a GreedyRewriteDriverConfig.
440public:
443 PyGreedyRewriteConfig::customDeleter) {}
445 : config(std::move(other.config)) {}
446 PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
447 : config(other.config) {}
448
449 MlirGreedyRewriteDriverConfig get() {
450 return MlirGreedyRewriteDriverConfig{config.get()};
451 }
453 void setMaxIterations(int64_t maxIterations) {
455 }
456
457 void setMaxNumRewrites(int64_t maxNumRewrites) {
459 }
460
461 void setUseTopDownTraversal(bool useTopDownTraversal) {
463 useTopDownTraversal);
464 }
466 void enableFolding(bool enable) {
468 }
472 get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
474
475 void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
477 get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
478 }
479
480 void enableConstantCSE(bool enable) {
483
484 int64_t getMaxIterations() {
486 }
490 }
491
492 bool getUseTopDownTraversal() {
494 }
495
496 bool isFoldingEnabled() {
504
505 PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
506 return static_cast<PyGreedySimplifyRegionLevel>(
510 bool isConstantCSEEnabled() {
512 }
514private:
515 std::shared_ptr<void> config;
516 static void customDeleter(void *c) {
517 mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
519};
520
521enum class PyDialectConversionFoldingMode : std::underlying_type_t<
526};
528class PyConversionConfig {
529public:
532 PyConversionConfig::customDeleter) {}
533
534 MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
535
540
541 PyDialectConversionFoldingMode getFoldingMode() {
544 }
545
546 void enableBuildMaterializations(bool enabled) {
548 }
549
550 bool isBuildMaterializationsEnabled() {
552 }
553
554private:
555 std::shared_ptr<void> config;
556 static void customDeleter(void *c) {
557 mlirConversionConfigDestroy(MlirConversionConfig{c});
558 }
559};
560
561/// Create the `mlir.rewrite` here.
562void populateRewriteSubmodule(nb::module_ &m) {
563 // Enum definitions
564 nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
565 .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
566 .value("EXISTING_AND_NEW_OPS",
567 PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS)
568 .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
569
570 nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
571 .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
572 .value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
573 .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
574
575 nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
576 .value("NEVER", PyDialectConversionFoldingMode::Never)
577 .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
578 .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
579
580 //----------------------------------------------------------------------------
581 // Mapping of the PatternRewriter
582 //----------------------------------------------------------------------------
583
584 PyPatternRewriter::bind(m);
585
586 //----------------------------------------------------------------------------
587 // Mapping of the RewritePatternSet
588 //----------------------------------------------------------------------------
590
591 nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
592 m, "ConversionPatternRewriter")
593 .def("convert_region_types",
594 [](PyConversionPatternRewriter &self, PyRegion &region,
595 PyTypeConverter &typeConverter) {
597 self.rewriter, region.get(), typeConverter.get());
598 });
599
600 nb::class_<PyConversionTarget>(m, "ConversionTarget")
601 .def(
602 "__init__",
603 [](PyConversionTarget &self, DefaultingPyMlirContext context) {
604 new (&self) PyConversionTarget(context.get()->get());
605 },
606 "context"_a = nb::none())
607 .def(
608 "add_legal_op",
609 [](PyConversionTarget &self, const nb::args &ops) {
610 for (auto op : ops) {
612 }
613 },
614 "ops"_a, "Mark the given operations as legal.")
615 .def(
616 "add_illegal_op",
617 [](PyConversionTarget &self, const nb::args &ops) {
618 for (auto op : ops) {
620 }
621 },
622 "ops"_a, "Mark the given operations as illegal.")
623 .def(
624 "add_legal_dialect",
625 [](PyConversionTarget &self, const nb::args &dialects) {
626 for (auto dialect : dialects) {
628 }
629 },
630 "dialects"_a, "Mark the given dialects as legal.")
631 .def(
632 "add_illegal_dialect",
633 [](PyConversionTarget &self, const nb::args &dialects) {
634 for (auto dialect : dialects) {
636 }
637 },
638 "dialects"_a, "Mark the given dialect as illegal.");
639
640 nb::class_<PyTypeConverter>(m, "TypeConverter")
641 .def(nb::init<>(), "Create a new TypeConverter.")
642 .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
643 nb::keep_alive<0, 1>(), "Register a type conversion function.")
644 .def("convert_type", &PyTypeConverter::convertType, "type"_a,
645 "Convert the given type. Returns None if conversion fails.");
646
647 //----------------------------------------------------------------------------
648 // Mapping of the PDLResultList and PDLModule
649 //----------------------------------------------------------------------------
650#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
651 nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
652 .def("append",
653 [](PyMlirPDLResultList results, const PyValue &value) {
654 mlirPDLResultListPushBackValue(results, value);
655 })
656 .def("append",
657 [](PyMlirPDLResultList results, const PyOperation &op) {
658 mlirPDLResultListPushBackOperation(results, op);
659 })
660 .def("append",
661 [](PyMlirPDLResultList results, const PyType &type) {
662 mlirPDLResultListPushBackType(results, type);
663 })
664 .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
665 mlirPDLResultListPushBackAttribute(results, attr);
666 });
667 nb::class_<PyPDLPatternModule>(m, "PDLModule")
668 .def(
669 "__init__",
670 [](PyPDLPatternModule &self, PyModule &module) {
671 new (&self) PyPDLPatternModule(
672 mlirPDLPatternModuleFromModule(module.get()));
673 },
674 "module"_a, "Create a PDL module from the given module.")
675 .def(
676 "__init__",
677 [](PyPDLPatternModule &self, PyModule &module) {
678 new (&self) PyPDLPatternModule(
679 mlirPDLPatternModuleFromModule(module.get()));
680 },
681 "module"_a, "Create a PDL module from the given module.")
682 .def(
683 "freeze",
684 [](PyPDLPatternModule &self) {
685 return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
686 mlirRewritePatternSetFromPDLPatternModule(self.get())));
687 },
688 nb::keep_alive<0, 1>())
689 .def(
690 "register_rewrite_function",
691 [](PyPDLPatternModule &self, const std::string &name,
692 const nb::callable &fn) {
693 self.registerRewriteFunction(name, fn);
694 },
695 nb::keep_alive<1, 3>())
696 .def(
697 "register_constraint_function",
698 [](PyPDLPatternModule &self, const std::string &name,
699 const nb::callable &fn) {
700 self.registerConstraintFunction(name, fn);
701 },
702 nb::keep_alive<1, 3>());
703#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
704
705 nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
706 .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
707 .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
709 "Maximum number of iterations")
710 .def_prop_rw("max_num_rewrites",
713 "Maximum number of rewrites per iteration")
714 .def_prop_rw("use_top_down_traversal",
717 "Whether to use top-down traversal")
718 .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
720 "Enable or disable folding")
721 .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
723 "Rewrite strictness level")
724 .def_prop_rw("region_simplification_level",
727 "Region simplification level")
728 .def_prop_rw("enable_constant_cse",
731 "Enable or disable constant CSE");
732
733 nb::class_<PyConversionConfig>(m, "ConversionConfig")
734 .def(nb::init<>(), "Create a conversion config with defaults")
735 .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
737 "folding behavior during dialect conversion")
738 .def_prop_rw("build_materializations",
741 "Whether the dialect conversion attempts to build "
742 "source/target materializations");
743
744 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
745 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
749 m.def(
750 "apply_patterns_and_fold_greedily",
751 [](PyModule &module, PyFrozenRewritePatternSet &set,
752 std::optional<PyGreedyRewriteConfig> config) {
754 module.get(), set.get(),
755 config.has_value() ? config->get()
757 if (mlirLogicalResultIsFailure(status))
758 throw std::runtime_error("pattern application failed to converge");
759 },
760 "module"_a, "set"_a, "config"_a = nb::none(),
761 "Applys the given patterns to the given module greedily while folding "
762 "results.")
763 .def(
764 "apply_patterns_and_fold_greedily",
765 [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
766 std::optional<PyGreedyRewriteConfig> config) {
768 op.getOperation(), set.get(),
769 config.has_value() ? config->get()
771 if (mlirLogicalResultIsFailure(status))
772 throw std::runtime_error(
773 "pattern application failed to converge");
774 },
775 "op"_a, "set"_a, "config"_a = nb::none(),
776 "Applys the given patterns to the given op greedily while folding "
777 "results.")
778 .def(
779 "walk_and_apply_patterns",
780 [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
782 },
783 "op"_a, "set"_a,
784 "Applies the given patterns to the given op by a fast walk-based "
785 "driver.")
786 .def(
787 "apply_partial_conversion",
788 [](PyOperationBase &op, PyConversionTarget &target,
789 PyFrozenRewritePatternSet &set,
790 std::optional<PyConversionConfig> config) {
791 if (!config)
792 config.emplace(PyConversionConfig());
793 PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
795 op.getOperation(), target.get(), set.get(), config->get());
796 if (mlirLogicalResultIsFailure(status))
797 throw MLIRError("partial conversion failed", errors.take());
798 },
799 "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
800 "Applies a partial conversion on the given operation.")
801 .def(
802 "apply_full_conversion",
803 [](PyOperationBase &op, PyConversionTarget &target,
804 PyFrozenRewritePatternSet &set,
805 std::optional<PyConversionConfig> config) {
806 if (!config)
807 config.emplace(PyConversionConfig());
808 PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
810 op.getOperation(), target.get(), set.get(), config->get());
811 if (mlirLogicalResultIsFailure(status))
812 throw MLIRError("full conversion failed", errors.take());
813 },
814 "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
815 "Applies a full conversion on the given operation.");
816}
817} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
818} // namespace python
819} // namespace mlir
true
Given two iterators into the same block, return "true" if a is before `b.
MlirIdentifier mlirOperationGetName(MlirOperation op)
Definition IR.cpp:668
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition Interop.h:293
false
Parses a map_entries map type from a string format back into its numeric value.
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:297
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:278
void setFoldingMode(PyDialectConversionFoldingMode mode)
Definition Rewrite.cpp:513
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
Definition Rewrite.cpp:127
void addIllegalDialect(const std::string &dialectName)
Definition Rewrite.cpp:156
void addLegalDialect(const std::string &dialectName)
Definition Rewrite.cpp:151
Owning Wrapper around a FrozenRewritePatternSet.
Definition Rewrite.cpp:344
static nb::object createFromCapsule(const nb::object &capsule)
Definition Rewrite.cpp:362
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
Definition Rewrite.cpp:347
std::optional< nanobind::object > lookupOpAdaptorClass(std::string_view operationName)
Looks up a registered operation adaptor class by operation name.
Definition Globals.cpp:238
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:59
Owning Wrapper around a GreedyRewriteDriverConfig.
Definition Rewrite.cpp:416
void setStrictness(PyGreedyRewriteStrictness strictness)
Definition Rewrite.cpp:447
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
Definition Rewrite.cpp:452
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:211
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:547
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
void add(nanobind::handle root, const nanobind::callable &matchAndRewrite, unsigned benefit)
Add a new rewrite pattern to the pattern set.
Definition Rewrite.cpp:87
void addConversion(nanobind::handle root, const nanobind::callable &matchAndRewrite, PyTypeConverter &typeConverter, unsigned benefit)
Add a new conversion pattern to the pattern set.
Definition Rewrite.cpp:222
PyRewritePatternSet(MlirContext ctx)
Create an owned pattern set.
Definition Rewrite.cpp:72
nb::typed< nb::object, std::optional< PyType > > convertType(PyType &type)
Definition Rewrite.cpp:194
Wrapper around the generic MlirType.
Definition IRCore.h:874
nanobind::typed< nanobind::object, PyType > maybeDownCast()
Definition IRCore.cpp:1942
MLIR_CAPI_EXPORTED void mlirConversionTargetDestroy(MlirConversionTarget target)
Destroy the given ConversionTarget.
Definition Rewrite.cpp:545
MLIR_CAPI_EXPORTED MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set)
Get the context associated with a MlirRewritePatternSet.
Definition Rewrite.cpp:731
MlirDialectConversionFoldingMode
Definition Rewrite.h:452
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS
Definition Rewrite.h:455
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS
Definition Rewrite.h:454
@ MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER
Definition Rewrite.h:453
MLIR_CAPI_EXPORTED MlirRewritePattern mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern)
Cast the ConversionPattern to a RewritePattern.
Definition Rewrite.cpp:668
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
Definition Rewrite.cpp:365
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a partial conversion on the given operation.
Definition Rewrite.cpp:447
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
Definition Rewrite.cpp:405
MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode mlirConversionConfigGetFoldingMode(MlirConversionConfig config)
Get the folding mode for the given ConversionConfig.
Definition Rewrite.cpp:492
MLIR_CAPI_EXPORTED bool mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config)
Check if building materializations during conversion is enabled.
Definition Rewrite.cpp:508
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config)
Apply a full conversion on the given operation.
Definition Rewrite.cpp:454
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
Definition Rewrite.cpp:385
MLIR_CAPI_EXPORTED void mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config, bool enable)
Enable or disable building materializations during conversion.
Definition Rewrite.cpp:503
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
Definition Rewrite.cpp:301
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
Definition Rewrite.cpp:420
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
Definition Rewrite.cpp:739
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as illegal.
Definition Rewrite.cpp:566
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
Definition Rewrite.cpp:441
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
Definition Rewrite.cpp:348
MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void)
TypeConverter API.
Definition Rewrite.cpp:575
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as legal.
Definition Rewrite.cpp:549
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
Definition Rewrite.cpp:426
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
Definition Rewrite.cpp:305
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
Definition Rewrite.cpp:320
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:310
MLIR_CAPI_EXPORTED MlirTypeConverter mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern)
Get the type converter used by this conversion pattern.
Definition Rewrite.cpp:663
MLIR_CAPI_EXPORTED MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(MlirConversionPatternRewriter rewriter)
ConversionPatternRewriter API.
Definition Rewrite.cpp:525
MLIR_CAPI_EXPORTED MlirConversionTarget mlirConversionTargetCreate(MlirContext context)
ConversionTarget API.
Definition Rewrite.cpp:541
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
Definition Rewrite.cpp:709
MLIR_CAPI_EXPORTED void mlirConversionTargetAddIllegalOp(MlirConversionTarget target, MlirStringRef opName)
Register the given operations as illegal.
Definition Rewrite.cpp:555
MLIR_CAPI_EXPORTED void mlirConversionConfigSetFoldingMode(MlirConversionConfig config, MlirDialectConversionFoldingMode mode)
Set the folding mode for the given ConversionConfig.
Definition Rewrite.cpp:474
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:370
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:390
MLIR_CAPI_EXPORTED MlirType mlirTypeConverterConvertType(MlirTypeConverter typeConverter, MlirType type)
Convert the given type using the given TypeConverter.
Definition Rewrite.cpp:600
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:330
MLIR_CAPI_EXPORTED void mlirTypeConverterAddConversion(MlirTypeConverter typeConverter, MlirTypeConverterConversionCallback convertType, void *userData)
Add a type conversion function to the given TypeConverter.
Definition Rewrite.cpp:583
MLIR_CAPI_EXPORTED void mlirTypeConverterDestroy(MlirTypeConverter typeConverter)
Destroy the given TypeConverter.
Definition Rewrite.cpp:579
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
Definition Rewrite.cpp:735
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
Definition Rewrite.cpp:517
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition Rewrite.cpp:434
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
Definition Rewrite.cpp:380
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
Definition Rewrite.cpp:727
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:315
MLIR_CAPI_EXPORTED void mlirConversionConfigDestroy(MlirConversionConfig config)
Destroy the given ConversionConfig.
Definition Rewrite.cpp:470
MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void)
ConversionConfig API.
Definition Rewrite.cpp:466
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
Definition Rewrite.cpp:325
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
Definition Rewrite.cpp:283
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
Definition Rewrite.cpp:277
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
Definition Rewrite.h:51
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
Definition Rewrite.h:53
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
Definition Rewrite.h:55
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
Definition Rewrite.h:57
MLIR_CAPI_EXPORTED MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(MlirConversionPatternRewriter rewriter, MlirRegion region, MlirTypeConverter typeConverter)
Apply a signature conversion to each block in the given region.
Definition Rewrite.cpp:530
MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a conversion pattern that matches the operation with the given rootName, corresponding to mlir...
Definition Rewrite.cpp:649
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
Definition Rewrite.h:41
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
Definition Rewrite.h:45
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
Definition Rewrite.h:47
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
Definition Rewrite.h:43
MLIR_CAPI_EXPORTED void mlirConversionTargetAddLegalDialect(MlirConversionTarget target, MlirStringRef dialectName)
Register the operations of the given dialect as legal.
Definition Rewrite.cpp:561
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:375
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1336
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1160
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1253
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:143
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:124
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:137
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
static std::string dialectNameFromObject(nb::handle root)
Definition Rewrite.cpp:51
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
Definition Rewrite.cpp:539
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:197
static std::string operationNameFromObject(nb::handle root)
Definition Rewrite.cpp:42
static MlirLogicalResult logicalResultFromObject(const nb::object &obj)
Definition Rewrite.cpp:34
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ConversionPattern API.
Definition Rewrite.h:563
MlirLogicalResult(* matchAndRewrite)(MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the conver...
Definition Rewrite.h:573
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:566
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:569
A logical result value, essentially a boolean with named states.
Definition Support.h:121
RewritePattern API.
Definition Rewrite.h:600
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
Definition Rewrite.h:610
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:603
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:606
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80