MLIR 23.0.0git
Rewrite.h
Go to the documentation of this file.
1//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
10#define MLIR_BINDINGS_PYTHON_REWRITE_H
11
12#include "mlir-c/Rewrite.h"
14
15#include <nanobind/nanobind.h>
16
17namespace mlir {
18namespace python {
20
21/// CRTP Base class for rewriter wrappers.
22template <typename DerivedTy>
24public:
26 : base(rewriter),
27 ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
28
30 MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
31 MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
32
33 if (mlirOperationIsNull(op)) {
34 MlirOperation owner = mlirBlockGetParentOperation(block);
35 auto parent = PyOperation::forOperation(ctx, owner);
36 return PyInsertionPoint(PyBlock(parent, block));
37 }
38
40 }
41
42 static void bind(nanobind::module_ &m) {
43 nanobind::class_<DerivedTy>(m, DerivedTy::pyClassName)
44 .def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
45 "The current insertion point of the PatternRewriter.")
46 .def(
47 "replace_op",
48 [](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
50 self.base, op.getOperation(), newOp.getOperation());
51 },
52 "Replace an operation with a new operation.", nanobind::arg("op"),
53 nanobind::arg("new_op"))
54 .def(
55 "replace_op",
56 [](DerivedTy &self, PyOperationBase &op,
57 const std::vector<PyValue> &values) {
58 std::vector<MlirValue> values_(values.size());
59 std::copy(values.begin(), values.end(), values_.begin());
61 self.base, op.getOperation(), values_.size(), values_.data());
62 },
63 "Replace an operation with a list of values.", nanobind::arg("op"),
64 nanobind::arg("values"))
65 .def(
66 "erase_op",
67 [](DerivedTy &self, PyOperationBase &op) {
69 },
70 "Erase an operation.", nanobind::arg("op"));
71 }
72
73private:
76};
77
79} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
80} // namespace python
81} // namespace mlir
82
83#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
An insertion point maintains a pointer to a Block and a reference operation.
Definition IRCore.h:833
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:578
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values)
Replace the results of the given (original) operation with the specified list of values (replacements...
Definition Rewrite.cpp:136
MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter)
Returns the operation right after the current insertion point of the rewriter.
Definition Rewrite.cpp:77
MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op)
Erases an operation that is known to have no uses.
Definition Rewrite.cpp:150
MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter)
RewriterBase API inherited from OpBuilder.
Definition Rewrite.cpp:31
MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp)
Replace the results of the given (original) operation with the specified new op (replacement).
Definition Rewrite.cpp:144
MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter)
Return the block the current insertion point belongs to.
Definition Rewrite.cpp:68
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock)
Returns the closest surrounding operation that contains this block.
Definition IR.cpp:990
#define MLIR_PYTHON_API_EXPORTED
Definition Support.h:49
void populateRewriteSubmodule(nb::module_ &m)
Create the mlir.rewrite here.
Definition Rewrite.cpp:467
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:198
Include the generated interface declarations.