MLIR 22.0.0git
Rewrite.cpp
Go to the documentation of this file.
1//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Rewrite.h"
10
11#include "IRModule.h"
12#include "mlir-c/IR.h"
13#include "mlir-c/Rewrite.h"
14#include "mlir-c/Support.h"
15// clang-format off
17#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
18// clang-format on
19#include "mlir/Config/mlir-config.h"
20#include "nanobind/nanobind.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24using namespace nb::literals;
25using namespace mlir::python;
26
27namespace {
28
29class PyPatternRewriter {
30public:
31 PyPatternRewriter(MlirPatternRewriter rewriter)
32 : base(mlirPatternRewriterAsBase(rewriter)),
33 ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
34
35 PyInsertionPoint getInsertionPoint() const {
36 MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
37 MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
38
39 if (mlirOperationIsNull(op)) {
40 MlirOperation owner = mlirBlockGetParentOperation(block);
41 auto parent = PyOperation::forOperation(ctx, owner);
42 return PyInsertionPoint(PyBlock(parent, block));
43 }
44
45 return PyInsertionPoint(PyOperation::forOperation(ctx, op));
46 }
47
48 void replaceOp(MlirOperation op, MlirOperation newOp) {
50 }
51
52 void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
53 mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
54 }
55
56 void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
57
58private:
59 MlirRewriterBase base;
61};
62
63#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
64static nb::object objectFromPDLValue(MlirPDLValue value) {
65 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
66 return nb::cast(v);
67 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
68 return nb::cast(v);
69 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
70 return nb::cast(v);
71 if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
72 return nb::cast(v);
73
74 throw std::runtime_error("unsupported PDL value type");
75}
76
77static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
78 MlirPDLValue *values) {
79 std::vector<nb::object> args;
80 args.reserve(nValues);
81 for (size_t i = 0; i < nValues; ++i)
82 args.push_back(objectFromPDLValue(values[i]));
83 return args;
84}
85
86// Convert the Python object to a boolean.
87// If it evaluates to False, treat it as success;
88// otherwise, treat it as failure.
89// Note that None is considered success.
90static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
91 if (obj.is_none())
93
94 return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
96}
97
98/// Owning Wrapper around a PDLPatternModule.
99class PyPDLPatternModule {
100public:
101 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
102 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
103 : module(other.module) {
104 other.module.ptr = nullptr;
105 }
106 ~PyPDLPatternModule() {
107 if (module.ptr != nullptr)
108 mlirPDLPatternModuleDestroy(module);
109 }
110 MlirPDLPatternModule get() { return module; }
111
112 void registerRewriteFunction(const std::string &name,
113 const nb::callable &fn) {
114 mlirPDLPatternModuleRegisterRewriteFunction(
115 get(), mlirStringRefCreate(name.data(), name.size()),
116 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
117 size_t nValues, MlirPDLValue *values,
118 void *userData) -> MlirLogicalResult {
119 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
120 return logicalResultFromObject(
121 f(PyPatternRewriter(rewriter), results,
122 objectsFromPDLValues(nValues, values)));
123 },
124 fn.ptr());
125 }
126
127 void registerConstraintFunction(const std::string &name,
128 const nb::callable &fn) {
129 mlirPDLPatternModuleRegisterConstraintFunction(
130 get(), mlirStringRefCreate(name.data(), name.size()),
131 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
132 size_t nValues, MlirPDLValue *values,
133 void *userData) -> MlirLogicalResult {
134 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
135 return logicalResultFromObject(
136 f(PyPatternRewriter(rewriter), results,
137 objectsFromPDLValues(nValues, values)));
138 },
139 fn.ptr());
140 }
141
142private:
143 MlirPDLPatternModule module;
144};
145#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
146
147/// Owning Wrapper around a FrozenRewritePatternSet.
148class PyFrozenRewritePatternSet {
149public:
150 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
151 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
152 : set(other.set) {
153 other.set.ptr = nullptr;
154 }
155 ~PyFrozenRewritePatternSet() {
156 if (set.ptr != nullptr)
158 }
159 MlirFrozenRewritePatternSet get() { return set; }
160
161 nb::object getCapsule() {
162 return nb::steal<nb::object>(
164 }
165
166 static nb::object createFromCapsule(const nb::object &capsule) {
167 MlirFrozenRewritePatternSet rawPm =
169 if (rawPm.ptr == nullptr)
170 throw nb::python_error();
171 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
172 }
173
174private:
175 MlirFrozenRewritePatternSet set;
176};
177
178class PyRewritePatternSet {
179public:
180 PyRewritePatternSet(MlirContext ctx)
181 : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
182 ~PyRewritePatternSet() {
183 if (set.ptr)
185 }
186
187 void add(MlirStringRef rootName, unsigned benefit,
188 const nb::callable &matchAndRewrite) {
189 MlirRewritePatternCallbacks callbacks;
190 callbacks.construct = [](void *userData) {
191 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
192 };
193 callbacks.destruct = [](void *userData) {
194 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
195 };
196 callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
197 MlirPatternRewriter rewriter,
198 void *userData) -> MlirLogicalResult {
199 nb::handle f(static_cast<PyObject *>(userData));
200
201 PyMlirContextRef ctx =
203 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
204
205 nb::object res = f(opView, PyPatternRewriter(rewriter));
206 return logicalResultFromObject(res);
207 };
208 MlirRewritePattern pattern = mlirOpRewritePatternCreate(
209 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
210 /* nGeneratedNames */ 0,
211 /* generatedNames */ nullptr);
212 mlirRewritePatternSetAdd(set, pattern);
213 }
214
215 PyFrozenRewritePatternSet freeze() {
216 MlirRewritePatternSet s = set;
217 set.ptr = nullptr;
218 return mlirFreezeRewritePattern(s);
219 }
220
221private:
222 MlirRewritePatternSet set;
223 MlirContext ctx;
224};
225
226/// Owning Wrapper around a GreedyRewriteDriverConfig.
227class PyGreedyRewriteDriverConfig {
228public:
229 PyGreedyRewriteDriverConfig()
231 PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
232 : config(other.config) {
233 other.config.ptr = nullptr;
234 }
235 ~PyGreedyRewriteDriverConfig() {
236 if (config.ptr != nullptr)
238 }
239 MlirGreedyRewriteDriverConfig get() { return config; }
240
241 void setMaxIterations(int64_t maxIterations) {
243 }
244
245 void setMaxNumRewrites(int64_t maxNumRewrites) {
247 }
248
249 void setUseTopDownTraversal(bool useTopDownTraversal) {
251 useTopDownTraversal);
252 }
253
254 void enableFolding(bool enable) {
256 }
257
258 void setStrictness(MlirGreedyRewriteStrictness strictness) {
260 }
261
262 void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
264 }
265
266 void enableConstantCSE(bool enable) {
268 }
269
270 int64_t getMaxIterations() {
272 }
273
274 int64_t getMaxNumRewrites() {
276 }
277
278 bool getUseTopDownTraversal() {
280 }
281
282 bool isFoldingEnabled() {
284 }
285
286 MlirGreedyRewriteStrictness getStrictness() {
288 }
289
290 MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
292 }
293
294 bool isConstantCSEEnabled() {
296 }
297
298private:
299 MlirGreedyRewriteDriverConfig config;
300};
301
302} // namespace
303
304/// Create the `mlir.rewrite` here.
305void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
306 // Enum definitions
307 nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
309 .value("EXISTING_AND_NEW_OPS",
311 .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
312
313 nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
317
318 //----------------------------------------------------------------------------
319 // Mapping of the PatternRewriter
320 //----------------------------------------------------------------------------
321 nb::
322 class_<PyPatternRewriter>(m, "PatternRewriter")
323 .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
324 "The current insertion point of the PatternRewriter.")
325 .def(
326 "replace_op",
327 [](PyPatternRewriter &self, MlirOperation op,
328 MlirOperation newOp) { self.replaceOp(op, newOp); },
329 "Replace an operation with a new operation.", nb::arg("op"),
330 nb::arg("new_op"),
331 // clang-format off
332 nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
333 // clang-format on
334 )
335 .def(
336 "replace_op",
337 [](PyPatternRewriter &self, MlirOperation op,
338 const std::vector<MlirValue> &values) {
339 self.replaceOp(op, values);
340 },
341 "Replace an operation with a list of values.", nb::arg("op"),
342 nb::arg("values"),
343 // clang-format off
344 nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
345 // clang-format on
346 )
347 .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
348 nb::arg("op"),
349 // clang-format off
350 nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
351 // clang-format on
352 );
353
354 //----------------------------------------------------------------------------
355 // Mapping of the RewritePatternSet
356 //----------------------------------------------------------------------------
357 nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
358 .def(
359 "__init__",
360 [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
361 new (&self) PyRewritePatternSet(context.get()->get());
362 },
363 "context"_a = nb::none())
364 .def(
365 "add",
366 [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
367 unsigned benefit) {
368 std::string opName;
369 if (root.is_type()) {
370 opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
371 } else if (nb::isinstance<nb::str>(root)) {
372 opName = nb::cast<std::string>(root);
373 } else {
374 throw nb::type_error(
375 "the root argument must be a type or a string");
376 }
377 self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
378 fn);
379 },
380 "root"_a, "fn"_a, "benefit"_a = 1,
381 // clang-format off
382 nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
383 // clang-format on
384 R"(
385 Add a new rewrite pattern on the specified root operation, using the provided callable
386 for matching and rewriting, and assign it the given benefit.
387
388 Args:
389 root: The root operation to which this pattern applies.
390 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
391 an operation name string (e.g., ``"arith.addi"``).
392 fn: The callable to use for matching and rewriting,
393 which takes an operation and a pattern rewriter as arguments.
394 The match is considered successful iff the callable returns
395 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
396 If possible, the operation is cast to its corresponding OpView subclass
397 before being passed to the callable.
398 benefit: The benefit of the pattern, defaulting to 1.)")
399 .def("freeze", &PyRewritePatternSet::freeze,
400 "Freeze the pattern set into a frozen one.");
401
402 //----------------------------------------------------------------------------
403 // Mapping of the PDLResultList and PDLModule
404 //----------------------------------------------------------------------------
405#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
406 nb::class_<MlirPDLResultList>(m, "PDLResultList")
407 .def(
408 "append",
409 [](MlirPDLResultList results, const PyValue &value) {
410 mlirPDLResultListPushBackValue(results, value);
411 },
412 // clang-format off
413 nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
414 // clang-format on
415 )
416 .def(
417 "append",
418 [](MlirPDLResultList results, const PyOperation &op) {
419 mlirPDLResultListPushBackOperation(results, op);
420 },
421 // clang-format off
422 nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
423 // clang-format on
424 )
425 .def(
426 "append",
427 [](MlirPDLResultList results, const PyType &type) {
428 mlirPDLResultListPushBackType(results, type);
429 },
430 // clang-format off
431 nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
432 // clang-format on
433 )
434 .def(
435 "append",
436 [](MlirPDLResultList results, const PyAttribute &attr) {
437 mlirPDLResultListPushBackAttribute(results, attr);
438 },
439 // clang-format off
440 nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
441 // clang-format on
442 );
443 nb::class_<PyPDLPatternModule>(m, "PDLModule")
444 .def(
445 "__init__",
446 [](PyPDLPatternModule &self, MlirModule module) {
447 new (&self)
448 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
449 },
450 // clang-format off
451 nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
452 // clang-format on
453 "module"_a, "Create a PDL module from the given module.")
454 .def(
455 "__init__",
456 [](PyPDLPatternModule &self, PyModule &module) {
457 new (&self) PyPDLPatternModule(
458 mlirPDLPatternModuleFromModule(module.get()));
459 },
460 // clang-format off
461 nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
462 // clang-format on
463 "module"_a, "Create a PDL module from the given module.")
464 .def(
465 "freeze",
466 [](PyPDLPatternModule &self) {
467 return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
468 mlirRewritePatternSetFromPDLPatternModule(self.get())));
469 },
470 nb::keep_alive<0, 1>())
471 .def(
472 "register_rewrite_function",
473 [](PyPDLPatternModule &self, const std::string &name,
474 const nb::callable &fn) {
475 self.registerRewriteFunction(name, fn);
476 },
477 nb::keep_alive<1, 3>())
478 .def(
479 "register_constraint_function",
480 [](PyPDLPatternModule &self, const std::string &name,
481 const nb::callable &fn) {
482 self.registerConstraintFunction(name, fn);
483 },
484 nb::keep_alive<1, 3>());
485#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
486
487 nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
488 .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
489 .def_prop_rw("max_iterations",
490 &PyGreedyRewriteDriverConfig::getMaxIterations,
491 &PyGreedyRewriteDriverConfig::setMaxIterations,
492 "Maximum number of iterations")
493 .def_prop_rw("max_num_rewrites",
494 &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
495 &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
496 "Maximum number of rewrites per iteration")
497 .def_prop_rw("use_top_down_traversal",
498 &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
499 &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
500 "Whether to use top-down traversal")
501 .def_prop_rw("enable_folding",
502 &PyGreedyRewriteDriverConfig::isFoldingEnabled,
503 &PyGreedyRewriteDriverConfig::enableFolding,
504 "Enable or disable folding")
505 .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
506 &PyGreedyRewriteDriverConfig::setStrictness,
507 "Rewrite strictness level")
508 .def_prop_rw("region_simplification_level",
509 &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
510 &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
511 "Region simplification level")
512 .def_prop_rw("enable_constant_cse",
513 &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
514 &PyGreedyRewriteDriverConfig::enableConstantCSE,
515 "Enable or disable constant CSE");
516
517 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
518 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
519 &PyFrozenRewritePatternSet::getCapsule)
521 &PyFrozenRewritePatternSet::createFromCapsule);
522 m.def(
523 "apply_patterns_and_fold_greedily",
524 [](PyModule &module, PyFrozenRewritePatternSet &set) {
526 module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
527 if (mlirLogicalResultIsFailure(status))
528 throw std::runtime_error("pattern application failed to converge");
529 },
530 "module"_a, "set"_a,
531 // clang-format off
532 nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
533 // clang-format on
534 "Applys the given patterns to the given module greedily while folding "
535 "results.")
536 .def(
537 "apply_patterns_and_fold_greedily",
538 [](PyModule &module, MlirFrozenRewritePatternSet set) {
541 if (mlirLogicalResultIsFailure(status))
542 throw std::runtime_error(
543 "pattern application failed to converge");
544 },
545 "module"_a, "set"_a,
546 // clang-format off
547 nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
548 // clang-format on
549 "Applys the given patterns to the given module greedily while "
550 "folding "
551 "results.")
552 .def(
553 "apply_patterns_and_fold_greedily",
554 [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
556 op.getOperation(), set.get(),
558 if (mlirLogicalResultIsFailure(status))
559 throw std::runtime_error(
560 "pattern application failed to converge");
561 },
562 "op"_a, "set"_a,
563 // clang-format off
564 nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
565 // clang-format on
566 "Applys the given patterns to the given op greedily while folding "
567 "results.")
568 .def(
569 "apply_patterns_and_fold_greedily",
570 [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
573 if (mlirLogicalResultIsFailure(status))
574 throw std::runtime_error(
575 "pattern application failed to converge");
576 },
577 "op"_a, "set"_a,
578 // clang-format off
579 nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
580 // clang-format on
581 "Applys the given patterns to the given op greedily while folding "
582 "results.")
583 .def(
584 "walk_and_apply_patterns",
585 [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
586 mlirWalkAndApplyPatterns(op.getOperation(), set.get());
587 },
588 "op"_a, "set"_a,
589 // clang-format off
590 nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
591 // clang-format on
592 "Applies the given patterns to the given op by a fast walk-based "
593 "driver.");
594}
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:651
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition Interop.h:293
#define add(a, b)
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRModule.h:273
ReferrentTy * get() const
Wrapper around the generic MlirAttribute.
Definition IRModule.h:1008
MlirContext get()
Accesses the underlying MlirContext.
Definition IRModule.h:204
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:568
MlirModule get()
Gets the backing MlirModule.
Definition IRModule.h:522
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRModule.h:552
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:1063
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1434
Wrapper around the generic MlirType.
Definition IRModule.h:878
Wrapper around the generic MlirValue.
Definition IRModule.h:1167
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
Definition Rewrite.cpp:134
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
Definition Rewrite.cpp:363
MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter)
Returns the operation right after the current insertion point of the rewriter.
Definition Rewrite.cpp:75
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
Definition Rewrite.cpp:402
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
Definition Rewrite.cpp:383
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
Definition Rewrite.cpp:299
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
Definition Rewrite.cpp:416
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
Definition Rewrite.cpp:513
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
Definition Rewrite.cpp:437
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
Definition Rewrite.cpp:346
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
Definition Rewrite.cpp:422
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
Definition Rewrite.cpp:303
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
Definition Rewrite.cpp:318
MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
Definition Rewrite.cpp:148
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:308
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
Definition Rewrite.cpp:487
MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
Definition Rewrite.cpp:29
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
Definition Rewrite.cpp:142
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:368
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:388
MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
Definition Rewrite.cpp:66
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:328
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
Definition Rewrite.cpp:509
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
Definition Rewrite.cpp:446
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition Rewrite.cpp:430
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
Definition Rewrite.cpp:378
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
Definition Rewrite.cpp:505
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:313
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
Definition Rewrite.cpp:323
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
Definition Rewrite.cpp:281
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
Definition Rewrite.cpp:275
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
Definition Rewrite.h:51
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
Definition Rewrite.h:53
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
Definition Rewrite.h:55
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
Definition Rewrite.h:57
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
Definition Rewrite.h:41
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
Definition Rewrite.h:45
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
Definition Rewrite.h:47
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
Definition Rewrite.h:43
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:373
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1156
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:987
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:82
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:138
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:119
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:132
struct MlirStringRef MlirStringRef
Definition Support.h:77
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:127
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRModule.h:190
void populateRewriteSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
Definition Rewrite.h:449
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:442
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:445