MLIR  20.0.0git
TransformInterpreter.cpp
Go to the documentation of this file.
1 //===- TransformInterpreter.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pybind classes for the transform dialect interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Support.h"
19 
20 namespace nb = nanobind;
21 
22 namespace {
23 struct PyMlirTransformOptions {
24  PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
25  PyMlirTransformOptions(PyMlirTransformOptions &&other) {
26  options = other.options;
27  other.options.ptr = nullptr;
28  }
29  PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
30 
31  ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
32 
33  MlirTransformOptions options;
34 };
35 } // namespace
36 
37 static void populateTransformInterpreterSubmodule(nb::module_ &m) {
38  nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
39  .def(nb::init<>())
40  .def_prop_rw(
41  "expensive_checks",
42  [](const PyMlirTransformOptions &self) {
44  },
45  [](PyMlirTransformOptions &self, bool value) {
47  })
48  .def_prop_rw(
49  "enforce_single_top_level_transform_op",
50  [](const PyMlirTransformOptions &self) {
52  self.options);
53  },
54  [](PyMlirTransformOptions &self, bool value) {
56  value);
57  });
58 
59  m.def(
60  "apply_named_sequence",
61  [](MlirOperation payloadRoot, MlirOperation transformRoot,
62  MlirOperation transformModule, const PyMlirTransformOptions &options) {
64  mlirOperationGetContext(transformRoot));
65 
66  // Calling back into Python to invalidate everything under the payload
67  // root. This is awkward, but we don't have access to PyMlirContext
68  // object here otherwise.
69  nb::object obj = nb::cast(payloadRoot);
70  obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
71 
73  payloadRoot, transformRoot, transformModule, options.options);
74  if (mlirLogicalResultIsSuccess(result))
75  return;
76 
77  throw nb::value_error(
78  ("Failed to apply named transform sequence.\nDiagnostic message " +
79  scope.takeMessage())
80  .c_str());
81  },
82  nb::arg("payload_root"), nb::arg("transform_root"),
83  nb::arg("transform_module"),
84  nb::arg("transform_options") = PyMlirTransformOptions());
85 
86  m.def(
87  "copy_symbols_and_merge_into",
88  [](MlirOperation target, MlirOperation other) {
90  mlirOperationGetContext(target));
91 
92  MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
93  if (mlirLogicalResultIsFailure(result)) {
94  throw nb::value_error(
95  ("Failed to merge symbols.\nDiagnostic message " +
96  scope.takeMessage())
97  .c_str());
98  }
99  },
100  nb::arg("target"), nb::arg("other"));
101 }
102 
103 NB_MODULE(_mlirTransformInterpreter, m) {
104  m.doc() = "MLIR Transform dialect interpreter functionality.";
106 }
static void populateTransformInterpreterSubmodule(nb::module_ &m)
NB_MODULE(_mlirTransformInterpreter, m)
void mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions, bool enable)
Enables or disables expensive checks in transform options.
MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other)
Merge the symbols from other into target, potentially renaming them to avoid conflicts.
MlirLogicalResult mlirTransformApplyNamedSequence(MlirOperation payload, MlirOperation transformRoot, MlirOperation transformModule, MlirTransformOptions transformOptions)
Applies the transformation script starting at the given transform root operation to the given payload...
void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions)
Destroys a transform options object previously created by mlirTransformOptionsCreate.
bool mlirTransformOptionsGetExpensiveChecksEnabled(MlirTransformOptions transformOptions)
Returns true if expensive checks are enabled in transform options.
void mlirTransformOptionsEnforceSingleTopLevelTransformOp(MlirTransformOptions transformOptions, bool enable)
Enables or disables the enforcement of the top-level transform op being single in transform options.
MlirTransformOptions mlirTransformOptionsCreate()
Creates a default-initialized transform options object.
bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(MlirTransformOptions transformOptions)
Returns true if the enforcement of the top-level transform op being single is enabled in transform op...
static llvm::ManagedStatic< PassManagerOptions > options
RAII scope intercepting all diagnostics into a string.
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op)
Gets the context this operation is associated with.
Definition: IR.cpp:515
static bool mlirLogicalResultIsSuccess(MlirLogicalResult res)
Checks if the given logical result represents a success.
Definition: Support.h:122
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
A logical result value, essentially a boolean with named states.
Definition: Support.h:116