MLIR  20.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Pass.h"
10 
11 #include "IRModule.h"
12 #include "mlir-c/Pass.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15 
16 namespace nb = nanobind;
17 using namespace nb::literals;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 namespace {
22 
23 /// Owning Wrapper around a PassManager.
24 class PyPassManager {
25 public:
26  PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
27  PyPassManager(PyPassManager &&other) noexcept
28  : passManager(other.passManager) {
29  other.passManager.ptr = nullptr;
30  }
31  ~PyPassManager() {
32  if (!mlirPassManagerIsNull(passManager))
33  mlirPassManagerDestroy(passManager);
34  }
35  MlirPassManager get() { return passManager; }
36 
37  void release() { passManager.ptr = nullptr; }
38  nb::object getCapsule() {
39  return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
40  }
41 
42  static nb::object createFromCapsule(nb::object capsule) {
43  MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44  if (mlirPassManagerIsNull(rawPm))
45  throw nb::python_error();
46  return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
47  }
48 
49 private:
50  MlirPassManager passManager;
51 };
52 
53 } // namespace
54 
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
57  //----------------------------------------------------------------------------
58  // Mapping of the top-level PassManager
59  //----------------------------------------------------------------------------
60  nb::class_<PyPassManager>(m, "PassManager")
61  .def(
62  "__init__",
63  [](PyPassManager &self, const std::string &anchorOp,
64  DefaultingPyMlirContext context) {
65  MlirPassManager passManager = mlirPassManagerCreateOnOperation(
66  context->get(),
67  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
68  new (&self) PyPassManager(passManager);
69  },
70  "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
71  "Create a new PassManager for the current (or provided) Context.")
72  .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
73  .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
74  .def("_testing_release", &PyPassManager::release,
75  "Releases (leaks) the backing pass manager (testing)")
76  .def(
77  "enable_ir_printing",
78  [](PyPassManager &passManager, bool printBeforeAll,
79  bool printAfterAll, bool printModuleScope, bool printAfterChange,
80  bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
81  bool enableDebugInfo, bool printGenericOpForm,
82  std::optional<std::string> optionalTreePrintingPath) {
83  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
84  if (largeElementsLimit)
86  *largeElementsLimit);
87  if (enableDebugInfo)
88  mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
89  /*prettyForm=*/false);
90  if (printGenericOpForm)
92  std::string treePrintingPath = "";
93  if (optionalTreePrintingPath.has_value())
94  treePrintingPath = optionalTreePrintingPath.value();
96  passManager.get(), printBeforeAll, printAfterAll,
97  printModuleScope, printAfterChange, printAfterFailure, flags,
98  mlirStringRefCreate(treePrintingPath.data(),
99  treePrintingPath.size()));
101  },
102  "print_before_all"_a = false, "print_after_all"_a = true,
103  "print_module_scope"_a = false, "print_after_change"_a = false,
104  "print_after_failure"_a = false,
105  "large_elements_limit"_a.none() = nb::none(),
106  "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
107  "tree_printing_dir_path"_a.none() = nb::none(),
108  "Enable IR printing, default as mlir-print-ir-after-all.")
109  .def(
110  "enable_verifier",
111  [](PyPassManager &passManager, bool enable) {
112  mlirPassManagerEnableVerifier(passManager.get(), enable);
113  },
114  "enable"_a, "Enable / disable verify-each.")
115  .def_static(
116  "parse",
117  [](const std::string &pipeline, DefaultingPyMlirContext context) {
118  MlirPassManager passManager = mlirPassManagerCreate(context->get());
119  PyPrintAccumulator errorMsg;
122  mlirStringRefCreate(pipeline.data(), pipeline.size()),
123  errorMsg.getCallback(), errorMsg.getUserData());
124  if (mlirLogicalResultIsFailure(status))
125  throw nb::value_error(errorMsg.join().c_str());
126  return new PyPassManager(passManager);
127  },
128  "pipeline"_a, "context"_a.none() = nb::none(),
129  "Parse a textual pass-pipeline and return a top-level PassManager "
130  "that can be applied on a Module. Throw a ValueError if the pipeline "
131  "can't be parsed")
132  .def(
133  "add",
134  [](PyPassManager &passManager, const std::string &pipeline) {
135  PyPrintAccumulator errorMsg;
137  mlirPassManagerGetAsOpPassManager(passManager.get()),
138  mlirStringRefCreate(pipeline.data(), pipeline.size()),
139  errorMsg.getCallback(), errorMsg.getUserData());
140  if (mlirLogicalResultIsFailure(status))
141  throw nb::value_error(errorMsg.join().c_str());
142  },
143  "pipeline"_a,
144  "Add textual pipeline elements to the pass manager. Throws a "
145  "ValueError if the pipeline can't be parsed.")
146  .def(
147  "run",
148  [](PyPassManager &passManager, PyOperationBase &op,
149  bool invalidateOps) {
150  if (invalidateOps) {
152  }
153  // Actually run the pass manager.
156  passManager.get(), op.getOperation().get());
157  if (mlirLogicalResultIsFailure(status))
158  throw MLIRError("Failure while executing pass pipeline",
159  errors.take());
160  },
161  "operation"_a, "invalidate_ops"_a = true,
162  "Run the pass manager on the provided operation, raising an "
163  "MLIRError on failure.")
164  .def(
165  "__str__",
166  [](PyPassManager &self) {
167  MlirPassManager passManager = self.get();
168  PyPrintAccumulator printAccum;
171  printAccum.getCallback(), printAccum.getUserData());
172  return printAccum.join();
173  },
174  "Print the textual representation for this PassManager, suitable to "
175  "be passed to `parse` for round-tripping.");
176 }
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:74
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:112
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:106
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:97
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition: Pass.cpp:47
#define MLIR_PYTHON_CAPI_PTR_ATTR
Attribute on MLIR Python objects that expose their C-API pointer.
Definition: Interop.h:97
#define MLIR_PYTHON_CAPI_FACTORY_ATTR
Attribute on MLIR Python objects that exposes a factory function for constructing the corresponding P...
Definition: Interop.h:110
static PyObject * mlirPythonPassManagerToCapsule(MlirPassManager pm)
Creates a capsule object encapsulating the raw C-API MlirPassManager.
Definition: Interop.h:311
static MlirPassManager mlirPythonCapsuleToPassManager(PyObject *capsule)
Extracts an MlirPassManager from a capsule as produced from mlirPythonPassManagerToCapsule.
Definition: Interop.h:320
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:291
ReferrentTy * get() const
Definition: NanobindUtils.h:55
void clearOperationsInside(PyOperationBase &op)
Clears all operations nested inside the given op using clearOperation(MlirOperation).
Definition: IRCore.cpp:682
Base class for PyOperation and PyOpView which exposes the primary, user visible methods for manipulat...
Definition: IRModule.h:569
virtual PyOperation & getOperation()=0
Each must provide access to the raw Operation.
MlirOperation get() const
Definition: IRModule.h:646
static bool mlirPassManagerIsNull(MlirPassManager passManager)
Checks if a PassManager is null.
Definition: Pass.h:65
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)
Always print operations in the generic form.
Definition: IR.cpp:216
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: IR.cpp:201
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)
Destroys printing flags created with mlirOpPrintingFlagsCreate.
Definition: IR.cpp:197
MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm)
Enable or disable printing of debug information (based on enable).
Definition: IR.cpp:211
MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void)
Creates new printing flags with defaults, intended for customization.
Definition: IR.cpp:193
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populatePassManagerSubmodule(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
Accumulates into a python string from a method that accepts an MlirStringCallback.
MlirStringCallback getCallback()
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1294
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:426