MLIR 22.0.0git
Pass.cpp
Go to the documentation of this file.
1//===- Pass.cpp - Pass Management -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Pass.h"
10
11#include "mlir-c/Pass.h"
14// clang-format off
16#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
17// clang-format on
18
19namespace nb = nanobind;
20using namespace nb::literals;
21using namespace mlir;
23
24namespace mlir {
25namespace python {
27
28/// Owning Wrapper around a PassManager.
30public:
31 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
32 PyPassManager(PyPassManager &&other) noexcept
33 : passManager(other.passManager) {
34 other.passManager.ptr = nullptr;
35 }
37 if (!mlirPassManagerIsNull(passManager))
38 mlirPassManagerDestroy(passManager);
39 }
40 MlirPassManager get() { return passManager; }
41
42 void release() { passManager.ptr = nullptr; }
43 nb::object getCapsule() {
44 return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
45 }
46
47 static nb::object createFromCapsule(const nb::object &capsule) {
48 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
49 if (mlirPassManagerIsNull(rawPm))
50 throw nb::python_error();
51 return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
52 }
53
54private:
55 MlirPassManager passManager;
56};
57
62
64
65/// Create the `mlir.passmanager` here.
66void populatePassManagerSubmodule(nb::module_ &m) {
67 //----------------------------------------------------------------------------
68 // Mapping of enumerated types
69 //----------------------------------------------------------------------------
70 nb::enum_<PyMlirPassDisplayMode>(m, "PassDisplayMode")
71 .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
72 .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
73
74 //----------------------------------------------------------------------------
75 // Mapping of MlirExternalPass
76 //----------------------------------------------------------------------------
77 nb::class_<PyMlirExternalPass>(m, "ExternalPass")
78 .def("signal_pass_failure", [](PyMlirExternalPass pass) {
80 });
81
82 //----------------------------------------------------------------------------
83 // Mapping of the top-level PassManager
84 //----------------------------------------------------------------------------
85 nb::class_<PyPassManager>(m, "PassManager")
86 .def(
87 "__init__",
88 [](PyPassManager &self, const std::string &anchorOp,
90 MlirPassManager passManager = mlirPassManagerCreateOnOperation(
91 context->get(),
92 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
93 new (&self) PyPassManager(passManager);
94 },
95 "anchor_op"_a = nb::str("any"), "context"_a = nb::none(),
96 // clang-format off
97 nb::sig("def __init__(self, anchor_op: str = 'any', context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> None"),
98 // clang-format on
99 "Create a new PassManager for the current (or provided) Context.")
102 .def("_testing_release", &PyPassManager::release,
103 "Releases (leaks) the backing pass manager (testing)")
104 .def(
105 "enable_ir_printing",
106 [](PyPassManager &passManager, bool printBeforeAll,
107 bool printAfterAll, bool printModuleScope, bool printAfterChange,
108 bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
109 std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
110 bool printGenericOpForm,
111 std::optional<std::string> optionalTreePrintingPath) {
112 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
113 if (largeElementsLimit) {
115 *largeElementsLimit);
117 *largeElementsLimit);
118 }
119 if (largeResourceLimit)
121 *largeResourceLimit);
122 if (enableDebugInfo)
123 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
124 /*prettyForm=*/false);
125 if (printGenericOpForm)
127 std::string treePrintingPath = "";
128 if (optionalTreePrintingPath.has_value())
129 treePrintingPath = optionalTreePrintingPath.value();
131 passManager.get(), printBeforeAll, printAfterAll,
132 printModuleScope, printAfterChange, printAfterFailure, flags,
133 mlirStringRefCreate(treePrintingPath.data(),
134 treePrintingPath.size()));
136 },
137 "print_before_all"_a = false, "print_after_all"_a = true,
138 "print_module_scope"_a = false, "print_after_change"_a = false,
139 "print_after_failure"_a = false,
140 "large_elements_limit"_a = nb::none(),
141 "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
142 "print_generic_op_form"_a = false,
143 "tree_printing_dir_path"_a = nb::none(),
144 "Enable IR printing, default as mlir-print-ir-after-all.")
145 .def(
146 "enable_verifier",
147 [](PyPassManager &passManager, bool enable) {
148 mlirPassManagerEnableVerifier(passManager.get(), enable);
149 },
150 "enable"_a, "Enable / disable verify-each.")
151 .def(
152 "enable_timing",
153 [](PyPassManager &passManager) {
154 mlirPassManagerEnableTiming(passManager.get());
155 },
156 "Enable pass timing.")
157 .def(
158 "enable_statistics",
159 [](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) {
161 passManager.get(),
162 static_cast<MlirPassDisplayMode>(displayMode));
163 },
164 "displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE,
165 "Enable pass statistics.")
166 .def_static(
167 "parse",
168 [](const std::string &pipeline, DefaultingPyMlirContext context) {
169 MlirPassManager passManager = mlirPassManagerCreate(context->get());
170 PyPrintAccumulator errorMsg;
173 mlirStringRefCreate(pipeline.data(), pipeline.size()),
174 errorMsg.getCallback(), errorMsg.getUserData());
175 if (mlirLogicalResultIsFailure(status))
176 throw nb::value_error(errorMsg.join().c_str());
177 return new PyPassManager(passManager);
178 },
179 "pipeline"_a, "context"_a = nb::none(),
180 // clang-format off
181 nb::sig("def parse(pipeline: str, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> PassManager"),
182 // clang-format on
183 "Parse a textual pass-pipeline and return a top-level PassManager "
184 "that can be applied on a Module. Throw a ValueError if the pipeline "
185 "can't be parsed")
186 .def(
187 "add",
188 [](PyPassManager &passManager, const std::string &pipeline) {
189 PyPrintAccumulator errorMsg;
192 mlirStringRefCreate(pipeline.data(), pipeline.size()),
193 errorMsg.getCallback(), errorMsg.getUserData());
194 if (mlirLogicalResultIsFailure(status))
195 throw nb::value_error(errorMsg.join().c_str());
196 },
197 "pipeline"_a,
198 "Add textual pipeline elements to the pass manager. Throws a "
199 "ValueError if the pipeline can't be parsed.")
200 .def(
201 "add",
202 [](PyPassManager &passManager, const nb::callable &run,
203 std::optional<std::string> &name, const std::string &argument,
204 const std::string &description, const std::string &opName) {
205 if (!name.has_value()) {
206 name = nb::cast<std::string>(
207 nb::borrow<nb::str>(run.attr("__name__")));
208 }
209 MlirTypeID passID = PyGlobals::get().allocateTypeID();
210 MlirExternalPassCallbacks callbacks;
211 callbacks.construct = [](void *obj) {
212 (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
213 };
214 callbacks.destruct = [](void *obj) {
215 (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
216 };
217 callbacks.initialize = nullptr;
218 callbacks.clone = [](void *) -> void * {
219 throw std::runtime_error("Cloning Python passes not supported");
220 };
221 callbacks.run = [](MlirOperation op, MlirExternalPass pass,
222 void *userData) {
223 nb::handle(static_cast<PyObject *>(userData))(
224 op, PyMlirExternalPass{pass.ptr});
225 };
226 auto externalPass = mlirCreateExternalPass(
227 passID, mlirStringRefCreate(name->data(), name->length()),
228 mlirStringRefCreate(argument.data(), argument.length()),
229 mlirStringRefCreate(description.data(), description.length()),
230 mlirStringRefCreate(opName.data(), opName.size()),
231 /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
232 callbacks, /*userData*/ run.ptr());
233 mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
234 },
235 "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
236 "description"_a.none() = "", "op_name"_a.none() = "",
237 R"(
238 Add a python-defined pass to the current pipeline of the pass manager.
239
240 Args:
241 run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
242 Called when the pass executes. It receives the operation to be processed and
243 the current ``ExternalPass`` instance.
244 Use ``pass_.signal_pass_failure()`` to signal failure.
245 name: The name of the pass. Defaults to ``run.__name__``.
246 argument: The command-line argument for the pass. Defaults to empty.
247 description: The description of the pass. Defaults to empty.
248 op_name: The name of the operation this pass operates on.
249 It will be a generic operation pass if not specified.)")
250 .def(
251 "run",
252 [](PyPassManager &passManager, PyOperationBase &op) {
253 // Actually run the pass manager.
256 passManager.get(), op.getOperation().get());
257 if (mlirLogicalResultIsFailure(status))
258 throw MLIRError("Failure while executing pass pipeline",
259 errors.take());
260 },
261 "operation"_a,
262 // clang-format off
263 nb::sig("def run(self, operation: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ") -> None"),
264 // clang-format on
265 "Run the pass manager on the provided operation, raising an "
266 "MLIRError on failure.")
267 .def(
268 "__str__",
269 [](PyPassManager &self) {
270 MlirPassManager passManager = self.get();
271 PyPrintAccumulator printAccum;
274 printAccum.getCallback(), printAccum.getUserData());
275 return printAccum.join();
276 },
277 "Print the textual representation for this PassManager, suitable to "
278 "be passed to `parse` for round-tripping.");
279}
280} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
281} // namespace python
282} // namespace mlir
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition Pass.cpp:219
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Definition Pass.cpp:131
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Definition Pass.cpp:125
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
Definition IR.cpp:202
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Definition IR.cpp:210
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Definition IR.cpp:220
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Definition IR.cpp:215
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Definition IR.cpp:225
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Definition IR.cpp:206
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition Interop.h:110
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition Interop.h:57
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
Definition Interop.h:320
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
Definition Interop.h:311
ReferrentTy * get() const
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRCore.h:299
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:280
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.cpp:44
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition IRCore.h:578
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
Owning Wrapper around a PassManager.
Definition Pass.cpp:29
PyPassManager(PyPassManager &&other) noexcept
Definition Pass.cpp:32
static nb::object createFromCapsule(const nb::object &capsule)
Definition Pass.cpp:47
MLIR_CAPI_EXPORTED void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition Pass.cpp:75
MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition Pass.cpp:234
MLIR_CAPI_EXPORTED void mlirPassManagerEnableStatistics(MlirPassManager passManager, MlirPassDisplayMode displayMode)
Enable pass statistics.
Definition Pass.cpp:83
MLIR_CAPI_EXPORTED MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition Pass.cpp:39
MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition Pass.cpp:43
MLIR_CAPI_EXPORTED void mlirPassManagerEnableTiming(MlirPassManager passManager)
Enable pass timing.
Definition Pass.cpp:79
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition Pass.cpp:25
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition Pass.cpp:48
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition Pass.cpp:116
MlirPassDisplayMode
Enumerated type of pass display modes.
Definition Pass.h:97
MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition Pass.cpp:107
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
Definition Pass.h:65
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition Pass.cpp:29
MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition Pass.cpp:34
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:84
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:129
void populatePassManagerSubmodule(nb::module_ &m)
Create the mlir.passmanager here.
Definition Pass.cpp:66
Include the generated interface declarations.
A logical result value, essentially a boolean with named states.
Definition Support.h:118
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition IRCore.h:1325
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRCore.h:434
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRCore.h:444