MLIR 23.0.0git
Rewrite.cpp
Go to the documentation of this file.
1//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Rewrite.h"
10
11#include "mlir-c/IR.h"
12#include "mlir-c/Rewrite.h"
13#include "mlir-c/Support.h"
15// clang-format off
17#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
18// clang-format on
19#include "mlir/Config/mlir-config.h"
20#include "nanobind/nanobind.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24using namespace nb::literals;
26
27namespace mlir {
28namespace python {
30
32public:
33 PyPatternRewriter(MlirPatternRewriter rewriter)
34 : base(mlirPatternRewriterAsBase(rewriter)),
35 ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
36
38 MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
39 MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
40
41 if (mlirOperationIsNull(op)) {
42 MlirOperation owner = mlirBlockGetParentOperation(block);
43 auto parent = PyOperation::forOperation(ctx, owner);
44 return PyInsertionPoint(PyBlock(parent, block));
45 }
46
48 }
49
50 void replaceOp(MlirOperation op, MlirOperation newOp) {
52 }
53
54 void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
55 mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
56 }
57
58 void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
59
60private:
61 MlirRewriterBase base;
63};
64
66
67#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
68static nb::object objectFromPDLValue(MlirPDLValue value) {
69 if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
70 return nb::cast(v);
71 if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
72 return nb::cast(v);
73 if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
74 return nb::cast(v);
75 if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
76 return nb::cast(v);
77
78 throw std::runtime_error("unsupported PDL value type");
79}
80
81static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
82 MlirPDLValue *values) {
83 std::vector<nb::object> args;
84 args.reserve(nValues);
85 for (size_t i = 0; i < nValues; ++i)
86 args.push_back(objectFromPDLValue(values[i]));
87 return args;
88}
89
90// Convert the Python object to a boolean.
91// If it evaluates to False, treat it as success;
92// otherwise, treat it as failure.
93// Note that None is considered success.
94static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
95 if (obj.is_none())
97
98 return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
100}
101
102/// Owning Wrapper around a PDLPatternModule.
103class PyPDLPatternModule {
104public:
105 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
106 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
107 : module(other.module) {
108 other.module.ptr = nullptr;
109 }
110 ~PyPDLPatternModule() {
111 if (module.ptr != nullptr)
112 mlirPDLPatternModuleDestroy(module);
113 }
114 MlirPDLPatternModule get() { return module; }
115
116 void registerRewriteFunction(const std::string &name,
117 const nb::callable &fn) {
118 mlirPDLPatternModuleRegisterRewriteFunction(
119 get(), mlirStringRefCreate(name.data(), name.size()),
120 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
121 size_t nValues, MlirPDLValue *values,
122 void *userData) -> MlirLogicalResult {
123 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
124 return logicalResultFromObject(
125 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
126 objectsFromPDLValues(nValues, values)));
127 },
128 fn.ptr());
129 }
130
131 void registerConstraintFunction(const std::string &name,
132 const nb::callable &fn) {
133 mlirPDLPatternModuleRegisterConstraintFunction(
134 get(), mlirStringRefCreate(name.data(), name.size()),
135 [](MlirPatternRewriter rewriter, MlirPDLResultList results,
136 size_t nValues, MlirPDLValue *values,
137 void *userData) -> MlirLogicalResult {
138 nb::handle f = nb::handle(static_cast<PyObject *>(userData));
139 return logicalResultFromObject(
140 f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
141 objectsFromPDLValues(nValues, values)));
142 },
143 fn.ptr());
144 }
145
146private:
147 MlirPDLPatternModule module;
148};
149#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
150
151/// Owning Wrapper around a FrozenRewritePatternSet.
153public:
154 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
156 : set(other.set) {
157 other.set.ptr = nullptr;
158 }
160 if (set.ptr != nullptr)
162 }
163 MlirFrozenRewritePatternSet get() { return set; }
164
165 nb::object getCapsule() {
166 return nb::steal<nb::object>(
168 }
169
170 static nb::object createFromCapsule(const nb::object &capsule) {
171 MlirFrozenRewritePatternSet rawPm =
173 if (rawPm.ptr == nullptr)
174 throw nb::python_error();
175 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
176 }
177
178private:
179 MlirFrozenRewritePatternSet set;
180};
181
183public:
184 PyRewritePatternSet(MlirContext ctx)
185 : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
187 if (set.ptr)
189 }
190
191 void add(MlirStringRef rootName, unsigned benefit,
192 const nb::callable &matchAndRewrite) {
194 callbacks.construct = [](void *userData) {
195 nb::handle(static_cast<PyObject *>(userData)).inc_ref();
196 };
197 callbacks.destruct = [](void *userData) {
198 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
199 };
200 callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
201 MlirPatternRewriter rewriter,
202 void *userData) -> MlirLogicalResult {
203 nb::handle f(static_cast<PyObject *>(userData));
204
205 PyMlirContextRef ctx =
207 nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
208
209 nb::object res = f(opView, PyPatternRewriter(rewriter));
210 return logicalResultFromObject(res);
211 };
212 MlirRewritePattern pattern = mlirOpRewritePatternCreate(
213 rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
214 /* nGeneratedNames */ 0,
215 /* generatedNames */ nullptr);
216 mlirRewritePatternSetAdd(set, pattern);
217 }
218
220 MlirRewritePatternSet s = set;
221 set.ptr = nullptr;
222 return mlirFreezeRewritePattern(s);
223 }
224
225private:
226 MlirRewritePatternSet set;
227 MlirContext ctx;
228};
229
236
243
244/// Owning Wrapper around a GreedyRewriteDriverConfig.
246public:
251 : config(std::move(other.config)) {}
253 : config(other.config) {}
254
255 MlirGreedyRewriteDriverConfig get() {
256 return MlirGreedyRewriteDriverConfig{config.get()};
257 }
258
259 void setMaxIterations(int64_t maxIterations) {
261 }
262
263 void setMaxNumRewrites(int64_t maxNumRewrites) {
265 }
266
267 void setUseTopDownTraversal(bool useTopDownTraversal) {
269 useTopDownTraversal);
270 }
271
272 void enableFolding(bool enable) {
274 }
275
280
285
289
293
297
301
305
310
315
319
320private:
321 std::shared_ptr<void> config;
322 static void customDeleter(void *c) {
323 mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
324 }
325};
326
327/// Create the `mlir.rewrite` here.
328void populateRewriteSubmodule(nb::module_ &m) {
329 // Enum definitions
330 nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
331 .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
332 .value("EXISTING_AND_NEW_OPS",
334 .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
335
336 nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
337 .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
339 .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
340 //----------------------------------------------------------------------------
341 // Mapping of the PatternRewriter
342 //----------------------------------------------------------------------------
343 nb::class_<PyPatternRewriter>(m, "PatternRewriter")
344 .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
345 "The current insertion point of the PatternRewriter.")
346 .def(
347 "replace_op",
349 PyOperationBase &newOp) {
350 self.replaceOp(op.getOperation(), newOp.getOperation());
351 },
352 "Replace an operation with a new operation.", nb::arg("op"),
353 nb::arg("new_op"))
354 .def(
355 "replace_op",
357 const std::vector<PyValue> &values) {
358 std::vector<MlirValue> values_(values.size());
359 std::copy(values.begin(), values.end(), values_.begin());
360 self.replaceOp(op.getOperation(), values_);
361 },
362 "Replace an operation with a list of values.", nb::arg("op"),
363 nb::arg("values"))
364 .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
365 nb::arg("op"));
366
367 //----------------------------------------------------------------------------
368 // Mapping of the RewritePatternSet
369 //----------------------------------------------------------------------------
370 nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
371 .def(
372 "__init__",
374 new (&self) PyRewritePatternSet(context.get()->get());
375 },
376 "context"_a = nb::none())
377 .def(
378 "add",
379 [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
380 unsigned benefit) {
381 std::string opName;
382 if (root.is_type()) {
383 opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
384 } else if (nb::isinstance<nb::str>(root)) {
385 opName = nb::cast<std::string>(root);
386 } else {
387 throw nb::type_error(
388 "the root argument must be a type or a string");
389 }
390 self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
391 fn);
392 },
393 "root"_a, "fn"_a, "benefit"_a = 1,
394 // clang-format off
395 nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
396 // clang-format on
397 R"(
398 Add a new rewrite pattern on the specified root operation, using the provided callable
399 for matching and rewriting, and assign it the given benefit.
400
401 Args:
402 root: The root operation to which this pattern applies.
403 This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
404 an operation name string (e.g., ``"arith.addi"``).
405 fn: The callable to use for matching and rewriting,
406 which takes an operation and a pattern rewriter as arguments.
407 The match is considered successful iff the callable returns
408 a value where ``bool(value)`` is ``False`` (e.g. ``None``).
409 If possible, the operation is cast to its corresponding OpView subclass
410 before being passed to the callable.
411 benefit: The benefit of the pattern, defaulting to 1.)")
412 .def("freeze", &PyRewritePatternSet::freeze,
413 "Freeze the pattern set into a frozen one.");
414
415 //----------------------------------------------------------------------------
416 // Mapping of the PDLResultList and PDLModule
417 //----------------------------------------------------------------------------
418#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
419 nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
420 .def("append",
421 [](PyMlirPDLResultList results, const PyValue &value) {
422 mlirPDLResultListPushBackValue(results, value);
423 })
424 .def("append",
425 [](PyMlirPDLResultList results, const PyOperation &op) {
426 mlirPDLResultListPushBackOperation(results, op);
427 })
428 .def("append",
429 [](PyMlirPDLResultList results, const PyType &type) {
430 mlirPDLResultListPushBackType(results, type);
431 })
432 .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
433 mlirPDLResultListPushBackAttribute(results, attr);
434 });
435 nb::class_<PyPDLPatternModule>(m, "PDLModule")
436 .def(
437 "__init__",
438 [](PyPDLPatternModule &self, PyModule &module) {
439 new (&self) PyPDLPatternModule(
440 mlirPDLPatternModuleFromModule(module.get()));
441 },
442 "module"_a, "Create a PDL module from the given module.")
443 .def(
444 "__init__",
445 [](PyPDLPatternModule &self, PyModule &module) {
446 new (&self) PyPDLPatternModule(
447 mlirPDLPatternModuleFromModule(module.get()));
448 },
449 "module"_a, "Create a PDL module from the given module.")
450 .def(
451 "freeze",
452 [](PyPDLPatternModule &self) {
454 mlirRewritePatternSetFromPDLPatternModule(self.get())));
455 },
456 nb::keep_alive<0, 1>())
457 .def(
458 "register_rewrite_function",
459 [](PyPDLPatternModule &self, const std::string &name,
460 const nb::callable &fn) {
461 self.registerRewriteFunction(name, fn);
462 },
463 nb::keep_alive<1, 3>())
464 .def(
465 "register_constraint_function",
466 [](PyPDLPatternModule &self, const std::string &name,
467 const nb::callable &fn) {
468 self.registerConstraintFunction(name, fn);
469 },
470 nb::keep_alive<1, 3>());
471#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
472
473 nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
474 .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
475 .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
477 "Maximum number of iterations")
478 .def_prop_rw("max_num_rewrites",
481 "Maximum number of rewrites per iteration")
482 .def_prop_rw("use_top_down_traversal",
485 "Whether to use top-down traversal")
486 .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
488 "Enable or disable folding")
489 .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
491 "Rewrite strictness level")
492 .def_prop_rw("region_simplification_level",
495 "Region simplification level")
496 .def_prop_rw("enable_constant_cse",
499 "Enable or disable constant CSE");
500
501 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
502 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
506 m.def(
507 "apply_patterns_and_fold_greedily",
508 [](PyModule &module, PyFrozenRewritePatternSet &set,
509 std::optional<PyGreedyRewriteConfig> config) {
511 module.get(), set.get(),
512 config.has_value() ? config->get()
514 if (mlirLogicalResultIsFailure(status))
515 throw std::runtime_error("pattern application failed to converge");
516 },
517 "module"_a, "set"_a, "config"_a = nb::none(),
518 "Applys the given patterns to the given module greedily while folding "
519 "results.")
520 .def(
521 "apply_patterns_and_fold_greedily",
523 std::optional<PyGreedyRewriteConfig> config) {
525 op.getOperation(), set.get(),
526 config.has_value() ? config->get()
528 if (mlirLogicalResultIsFailure(status))
529 throw std::runtime_error(
530 "pattern application failed to converge");
531 },
532 "op"_a, "set"_a, "config"_a = nb::none(),
533 "Applys the given patterns to the given op greedily while folding "
534 "results.")
535 .def(
536 "walk_and_apply_patterns",
539 },
540 "op"_a, "set"_a,
541 "Applies the given patterns to the given op by a fast walk-based "
542 "driver.");
543}
544} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
545} // namespace python
546} // namespace mlir
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:651
static MlirFrozenRewritePatternSet mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule)
Extracts an MlirFrozenRewritePatternSet from a capsule as produced from mlirPythonFrozenRewritePatter...
Definition Interop.h:302
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static PyObject * mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm)
Creates a capsule object encapsulating the raw C-API MlirFrozenRewritePatternSet.
Definition Interop.h:293
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:280
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1009
Owning Wrapper around a FrozenRewritePatternSet.
Definition Rewrite.cpp:152
static nb::object createFromCapsule(const nb::object &capsule)
Definition Rewrite.cpp:170
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
Definition Rewrite.cpp:155
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
Definition Rewrite.cpp:252
void setStrictness(PyGreedyRewriteStrictness strictness)
Definition Rewrite.cpp:276
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level)
Definition Rewrite.cpp:281
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
Definition Rewrite.cpp:250
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:834
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:483
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:213
MlirModule get()
Gets the backing MlirModule.
Definition IRCore.h:549
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:579
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1376
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:980
void replaceOp(MlirOperation op, const std::vector< MlirValue > &values)
Definition Rewrite.cpp:54
void replaceOp(MlirOperation op, MlirOperation newOp)
Definition Rewrite.cpp:50
void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite)
Definition Rewrite.cpp:191
Wrapper around the generic MlirType.
Definition IRCore.h:876
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
Definition Rewrite.cpp:134
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables constant CSE.
Definition Rewrite.cpp:363
MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter)
Returns the operation right after the current insertion point of the rewriter.
Definition Rewrite.cpp:75
MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config)
Gets the region simplification level.
Definition Rewrite.cpp:402
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether folding is enabled during greedy rewriting.
Definition Rewrite.cpp:383
MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate(void)
GreedyRewriteDriverConfig API.
Definition Rewrite.cpp:299
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(MlirGreedyRewriteDriverConfig config)
Gets whether constant CSE is enabled.
Definition Rewrite.cpp:416
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern)
Add the given MlirRewritePattern into a MlirRewritePatternSet.
Definition Rewrite.cpp:513
MLIR_CAPI_EXPORTED void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns)
Applies the given patterns to the given op by a fast walk-based pattern rewrite driver.
Definition Rewrite.cpp:437
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level)
Sets the region simplification level.
Definition Rewrite.cpp:346
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config)
Definition Rewrite.cpp:422
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config)
Destroys a greedy rewrite driver configuration.
Definition Rewrite.cpp:303
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal)
Sets whether to use top-down traversal for the initial population of the worklist.
Definition Rewrite.cpp:318
MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
Definition Rewrite.cpp:148
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(MlirGreedyRewriteDriverConfig config, int64_t maxIterations)
Sets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:308
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames)
Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::O...
Definition Rewrite.cpp:487
MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
Definition Rewrite.cpp:29
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
Definition Rewrite.cpp:142
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of iterations for the greedy rewrite driver.
Definition Rewrite.cpp:368
MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(MlirGreedyRewriteDriverConfig config)
Gets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:388
MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
Definition Rewrite.cpp:66
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness)
Sets the strictness level for the greedy rewrite driver.
Definition Rewrite.cpp:328
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set)
Destruct the given MlirRewritePatternSet.
Definition Rewrite.cpp:509
MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter)
PatternRewriter API.
Definition Rewrite.cpp:446
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig)
Definition Rewrite.cpp:430
MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(MlirGreedyRewriteDriverConfig config)
Gets whether top-down traversal is used for initial worklist population.
Definition Rewrite.cpp:378
MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context)
RewritePatternSet API.
Definition Rewrite.cpp:505
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites)
Sets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:313
MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(MlirGreedyRewriteDriverConfig config, bool enable)
Enables or disables folding during greedy rewriting.
Definition Rewrite.cpp:323
MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set)
Destroy the given MlirFrozenRewritePatternSet.
Definition Rewrite.cpp:281
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set)
FrozenRewritePatternSet API.
Definition Rewrite.cpp:275
MlirGreedySimplifyRegionLevel
Greedy simplify region levels.
Definition Rewrite.h:51
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED
Disable region control-flow simplification.
Definition Rewrite.h:53
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL
Run the normal simplification (e.g. dead args elimination).
Definition Rewrite.h:55
@ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
Run extra simplifications (e.g. block merging).
Definition Rewrite.h:57
MlirGreedyRewriteStrictness
Greedy rewrite strictness levels.
Definition Rewrite.h:41
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS
Only pre-existing and newly created ops are processed.
Definition Rewrite.h:45
@ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
Only pre-existing ops are processed.
Definition Rewrite.h:47
@ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP
No restrictions wrt. which ops are processed.
Definition Rewrite.h:43
MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(MlirGreedyRewriteDriverConfig config)
Gets the maximum number of rewrites within an iteration.
Definition Rewrite.cpp:373
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1156
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:987
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:84
static MlirLogicalResult mlirLogicalResultFailure(void)
Creates a logical result representing a failure.
Definition Support.h:140
struct MlirLogicalResult MlirLogicalResult
Definition Support.h:121
static MlirLogicalResult mlirLogicalResultSuccess(void)
Creates a logical result representing a success.
Definition Support.h:134
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:129
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
Definition Rewrite.cpp:328
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:199
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
Definition Support.h:118
RewritePattern API.
Definition Rewrite.h:439
MlirLogicalResult(* matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData)
The callback function to match against code rooted at the specified operation, and perform the rewrit...
Definition Rewrite.h:449
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Rewrite.h:442
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Rewrite.h:445
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:75