17#include "nanobind/nanobind.h"
18#include <nanobind/trampoline.h>
33 static constexpr const char *
pyClassName =
"TransformRewriter";
46 MlirTransformResults
get()
const {
return results; }
49 std::vector<MlirOperation> opsVec;
50 opsVec.reserve(ops.size());
52 opsVec.push_back(nb::cast<MlirOperation>(op));
58 std::vector<MlirValue> valuesVec;
59 valuesVec.reserve(values.size());
60 for (
auto item : values) {
61 valuesVec.push_back(nb::cast<MlirValue>(item));
68 std::vector<MlirAttribute> paramsVec;
69 paramsVec.reserve(params.size());
70 for (
auto item : params) {
71 paramsVec.push_back(nb::cast<MlirAttribute>(item));
77 static void bind(nanobind::module_ &m) {
78 nb::class_<PyTransformResults>(m,
"TransformResults")
79 .def(nb::init<MlirTransformResults>())
81 "Set the payload operations for a transform result.",
82 nb::arg(
"result"), nb::arg(
"ops"))
84 "Set the payload values for a transform result.",
85 nb::arg(
"result"), nb::arg(
"values"))
87 "Set the parameters for a transform result.", nb::arg(
"result"),
92 MlirTransformResults results;
102 MlirTransformState
get()
const {
return state; }
104 static void bind(nanobind::module_ &m) {
105 nb::class_<PyTransformState>(m,
"TransformState")
106 .def(nb::init<MlirTransformState>())
107 .def(
"get_payload_ops", &PyTransformState::getPayloadOps,
108 "Get the payload operations associated with a transform IR value.",
110 .def(
"get_payload_values", &PyTransformState::getPayloadValues,
111 "Get the payload values associated with a transform IR value.",
113 .def(
"get_params", &PyTransformState::getParams,
114 "Get the parameters (attributes) associated with a transform IR "
120 nanobind::list getPayloadOps(
PyValue &value) {
124 [](MlirOperation op,
void *userData) {
128 static_cast<nanobind::list *
>(userData)->append(opview);
134 nanobind::list getPayloadValues(PyValue &value) {
138 [](MlirValue val,
void *userData) {
139 static_cast<nanobind::list *
>(userData)->append(val);
145 nanobind::list getParams(PyValue &value) {
149 [](MlirAttribute attr,
void *userData) {
150 static_cast<nanobind::list *
>(userData)->append(attr);
156 MlirTransformState state;
167 constexpr static const char *
pyClassName =
"TransformOpInterface";
179 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
185 callbacks.
destruct = [](
void *userData) {
186 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
189 callbacks.
apply = [](MlirOperation op, MlirTransformRewriter rewriter,
190 MlirTransformResults results, MlirTransformState state,
192 nb::handle pyClass(
static_cast<PyObject *
>(userData));
194 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass,
"apply"));
205 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
207 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
212 void *userData) ->
bool {
213 nb::handle pyClass(
static_cast<PyObject *
>(userData));
215 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
216 nb::getattr(pyClass,
"allow_repeated_handle_operands"));
223 nb::object res = pyAllowRepeatedHandleOperands(opview);
225 return nb::cast<bool>(res);
230 ctx->
get(),
wrap(StringRef(opName.c_str())), callbacks);
235 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
239 return attach(
target, nb::cast<std::string>(opName), context);
241 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
242 nb::arg(
"target").none() = nb::none(),
243 nb::arg(
"context").none() = nb::none(),
244 "Attach the interface subclass to the given operation name.");
267 "Get an instance of AnyOpType in the given context.",
268 nb::arg(
"context").none() = nb::none());
289 context.
get()->get()));
291 "Get an instance of AnyParamType in the given context.",
292 nb::arg(
"context").none() = nb::none());
313 context.
get()->get()));
315 "Get an instance of AnyValueType in the given context.",
316 nb::arg(
"context").none() = nb::none());
341 context.
get()->get(), cOperationName));
343 "Get an instance of OperationType for the given kind in the given "
345 nb::arg(
"operation_name"), nb::arg(
"context").none() = nb::none());
351 return nb::str(operationName.
data, operationName.
length);
353 "Get the name of the payload operation accepted by the handle.");
374 context.
get()->get(), type));
376 "Get an instance of ParamType for the given type in the given context.",
377 nb::arg(
"type"), nb::arg(
"context").none() = nb::none());
384 "Get the type this ParamType is associated with.");
393void onlyReadsHandle(nb::iterable &operands,
395 std::vector<MlirOpOperand> operandsVec;
396 for (
auto operand : operands)
397 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
402void consumesHandle(nb::iterable &operands,
403 PyMemoryEffectsInstanceList effects) {
404 std::vector<MlirOpOperand> operandsVec;
405 for (
auto operand : operands)
406 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
411void producesHandle(nb::iterable &results,
412 PyMemoryEffectsInstanceList effects) {
413 std::vector<MlirValue> resultsVec;
414 for (
auto result : results)
415 resultsVec.push_back(nb::cast<PyOpResult>(
result).
get());
420void modifiesPayload(PyMemoryEffectsInstanceList effects) {
430 nb::enum_<MlirDiagnosedSilenceableFailure>(m,
"DiagnosedSilenceableFailure")
432 .value(
"SilenceableFailure",
447 m.def(
"only_reads_handle", onlyReadsHandle,
448 "Mark operands as only reading handles.", nb::arg(
"operands"),
451 m.def(
"consumes_handle", consumesHandle,
452 "Mark operands as consuming handles.", nb::arg(
"operands"),
455 m.def(
"produces_handle", producesHandle,
"Mark results as producing handles.",
456 nb::arg(
"results"), nb::arg(
"effects"));
458 m.def(
"modifies_payload", modifiesPayload,
459 "Mark the transform as modifying the payload.", nb::arg(
"effects"));
461 m.def(
"only_reads_payload", onlyReadsPayload,
462 "Mark the transform as only reading the payload.", nb::arg(
"effects"));
470 m.doc() =
"MLIR Transform dialect.";
MlirContext mlirOperationGetContext(MlirOperation op)
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
nanobind::class_< PyTransformOpInterface > ClassTy
nanobind::class_< AnyOpType, PyType > ClassTy
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirType) IsAFunctionTy
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
static void bind(nanobind::module_ &m)
PyRewriterBase(MlirRewriterBase rewriter)
PyType(PyMlirContextRef contextRef, MlirType type)
static constexpr const char * pyClassName
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
static constexpr GetTypeIDFunctionTy getInterfaceID
static void bindDerived(ClassTy &cls)
void setOps(PyValue &result, const nb::list &ops)
MlirTransformResults get() const
static void bind(nanobind::module_ &m)
void setValues(PyValue &result, const nb::list &values)
PyTransformResults(MlirTransformResults results)
void setParams(PyValue &result, const nb::list ¶ms)
static constexpr const char * pyClassName
PyTransformRewriter(MlirTransformRewriter rewriter)
PyTransformState(MlirTransformState state)
MlirTransformState get() const
static void bind(nanobind::module_ &m)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static void populateDialectTransformSubmodule(nb::module_ &m)
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
MlirMemoryEffectInstancesList effects
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr const char * pyClassName
static void bindDerived(ClassTy &c)
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static const MlirStringRef name
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr const char * pyClassName