MLIR  22.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "Globals.h"
12 #include "IRModule.h"
13 #include "mlir-c/Pass.h"
14 // clang-format off
16 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
17 // clang-format on
18 
19 namespace nb = nanobind;
20 using namespace nb::literals;
21 using namespace mlir;
22 using namespace mlir::python;
23 
24 namespace {
25 
26 /// Owning Wrapper around a PassManager.
27 class PyPassManager {
28 public:
29  PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
30  PyPassManager(PyPassManager &&other) noexcept
31  : passManager(other.passManager) {
32  other.passManager.ptr = nullptr;
33  }
34  ~PyPassManager() {
35  if (!mlirPassManagerIsNull(passManager))
36  mlirPassManagerDestroy(passManager);
37  }
38  MlirPassManager get() { return passManager; }
39 
40  void release() { passManager.ptr = nullptr; }
41  nb::object getCapsule() {
42  return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
43  }
44 
45  static nb::object createFromCapsule(const nb::object &capsule) {
46  MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
47  if (mlirPassManagerIsNull(rawPm))
48  throw nb::python_error();
49  return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
50  }
51 
52 private:
53  MlirPassManager passManager;
54 };
55 
56 } // namespace
57 
58 /// Create the `mlir.passmanager` here.
59 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
60  //----------------------------------------------------------------------------
61  // Mapping of enumerated types
62  //----------------------------------------------------------------------------
63  nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
64  .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
65  .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
66 
67  //----------------------------------------------------------------------------
68  // Mapping of MlirExternalPass
69  //----------------------------------------------------------------------------
70  nb::class_<MlirExternalPass>(m, "ExternalPass")
71  .def("signal_pass_failure",
72  [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
73 
74  //----------------------------------------------------------------------------
75  // Mapping of the top-level PassManager
76  //----------------------------------------------------------------------------
77  nb::class_<PyPassManager>(m, "PassManager")
78  .def(
79  "__init__",
80  [](PyPassManager &self, const std::string &anchorOp,
81  DefaultingPyMlirContext context) {
82  MlirPassManager passManager = mlirPassManagerCreateOnOperation(
83  context->get(),
84  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
85  new (&self) PyPassManager(passManager);
86  },
87  "anchor_op"_a = nb::str("any"), "context"_a = nb::none(),
88  // clang-format off
89  nb::sig("def __init__(self, anchor_op: str = 'any', context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> None"),
90  // clang-format on
91  "Create a new PassManager for the current (or provided) Context.")
92  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
93  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
94  .def("_testing_release", &PyPassManager::release,
95  "Releases (leaks) the backing pass manager (testing)")
96  .def(
97  "enable_ir_printing",
98  [](PyPassManager &passManager, bool printBeforeAll,
99  bool printAfterAll, bool printModuleScope, bool printAfterChange,
100  bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
101  std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
102  bool printGenericOpForm,
103  std::optional<std::string> optionalTreePrintingPath) {
104  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
105  if (largeElementsLimit) {
107  *largeElementsLimit);
109  *largeElementsLimit);
110  }
111  if (largeResourceLimit)
113  *largeResourceLimit);
114  if (enableDebugInfo)
115  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
116  /*prettyForm=*/false);
117  if (printGenericOpForm)
119  std::string treePrintingPath = "";
120  if (optionalTreePrintingPath.has_value())
121  treePrintingPath = optionalTreePrintingPath.value();
123  passManager.get(), printBeforeAll, printAfterAll,
124  printModuleScope, printAfterChange, printAfterFailure, flags,
125  mlirStringRefCreate(treePrintingPath.data(),
126  treePrintingPath.size()));
128  },
129  "print_before_all"_a = false, "print_after_all"_a = true,
130  "print_module_scope"_a = false, "print_after_change"_a = false,
131  "print_after_failure"_a = false,
132  "large_elements_limit"_a = nb::none(),
133  "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
134  "print_generic_op_form"_a = false,
135  "tree_printing_dir_path"_a = nb::none(),
136  "Enable IR printing, default as mlir-print-ir-after-all.")
137  .def(
138  "enable_verifier",
139  [](PyPassManager &passManager, bool enable) {
140  mlirPassManagerEnableVerifier(passManager.get(), enable);
141  },
142  "enable"_a, "Enable / disable verify-each.")
143  .def(
144  "enable_timing",
145  [](PyPassManager &passManager) {
146  mlirPassManagerEnableTiming(passManager.get());
147  },
148  "Enable pass timing.")
149  .def(
150  "enable_statistics",
151  [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
152  mlirPassManagerEnableStatistics(passManager.get(), displayMode);
153  },
154  "displayMode"_a =
156  "Enable pass statistics.")
157  .def_static(
158  "parse",
159  [](const std::string &pipeline, DefaultingPyMlirContext context) {
160  MlirPassManager passManager = mlirPassManagerCreate(context->get());
161  PyPrintAccumulator errorMsg;
164  mlirStringRefCreate(pipeline.data(), pipeline.size()),
165  errorMsg.getCallback(), errorMsg.getUserData());
166  if (mlirLogicalResultIsFailure(status))
167  throw nb::value_error(errorMsg.join().c_str());
168  return new PyPassManager(passManager);
169  },
170  "pipeline"_a, "context"_a = nb::none(),
171  // clang-format off
172  nb::sig("def parse(pipeline: str, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> PassManager"),
173  // clang-format on
174  "Parse a textual pass-pipeline and return a top-level PassManager "
175  "that can be applied on a Module. Throw a ValueError if the pipeline "
176  "can't be parsed")
177  .def(
178  "add",
179  [](PyPassManager &passManager, const std::string &pipeline) {
180  PyPrintAccumulator errorMsg;
182  mlirPassManagerGetAsOpPassManager(passManager.get()),
183  mlirStringRefCreate(pipeline.data(), pipeline.size()),
184  errorMsg.getCallback(), errorMsg.getUserData());
185  if (mlirLogicalResultIsFailure(status))
186  throw nb::value_error(errorMsg.join().c_str());
187  },
188  "pipeline"_a,
189  "Add textual pipeline elements to the pass manager. Throws a "
190  "ValueError if the pipeline can't be parsed.")
191  .def(
192  "add",
193  [](PyPassManager &passManager, const nb::callable &run,
194  std::optional<std::string> &name, const std::string &argument,
195  const std::string &description, const std::string &opName) {
196  if (!name.has_value()) {
197  name = nb::cast<std::string>(
198  nb::borrow<nb::str>(run.attr("__name__")));
199  }
200  MlirTypeID passID = PyGlobals::get().allocateTypeID();
201  MlirExternalPassCallbacks callbacks;
202  callbacks.construct = [](void *obj) {
203  (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
204  };
205  callbacks.destruct = [](void *obj) {
206  (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
207  };
208  callbacks.initialize = nullptr;
209  callbacks.clone = [](void *) -> void * {
210  throw std::runtime_error("Cloning Python passes not supported");
211  };
212  callbacks.run = [](MlirOperation op, MlirExternalPass pass,
213  void *userData) {
214  nb::handle(static_cast<PyObject *>(userData))(op, pass);
215  };
216  auto externalPass = mlirCreateExternalPass(
217  passID, mlirStringRefCreate(name->data(), name->length()),
218  mlirStringRefCreate(argument.data(), argument.length()),
219  mlirStringRefCreate(description.data(), description.length()),
220  mlirStringRefCreate(opName.data(), opName.size()),
221  /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
222  callbacks, /*userData*/ run.ptr());
223  mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
224  },
225  "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
226  "description"_a.none() = "", "op_name"_a.none() = "",
227  "Add a python-defined pass to the pass manager.")
228  .def(
229  "run",
230  [](PyPassManager &passManager, PyOperationBase &op) {
231  // Actually run the pass manager.
234  passManager.get(), op.getOperation().get());
235  if (mlirLogicalResultIsFailure(status))
236  throw MLIRError("Failure while executing pass pipeline",
237  errors.take());
238  },
239  "operation"_a,
240  // clang-format off
241  nb::sig("def run(self, operation: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ") -> None"),
242  // clang-format on
243  "Run the pass manager on the provided operation, raising an "
244  "MLIRError on failure.")
245  .def(
246  "__str__",
247  [](PyPassManager &self) {
248  MlirPassManager passManager = self.get();
249  PyPrintAccumulator printAccum;
252  printAccum.getCallback(), printAccum.getUserData());
253  return printAccum.join();
254  },
255  "Print the textual representation for this PassManager, suitable to "
256  "be passed to `parse` for round-tripping.");
257 }
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:219
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:25
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:75
void mlirPassManagerEnableStatistics(MlirPassManager passManager, MlirPassDisplayMode displayMode)
Enable pass statistics.
Definition: Pass.cpp:83
void mlirPassManagerEnableTiming(MlirPassManager passManager)
Enable pass timing.
Definition: Pass.cpp:79
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:34
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:131
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:39
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:43
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:234
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:125
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:107
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:29
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:116
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition: Pass.cpp:48
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
Definition: Interop.h:311
#define MAKE_MLIR_PYTHON_QUALNAME(local)
Definition: Interop.h:57
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
Definition: Interop.h:320
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:292
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:273
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:552
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
MlirOperation get() const
Definition: IRModule.h:638
MlirPassDisplayMode
Enumerated type of pass display modes.
Definition: Pass.h:97
@ MLIR_PASS_DISPLAY_MODE_LIST
Definition: Pass.h:98
@ MLIR_PASS_DISPLAY_MODE_PIPELINE
Definition: Pass.h:99
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
Definition: Pass.h:65
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, intptr_t largeResourceLimit)
Enables the elision of large resources strings by omitting them from the dialect_resources section.
Definition: IR.cpp:215
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:225
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:210
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:206
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:220
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:202
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populatePassManagerSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Structure of external MlirPass callbacks.
Definition: Pass.h:165
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:186
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:182
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:178
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:172
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:168
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1318
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:408