18#include "nanobind/nanobind.h"
19#include <nanobind/trampoline.h>
34 static constexpr const char *
pyClassName =
"TransformRewriter";
47 MlirTransformResults
get()
const {
return results; }
50 std::vector<MlirOperation> opsVec;
51 opsVec.reserve(ops.size());
53 opsVec.push_back(nb::cast<MlirOperation>(op));
59 std::vector<MlirValue> valuesVec;
60 valuesVec.reserve(values.size());
61 for (
auto item : values) {
62 valuesVec.push_back(nb::cast<MlirValue>(item));
69 std::vector<MlirAttribute> paramsVec;
70 paramsVec.reserve(params.size());
71 for (
auto item : params) {
72 paramsVec.push_back(nb::cast<MlirAttribute>(item));
78 static void bind(nanobind::module_ &m) {
79 nb::class_<PyTransformResults>(m,
"TransformResults")
80 .def(nb::init<MlirTransformResults>())
82 "Set the payload operations for a transform result.",
83 nb::arg(
"result"), nb::arg(
"ops"))
85 "Set the payload values for a transform result.",
86 nb::arg(
"result"), nb::arg(
"values"))
88 "Set the parameters for a transform result.", nb::arg(
"result"),
93 MlirTransformResults results;
103 MlirTransformState
get()
const {
return state; }
105 static void bind(nanobind::module_ &m) {
106 nb::class_<PyTransformState>(m,
"TransformState")
107 .def(nb::init<MlirTransformState>())
108 .def(
"get_payload_ops", &PyTransformState::getPayloadOps,
109 "Get the payload operations associated with a transform IR value.",
111 .def(
"get_payload_values", &PyTransformState::getPayloadValues,
112 "Get the payload values associated with a transform IR value.",
114 .def(
"get_params", &PyTransformState::getParams,
115 "Get the parameters (attributes) associated with a transform IR "
121 nanobind::list getPayloadOps(
PyValue &value) {
125 [](MlirOperation op,
void *userData) {
129 static_cast<nanobind::list *
>(userData)->append(opview);
135 nanobind::list getPayloadValues(PyValue &value) {
139 [](MlirValue val,
void *userData) {
140 static_cast<nanobind::list *
>(userData)->append(val);
146 nanobind::list getParams(PyValue &value) {
150 [](MlirAttribute attr,
void *userData) {
151 static_cast<nanobind::list *
>(userData)->append(attr);
157 MlirTransformState state;
168 constexpr static const char *
pyClassName =
"TransformOpInterface";
180 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
186 callbacks.
destruct = [](
void *userData) {
187 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
190 callbacks.
apply = [](MlirOperation op, MlirTransformRewriter rewriter,
191 MlirTransformResults results, MlirTransformState state,
193 nb::handle pyClass(
static_cast<PyObject *
>(userData));
195 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass,
"apply"));
206 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
208 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
213 void *userData) ->
bool {
214 nb::handle pyClass(
static_cast<PyObject *
>(userData));
216 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
217 nb::getattr(pyClass,
"allow_repeated_handle_operands"));
224 nb::object res = pyAllowRepeatedHandleOperands(opview);
226 return nb::cast<bool>(res);
237 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
241 return attach(
target, nb::cast<std::string>(opName), context);
243 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
244 nb::arg(
"target").none() = nb::none(),
245 nb::arg(
"context").none() = nb::none(),
246 "Attach the interface subclass to the given operation name.");
259 constexpr static const char *
pyClassName =
"PatternDescriptorOpInterface";
272 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
278 callbacks.
destruct = [](
void *userData) {
279 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
284 [](MlirOperation op, MlirRewritePatternSet patterns,
void *userData) {
285 nb::handle pyClass(
static_cast<PyObject *
>(userData));
287 auto pyPopulatePatterns =
288 nb::cast<nb::callable>(nb::getattr(pyClass,
"populate_patterns"));
297 pyPopulatePatterns(opview, pyPatterns);
302 if (nb::hasattr(
target,
"populate_patterns_with_state")) {
304 MlirRewritePatternSet patterns,
305 MlirTransformState state,
307 nb::handle pyClass(
static_cast<PyObject *
>(userData));
309 auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
310 nb::getattr(pyClass,
"populate_patterns_with_state"));
320 pyPopulatePatternsWithState(opview, pyPatterns, pyState);
335 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
339 return attach(
target, nb::cast<std::string>(opName), context);
341 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
342 nb::arg(
"target").none() = nb::none(),
343 nb::arg(
"context").none() = nb::none(),
344 "Attach the interface subclass to the given operation name.");
367 "Get an instance of AnyOpType in the given context.",
368 nb::arg(
"context").none() = nb::none());
389 context.
get()->get()));
391 "Get an instance of AnyParamType in the given context.",
392 nb::arg(
"context").none() = nb::none());
413 context.
get()->get()));
415 "Get an instance of AnyValueType in the given context.",
416 nb::arg(
"context").none() = nb::none());
441 context.
get()->get(), cOperationName));
443 "Get an instance of OperationType for the given kind in the given "
445 nb::arg(
"operation_name"), nb::arg(
"context").none() = nb::none());
451 return nb::str(operationName.
data, operationName.
length);
453 "Get the name of the payload operation accepted by the handle.");
474 context.
get()->get(), type));
476 "Get an instance of ParamType for the given type in the given context.",
477 nb::arg(
"type"), nb::arg(
"context").none() = nb::none());
484 "Get the type this ParamType is associated with.");
493void onlyReadsHandle(nb::iterable &operands,
495 std::vector<MlirOpOperand> operandsVec;
496 for (
auto operand : operands)
497 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
502void consumesHandle(nb::iterable &operands,
503 PyMemoryEffectsInstanceList effects) {
504 std::vector<MlirOpOperand> operandsVec;
505 for (
auto operand : operands)
506 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
511void producesHandle(nb::iterable &results,
512 PyMemoryEffectsInstanceList effects) {
513 std::vector<MlirValue> resultsVec;
514 for (
auto result : results)
515 resultsVec.push_back(nb::cast<PyOpResult>(
result).
get());
520void modifiesPayload(PyMemoryEffectsInstanceList effects) {
530 nb::enum_<MlirDiagnosedSilenceableFailure>(m,
"DiagnosedSilenceableFailure")
532 .value(
"SilenceableFailure",
548 m.def(
"only_reads_handle", onlyReadsHandle,
549 "Mark operands as only reading handles.", nb::arg(
"operands"),
552 m.def(
"consumes_handle", consumesHandle,
553 "Mark operands as consuming handles.", nb::arg(
"operands"),
556 m.def(
"produces_handle", producesHandle,
"Mark results as producing handles.",
557 nb::arg(
"results"), nb::arg(
"effects"));
559 m.def(
"modifies_payload", modifiesPayload,
560 "Mark the transform as modifying the payload.", nb::arg(
"effects"));
562 m.def(
"only_reads_payload", onlyReadsPayload,
563 "Mark the transform as only reading the payload.", nb::arg(
"effects"));
571 m.doc() =
"MLIR Transform dialect.";
MlirContext mlirOperationGetContext(MlirOperation op)
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
nanobind::class_< PyTransformOpInterface > ClassTy
nanobind::class_< AnyOpType, PyType > ClassTy
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirType) IsAFunctionTy
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
static void bind(nanobind::module_ &m)
PyRewriterBase(MlirRewriterBase rewriter)
PyType(PyMlirContextRef contextRef, MlirType type)
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new PatternDescriptorOpInterface FallbackModel to the named operation.
static void bindDerived(ClassTy &cls)
static constexpr GetTypeIDFunctionTy getInterfaceID
static constexpr const char * pyClassName
static constexpr const char * pyClassName
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
static constexpr GetTypeIDFunctionTy getInterfaceID
static void bindDerived(ClassTy &cls)
void setOps(PyValue &result, const nb::list &ops)
MlirTransformResults get() const
static void bind(nanobind::module_ &m)
void setValues(PyValue &result, const nb::list &values)
PyTransformResults(MlirTransformResults results)
void setParams(PyValue &result, const nb::list ¶ms)
static constexpr const char * pyClassName
PyTransformRewriter(MlirTransformRewriter rewriter)
PyTransformState(MlirTransformState state)
MlirTransformState get() const
static void bind(nanobind::module_ &m)
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static void populateDialectTransformSubmodule(nb::module_ &m)
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Callbacks for implementing PatternDescriptorOpInterface from external code.
void(* populatePatternsWithState)(MlirOperation op, MlirRewritePatternSet patterns, MlirTransformState state, void *userData)
Optional callback to populate rewrite patterns with transform state.
void(* populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns, void *userData)
Callback to populate rewrite patterns into the given pattern set.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
MlirMemoryEffectInstancesList effects
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr const char * pyClassName
static void bindDerived(ClassTy &c)
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static const MlirStringRef name
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr const char * pyClassName