20using namespace nb::literals;
29 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
30 PyPassManager(PyPassManager &&other) noexcept
31 : passManager(other.passManager) {
32 other.passManager.ptr =
nullptr;
38 MlirPassManager
get() {
return passManager; }
40 void release() { passManager.ptr =
nullptr; }
41 nb::object getCapsule() {
45 static nb::object createFromCapsule(
const nb::object &capsule) {
48 throw nb::python_error();
49 return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
53 MlirPassManager passManager;
63 nb::enum_<MlirPassDisplayMode>(m,
"PassDisplayMode")
70 nb::class_<MlirExternalPass>(m,
"ExternalPass")
71 .def(
"signal_pass_failure",
77 nb::class_<PyPassManager>(m,
"PassManager")
80 [](PyPassManager &self,
const std::string &anchorOp,
85 new (&self) PyPassManager(passManager);
87 "anchor_op"_a = nb::str(
"any"),
"context"_a = nb::none(),
89 nb::sig(
"def __init__(self, anchor_op: str = 'any', context: " MAKE_MLIR_PYTHON_QUALNAME(
"ir.Context")
" | None = None) -> None"),
91 "Create a new PassManager for the current (or provided) Context.")
94 .def(
"_testing_release", &PyPassManager::release,
95 "Releases (leaks) the backing pass manager (testing)")
98 [](PyPassManager &passManager,
bool printBeforeAll,
99 bool printAfterAll,
bool printModuleScope,
bool printAfterChange,
100 bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
101 std::optional<int64_t> largeResourceLimit,
bool enableDebugInfo,
102 bool printGenericOpForm,
103 std::optional<std::string> optionalTreePrintingPath) {
105 if (largeElementsLimit) {
107 *largeElementsLimit);
109 *largeElementsLimit);
111 if (largeResourceLimit)
113 *largeResourceLimit);
117 if (printGenericOpForm)
119 std::string treePrintingPath =
"";
120 if (optionalTreePrintingPath.has_value())
121 treePrintingPath = optionalTreePrintingPath.value();
123 passManager.get(), printBeforeAll, printAfterAll,
124 printModuleScope, printAfterChange, printAfterFailure, flags,
126 treePrintingPath.size()));
129 "print_before_all"_a =
false,
"print_after_all"_a =
true,
130 "print_module_scope"_a =
false,
"print_after_change"_a =
false,
131 "print_after_failure"_a =
false,
132 "large_elements_limit"_a = nb::none(),
133 "large_resource_limit"_a = nb::none(),
"enable_debug_info"_a =
false,
134 "print_generic_op_form"_a =
false,
135 "tree_printing_dir_path"_a = nb::none(),
136 "Enable IR printing, default as mlir-print-ir-after-all.")
139 [](PyPassManager &passManager,
bool enable) {
142 "enable"_a,
"Enable / disable verify-each.")
145 [](PyPassManager &passManager) {
148 "Enable pass timing.")
156 "Enable pass statistics.")
167 throw nb::value_error(errorMsg.
join().c_str());
168 return new PyPassManager(passManager);
170 "pipeline"_a,
"context"_a = nb::none(),
174 "Parse a textual pass-pipeline and return a top-level PassManager "
175 "that can be applied on a Module. Throw a ValueError if the pipeline "
179 [](PyPassManager &passManager,
const std::string &pipeline) {
186 throw nb::value_error(errorMsg.
join().c_str());
189 "Add textual pipeline elements to the pass manager. Throws a "
190 "ValueError if the pipeline can't be parsed.")
193 [](PyPassManager &passManager,
const nb::callable &run,
194 std::optional<std::string> &name,
const std::string &argument,
195 const std::string &description,
const std::string &opName) {
196 if (!name.has_value()) {
197 name = nb::cast<std::string>(
198 nb::borrow<nb::str>(run.attr(
"__name__")));
201 MlirExternalPassCallbacks callbacks;
202 callbacks.construct = [](
void *obj) {
203 (
void)nb::handle(
static_cast<PyObject *
>(obj)).inc_ref();
205 callbacks.destruct = [](
void *obj) {
206 (
void)nb::handle(
static_cast<PyObject *
>(obj)).dec_ref();
208 callbacks.initialize =
nullptr;
209 callbacks.clone = [](
void *) ->
void * {
210 throw std::runtime_error(
"Cloning Python passes not supported");
212 callbacks.run = [](MlirOperation op, MlirExternalPass pass,
214 nb::handle(
static_cast<PyObject *
>(userData))(op, pass);
222 callbacks, run.ptr());
225 "run"_a,
"name"_a.none() = nb::none(),
"argument"_a.none() =
"",
226 "description"_a.none() =
"",
"op_name"_a.none() =
"",
227 "Add a python-defined pass to the pass manager.")
236 throw MLIRError(
"Failure while executing pass pipeline",
243 "Run the pass manager on the provided operation, raising an "
244 "MLIRError on failure.")
247 [](PyPassManager &self) {
248 MlirPassManager passManager = self.get();
253 return printAccum.
join();
255 "Print the textual representation for this PassManager, suitable to "
256 "be passed to `parse` for round-tripping.");
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
MlirOpPrintingFlags mlirOpPrintingFlagsCreate()
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
#define MAKE_MLIR_PYTHON_QUALNAME(local)
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
PyMlirContextRef & getContext()
Accesses the context reference.
Used in function arguments when None should resolve to the current context manager set instance.
ReferrentTy * get() const
static PyGlobals & get()
Most code should get the globals via this static accessor.
MlirTypeID allocateTypeID()
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
MlirOperation get() const
MLIR_CAPI_EXPORTED void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
MLIR_CAPI_EXPORTED void mlirPassManagerEnableStatistics(MlirPassManager passManager, MlirPassDisplayMode displayMode)
Enable pass statistics.
MLIR_CAPI_EXPORTED MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
MLIR_CAPI_EXPORTED void mlirPassManagerEnableTiming(MlirPassManager passManager)
Enable pass timing.
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
MlirPassDisplayMode
Enumerated type of pass display modes.
@ MLIR_PASS_DISPLAY_MODE_LIST
@ MLIR_PASS_DISPLAY_MODE_PIPELINE
MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
void populatePassManagerSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
RAII object that captures any error diagnostics emitted to the provided context.