18#include "nanobind/nanobind.h"
19#include <nanobind/trampoline.h>
34 static constexpr const char *
pyClassName =
"TransformRewriter";
47 MlirTransformResults
get()
const {
return results; }
50 const nb::typed<nb::sequence, PyOperationBase> &ops) {
51 std::vector<MlirOperation> opsVec;
52 opsVec.reserve(nb::len(ops));
54 opsVec.push_back(nb::cast<MlirOperation>(op));
60 const nb::typed<nb::sequence, PyValue> &values) {
61 std::vector<MlirValue> valuesVec;
62 valuesVec.reserve(nb::len(values));
63 for (
auto item : values) {
64 valuesVec.push_back(nb::cast<MlirValue>(item));
71 const nb::typed<nb::sequence, PyAttribute> ¶ms) {
72 std::vector<MlirAttribute> paramsVec;
73 paramsVec.reserve(nb::len(params));
74 for (
auto item : params) {
75 paramsVec.push_back(nb::cast<MlirAttribute>(item));
81 static void bind(nanobind::module_ &m) {
82 nb::class_<PyTransformResults>(m,
"TransformResults")
83 .def(nb::init<MlirTransformResults>())
85 "Set the payload operations for a transform result.",
86 nb::arg(
"result"), nb::arg(
"ops"))
88 "Set the payload values for a transform result.",
89 nb::arg(
"result"), nb::arg(
"values"))
91 "Set the parameters for a transform result.", nb::arg(
"result"),
96 MlirTransformResults results;
106 MlirTransformState
get()
const {
return state; }
108 static void bind(nanobind::module_ &m) {
109 nb::class_<PyTransformState>(m,
"TransformState")
110 .def(nb::init<MlirTransformState>())
111 .def(
"get_payload_ops", &PyTransformState::getPayloadOps,
112 "Get the payload operations associated with a transform IR value.",
114 .def(
"get_payload_values", &PyTransformState::getPayloadValues,
115 "Get the payload values associated with a transform IR value.",
117 .def(
"get_params", &PyTransformState::getParams,
118 "Get the parameters (attributes) associated with a transform IR "
124 nanobind::list getPayloadOps(
PyValue &value) {
128 [](MlirOperation op,
void *userData) {
132 static_cast<nanobind::list *
>(userData)->append(opview);
138 nanobind::list getPayloadValues(PyValue &value) {
142 [](MlirValue val,
void *userData) {
143 static_cast<nanobind::list *
>(userData)->append(val);
149 nanobind::list getParams(PyValue &value) {
153 [](MlirAttribute attr,
void *userData) {
154 static_cast<nanobind::list *
>(userData)->append(attr);
160 MlirTransformState state;
171 constexpr static const char *
pyClassName =
"TransformOpInterface";
183 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
189 callbacks.
destruct = [](
void *userData) {
190 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
193 callbacks.
apply = [](MlirOperation op, MlirTransformRewriter rewriter,
194 MlirTransformResults results, MlirTransformState state,
196 nb::handle pyClass(
static_cast<PyObject *
>(userData));
198 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass,
"apply"));
209 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
211 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
216 void *userData) ->
bool {
217 nb::handle pyClass(
static_cast<PyObject *
>(userData));
219 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
220 nb::getattr(pyClass,
"allow_repeated_handle_operands"));
227 nb::object res = pyAllowRepeatedHandleOperands(opview);
229 return nb::cast<bool>(res);
240 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
244 return attach(
target, nb::cast<std::string>(opName), context);
246 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
247 nb::arg(
"target").none() = nb::none(),
248 nb::arg(
"context").none() = nb::none(),
249 "Attach the interface subclass to the given operation name.");
262 constexpr static const char *
pyClassName =
"PatternDescriptorOpInterface";
275 nb::handle(
static_cast<PyObject *
>(callbacks.
userData)).inc_ref();
281 callbacks.
destruct = [](
void *userData) {
282 nb::handle(
static_cast<PyObject *
>(userData)).dec_ref();
287 [](MlirOperation op, MlirRewritePatternSet patterns,
void *userData) {
288 nb::handle pyClass(
static_cast<PyObject *
>(userData));
290 auto pyPopulatePatterns =
291 nb::cast<nb::callable>(nb::getattr(pyClass,
"populate_patterns"));
300 pyPopulatePatterns(opview, pyPatterns);
305 if (nb::hasattr(
target,
"populate_patterns_with_state")) {
307 MlirRewritePatternSet patterns,
308 MlirTransformState state,
310 nb::handle pyClass(
static_cast<PyObject *
>(userData));
312 auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
313 nb::getattr(pyClass,
"populate_patterns_with_state"));
323 pyPopulatePatternsWithState(opview, pyPatterns, pyState);
338 [](
const nb::object &cls,
const nb::object &opName, nb::object
target,
342 return attach(
target, nb::cast<std::string>(opName), context);
344 nb::arg(
"cls"), nb::arg(
"op_name"), nb::kw_only(),
345 nb::arg(
"target").none() = nb::none(),
346 nb::arg(
"context").none() = nb::none(),
347 "Attach the interface subclass to the given operation name.");
370 "Get an instance of AnyOpType in the given context.",
371 nb::arg(
"context").none() = nb::none());
392 context.
get()->get()));
394 "Get an instance of AnyParamType in the given context.",
395 nb::arg(
"context").none() = nb::none());
416 context.
get()->get()));
418 "Get an instance of AnyValueType in the given context.",
419 nb::arg(
"context").none() = nb::none());
444 context.
get()->get(), cOperationName));
446 "Get an instance of OperationType for the given kind in the given "
448 nb::arg(
"operation_name"), nb::arg(
"context").none() = nb::none());
454 return nb::str(operationName.
data, operationName.
length);
456 "Get the name of the payload operation accepted by the handle.");
477 context.
get()->get(), type));
479 "Get an instance of ParamType for the given type in the given context.",
480 nb::arg(
"type"), nb::arg(
"context").none() = nb::none());
487 "Get the type this ParamType is associated with.");
496void onlyReadsHandle(nb::iterable &operands,
498 std::vector<MlirOpOperand> operandsVec;
499 for (
auto operand : operands)
500 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
505void consumesHandle(nb::iterable &operands,
506 PyMemoryEffectsInstanceList effects) {
507 std::vector<MlirOpOperand> operandsVec;
508 for (
auto operand : operands)
509 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
514void producesHandle(nb::iterable &results,
515 PyMemoryEffectsInstanceList effects) {
516 std::vector<MlirValue> resultsVec;
517 for (
auto result : results)
518 resultsVec.push_back(nb::cast<PyOpResult>(
result).
get());
523void modifiesPayload(PyMemoryEffectsInstanceList effects) {
533 nb::enum_<MlirDiagnosedSilenceableFailure>(m,
"DiagnosedSilenceableFailure")
535 .value(
"SilenceableFailure",
551 m.def(
"only_reads_handle", onlyReadsHandle,
552 "Mark operands as only reading handles.", nb::arg(
"operands"),
555 m.def(
"consumes_handle", consumesHandle,
556 "Mark operands as consuming handles.", nb::arg(
"operands"),
559 m.def(
"produces_handle", producesHandle,
"Mark results as producing handles.",
560 nb::arg(
"results"), nb::arg(
"effects"));
562 m.def(
"modifies_payload", modifiesPayload,
563 "Mark the transform as modifying the payload.", nb::arg(
"effects"));
565 m.def(
"only_reads_payload", onlyReadsPayload,
566 "Mark the transform as only reading the payload.", nb::arg(
"effects"));
574 m.doc() =
"MLIR Transform dialect.";
MlirContext mlirOperationGetContext(MlirOperation op)
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
nanobind::class_< PyTransformOpInterface > ClassTy
nanobind::class_< AnyOpType, PyType > ClassTy
static void bind(nanobind::module_ &m)
MlirTypeID(*)() GetTypeIDFunctionTy
bool(*)(MlirType) IsAFunctionTy
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
nanobind::object createOpView()
Creates an OpView suitable for this operation.
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
static void bind(nanobind::module_ &m)
PyRewriterBase(MlirRewriterBase rewriter)
PyType(PyMlirContextRef contextRef, MlirType type)
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new PatternDescriptorOpInterface FallbackModel to the named operation.
static void bindDerived(ClassTy &cls)
static constexpr GetTypeIDFunctionTy getInterfaceID
static constexpr const char * pyClassName
static constexpr const char * pyClassName
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
static constexpr GetTypeIDFunctionTy getInterfaceID
static void bindDerived(ClassTy &cls)
MlirTransformResults get() const
static void bind(nanobind::module_ &m)
void setOps(PyValue &result, const nb::typed< nb::sequence, PyOperationBase > &ops)
void setValues(PyValue &result, const nb::typed< nb::sequence, PyValue > &values)
PyTransformResults(MlirTransformResults results)
void setParams(PyValue &result, const nb::typed< nb::sequence, PyAttribute > ¶ms)
static constexpr const char * pyClassName
PyTransformRewriter(MlirTransformRewriter rewriter)
PyTransformState(MlirTransformState state)
MlirTransformState get() const
static void bind(nanobind::module_ &m)
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
static void populateDialectTransformSubmodule(nb::module_ &m)
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Callbacks for implementing PatternDescriptorOpInterface from external code.
void(* populatePatternsWithState)(MlirOperation op, MlirRewritePatternSet patterns, MlirTransformState state, void *userData)
Optional callback to populate rewrite patterns with transform state.
void(* populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns, void *userData)
Callback to populate rewrite patterns into the given pattern set.
void(* construct)(void *userData)
Optional constructor for the user data.
void(* destruct)(void *userData)
Optional destructor for the user data.
A pointer to a sized fragment of a string, not necessarily null-terminated.
const char * data
Pointer to the first symbol.
size_t length
Length of the fragment.
MlirMemoryEffectInstancesList effects
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr const char * pyClassName
static void bindDerived(ClassTy &c)
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static const MlirStringRef name
static constexpr IsAFunctionTy isaFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static const MlirStringRef name
static void bindDerived(ClassTy &c)
static constexpr const char * pyClassName
static constexpr GetTypeIDFunctionTy getTypeIdFunction
static constexpr IsAFunctionTy isaFunction
static void bindDerived(ClassTy &c)
static const MlirStringRef name
static constexpr const char * pyClassName