MLIR  20.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
13 #include "mlir-c/Pass.h"
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 namespace {
21 
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25  PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26  PyPassManager(PyPassManager &&other) noexcept
27  : passManager(other.passManager) {
28  other.passManager.ptr = nullptr;
29  }
30  ~PyPassManager() {
31  if (!mlirPassManagerIsNull(passManager))
32  mlirPassManagerDestroy(passManager);
33  }
34  MlirPassManager get() { return passManager; }
35 
36  void release() { passManager.ptr = nullptr; }
37  pybind11::object getCapsule() {
38  return py::reinterpret_steal<py::object>(
40  }
41 
42  static pybind11::object createFromCapsule(pybind11::object capsule) {
43  MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44  if (mlirPassManagerIsNull(rawPm))
45  throw py::error_already_set();
46  return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47  }
48 
49 private:
50  MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
57  //----------------------------------------------------------------------------
58  // Mapping of the top-level PassManager
59  //----------------------------------------------------------------------------
60  py::class_<PyPassManager>(m, "PassManager", py::module_local())
61  .def(py::init<>([](const std::string &anchorOp,
62  DefaultingPyMlirContext context) {
63  MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64  context->get(),
65  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66  return new PyPassManager(passManager);
67  }),
68  "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69  "Create a new PassManager for the current (or provided) Context.")
70  .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71  &PyPassManager::getCapsule)
72  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73  .def("_testing_release", &PyPassManager::release,
74  "Releases (leaks) the backing pass manager (testing)")
75  .def(
76  "enable_ir_printing",
77  [](PyPassManager &passManager, bool printBeforeAll,
78  bool printAfterAll, bool printModuleScope, bool printAfterChange,
79  bool printAfterFailure) {
81  passManager.get(), printBeforeAll, printAfterAll,
82  printModuleScope, printAfterChange, printAfterFailure);
83  },
84  "print_before_all"_a = false, "print_after_all"_a = true,
85  "print_module_scope"_a = false, "print_after_change"_a = false,
86  "print_after_failure"_a = false,
87  "Enable IR printing, default as mlir-print-ir-after-all.")
88  .def(
89  "enable_verifier",
90  [](PyPassManager &passManager, bool enable) {
91  mlirPassManagerEnableVerifier(passManager.get(), enable);
92  },
93  "enable"_a, "Enable / disable verify-each.")
94  .def_static(
95  "parse",
96  [](const std::string &pipeline, DefaultingPyMlirContext context) {
97  MlirPassManager passManager = mlirPassManagerCreate(context->get());
98  PyPrintAccumulator errorMsg;
101  mlirStringRefCreate(pipeline.data(), pipeline.size()),
102  errorMsg.getCallback(), errorMsg.getUserData());
103  if (mlirLogicalResultIsFailure(status))
104  throw py::value_error(std::string(errorMsg.join()));
105  return new PyPassManager(passManager);
106  },
107  "pipeline"_a, "context"_a = py::none(),
108  "Parse a textual pass-pipeline and return a top-level PassManager "
109  "that can be applied on a Module. Throw a ValueError if the pipeline "
110  "can't be parsed")
111  .def(
112  "add",
113  [](PyPassManager &passManager, const std::string &pipeline) {
114  PyPrintAccumulator errorMsg;
116  mlirPassManagerGetAsOpPassManager(passManager.get()),
117  mlirStringRefCreate(pipeline.data(), pipeline.size()),
118  errorMsg.getCallback(), errorMsg.getUserData());
119  if (mlirLogicalResultIsFailure(status))
120  throw py::value_error(std::string(errorMsg.join()));
121  },
122  "pipeline"_a,
123  "Add textual pipeline elements to the pass manager. Throws a "
124  "ValueError if the pipeline can't be parsed.")
125  .def(
126  "run",
127  [](PyPassManager &passManager, PyOperationBase &op,
128  bool invalidateOps) {
129  if (invalidateOps) {
130  op.getOperation().getContext()->clearOperationsInside(op);
131  }
132  // Actually run the pass manager.
133  PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
135  passManager.get(), op.getOperation().get());
136  if (mlirLogicalResultIsFailure(status))
137  throw MLIRError("Failure while executing pass pipeline",
138  errors.take());
139  },
140  "operation"_a, "invalidate_ops"_a = true,
141  "Run the pass manager on the provided operation, raising an "
142  "MLIRError on failure.")
143  .def(
144  "__str__",
145  [](PyPassManager &self) {
146  MlirPassManager passManager = self.get();
147  PyPrintAccumulator printAccum;
150  printAccum.getCallback(), printAccum.getUserData());
151  return printAccum.join();
152  },
153  "Print the textual representation for this PassManager, suitable to "
154  "be passed to `parse` for round-tripping.");
155 }
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:64
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:102
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:96
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:87
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure)
Enable IR printing.
Definition: Pass.cpp:47
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
Definition: Interop.h:311
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
Definition: Interop.h:320
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:292
ReferrentTy * get() const
Definition: PybindUtils.h:47
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:571
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
Definition: Pass.h:65
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populatePassManagerSubmodule(pybind11::module &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Accumulates into a python string from a method that accepts an MlirStringCallback.
Definition: PybindUtils.h:102
pybind11::str join()
Definition: PybindUtils.h:117
MlirStringCallback getCallback()
Definition: PybindUtils.h:107
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1284
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:427