MLIR  19.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25  PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26  PyPassManager(PyPassManager &&other) noexcept
27  : passManager(other.passManager) {
28  other.passManager.ptr = nullptr;
29  }
30  ~PyPassManager() {
31  if (!mlirPassManagerIsNull(passManager))
32  mlirPassManagerDestroy(passManager);
33  }
34  MlirPassManager get() { return passManager; }
35 
36  void release() { passManager.ptr = nullptr; }
37  pybind11::object getCapsule() {
38  return py::reinterpret_steal<py::object>(
40  }
41 
42  static pybind11::object createFromCapsule(pybind11::object capsule) {
43  MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44  if (mlirPassManagerIsNull(rawPm))
45  throw py::error_already_set();
46  return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47  }
48 
49 private:
50  MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
57  //----------------------------------------------------------------------------
58  // Mapping of the top-level PassManager
59  //----------------------------------------------------------------------------
60  py::class_<PyPassManager>(m, "PassManager", py::module_local())
61  .def(py::init<>([](const std::string &anchorOp,
62  DefaultingPyMlirContext context) {
63  MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64  context->get(),
65  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66  return new PyPassManager(passManager);
67  }),
68  "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69  "Create a new PassManager for the current (or provided) Context.")
70  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71  &PyPassManager::getCapsule)
72  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73  .def("_testing_release", &PyPassManager::release,
74  "Releases (leaks) the backing pass manager (testing)")
75  .def(
76  "enable_ir_printing",
77  [](PyPassManager &passManager) {
78  mlirPassManagerEnableIRPrinting(passManager.get());
79  },
80  "Enable mlir-print-ir-after-all.")
81  .def(
82  "enable_verifier",
83  [](PyPassManager &passManager, bool enable) {
84  mlirPassManagerEnableVerifier(passManager.get(), enable);
85  },
86  "enable"_a, "Enable / disable verify-each.")
87  .def_static(
88  "parse",
89  [](const std::string &pipeline, DefaultingPyMlirContext context) {
90  MlirPassManager passManager = mlirPassManagerCreate(context->get());
91  PyPrintAccumulator errorMsg;
94  mlirStringRefCreate(pipeline.data(), pipeline.size()),
95  errorMsg.getCallback(), errorMsg.getUserData());
96  if (mlirLogicalResultIsFailure(status))
97  throw py::value_error(std::string(errorMsg.join()));
98  return new PyPassManager(passManager);
99  },
100  "pipeline"_a, "context"_a = py::none(),
101  "Parse a textual pass-pipeline and return a top-level PassManager "
102  "that can be applied on a Module. Throw a ValueError if the pipeline "
103  "can't be parsed")
104  .def(
105  "add",
106  [](PyPassManager &passManager, const std::string &pipeline) {
107  PyPrintAccumulator errorMsg;
109  mlirPassManagerGetAsOpPassManager(passManager.get()),
110  mlirStringRefCreate(pipeline.data(), pipeline.size()),
111  errorMsg.getCallback(), errorMsg.getUserData());
112  if (mlirLogicalResultIsFailure(status))
113  throw py::value_error(std::string(errorMsg.join()));
114  },
115  "pipeline"_a,
116  "Add textual pipeline elements to the pass manager. Throws a "
117  "ValueError if the pipeline can't be parsed.")
118  .def(
119  "run",
120  [](PyPassManager &passManager, PyOperationBase &op,
121  bool invalidateOps) {
122  if (invalidateOps) {
123  op.getOperation().getContext()->clearOperationsInside(op);
124  }
125  // Actually run the pass manager.
126  PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
128  passManager.get(), op.getOperation().get());
129  if (mlirLogicalResultIsFailure(status))
130  throw MLIRError("Failure while executing pass pipeline",
131  errors.take());
132  },
133  "operation"_a, "invalidate_ops"_a = true,
134  "Run the pass manager on the provided operation, raising an "
135  "MLIRError on failure.")
136  .def(
137  "__str__",
138  [](PyPassManager &self) {
139  MlirPassManager passManager = self.get();
140  PyPrintAccumulator printAccum;
143  printAccum.getCallback(), printAccum.getUserData());
144  return printAccum.join();
145  },
146  "Print the textual representation for this PassManager, suitable to "
147  "be passed to `parse` for round-tripping.");
148 }
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:51
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:89
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:83
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:74
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager)
Enable mlir-print-ir-after-all.
Definition: Pass.cpp:47
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:96
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:109
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
Definition: Interop.h:290
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
Definition: Interop.h:299
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:284
ReferrentTy * get() const
Definition: PybindUtils.h:47
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:563
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
Definition: Pass.h:65
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populatePassManagerSubmodule(pybind11::module &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1275
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:419