MLIR 23.0.0git
DialectTransform.cpp
Go to the documentation of this file.
1//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <string>
10
11#include "IRInterfaces.h"
12#include "Rewrite.h"
14#include "mlir-c/IR.h"
15#include "mlir-c/Support.h"
17#include "nanobind/nanobind.h"
18#include <nanobind/trampoline.h>
19
20namespace nb = nanobind;
22
23namespace mlir {
24namespace python {
26namespace transform {
27
28//===----------------------------------------------------------------------===//
29// TransformRewriter
30//===----------------------------------------------------------------------===//
31class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
32public:
33 static constexpr const char *pyClassName = "TransformRewriter";
34
35 PyTransformRewriter(MlirTransformRewriter rewriter)
37};
38
39//===----------------------------------------------------------------------===//
40// TransformResults
41//===----------------------------------------------------------------------===//
43public:
44 PyTransformResults(MlirTransformResults results) : results(results) {}
45
46 MlirTransformResults get() const { return results; }
47
48 void setOps(PyValue &result, const nb::list &ops) {
49 std::vector<MlirOperation> opsVec;
50 opsVec.reserve(ops.size());
51 for (auto op : ops) {
52 opsVec.push_back(nb::cast<MlirOperation>(op));
53 }
54 mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
55 }
56
57 void setValues(PyValue &result, const nb::list &values) {
58 std::vector<MlirValue> valuesVec;
59 valuesVec.reserve(values.size());
60 for (auto item : values) {
61 valuesVec.push_back(nb::cast<MlirValue>(item));
62 }
63 mlirTransformResultsSetValues(results, result, valuesVec.size(),
64 valuesVec.data());
65 }
66
67 void setParams(PyValue &result, const nb::list &params) {
68 std::vector<MlirAttribute> paramsVec;
69 paramsVec.reserve(params.size());
70 for (auto item : params) {
71 paramsVec.push_back(nb::cast<MlirAttribute>(item));
72 }
73 mlirTransformResultsSetParams(results, result, paramsVec.size(),
74 paramsVec.data());
75 }
76
77 static void bind(nanobind::module_ &m) {
78 nb::class_<PyTransformResults>(m, "TransformResults")
79 .def(nb::init<MlirTransformResults>())
80 .def("set_ops", &PyTransformResults::setOps,
81 "Set the payload operations for a transform result.",
82 nb::arg("result"), nb::arg("ops"))
83 .def("set_values", &PyTransformResults::setValues,
84 "Set the payload values for a transform result.",
85 nb::arg("result"), nb::arg("values"))
86 .def("set_params", &PyTransformResults::setParams,
87 "Set the parameters for a transform result.", nb::arg("result"),
88 nb::arg("params"));
89 }
90
91private:
92 MlirTransformResults results;
93};
94
95//===----------------------------------------------------------------------===//
96// TransformState
97//===----------------------------------------------------------------------===//
99public:
100 PyTransformState(MlirTransformState state) : state(state) {}
101
102 MlirTransformState get() const { return state; }
103
104 static void bind(nanobind::module_ &m) {
105 nb::class_<PyTransformState>(m, "TransformState")
106 .def(nb::init<MlirTransformState>())
107 .def("get_payload_ops", &PyTransformState::getPayloadOps,
108 "Get the payload operations associated with a transform IR value.",
109 nb::arg("operand"))
110 .def("get_payload_values", &PyTransformState::getPayloadValues,
111 "Get the payload values associated with a transform IR value.",
112 nb::arg("operand"))
113 .def("get_params", &PyTransformState::getParams,
114 "Get the parameters (attributes) associated with a transform IR "
115 "value.",
116 nb::arg("operand"));
117 }
118
119private:
120 nanobind::list getPayloadOps(PyValue &value) {
121 nanobind::list result;
123 state, value,
124 [](MlirOperation op, void *userData) {
125 PyMlirContextRef context =
127 auto opview = PyOperation::forOperation(context, op)->createOpView();
128 static_cast<nanobind::list *>(userData)->append(opview);
129 },
130 &result);
131 return result;
132 }
133
134 nanobind::list getPayloadValues(PyValue &value) {
135 nanobind::list result;
137 state, value,
138 [](MlirValue val, void *userData) {
139 static_cast<nanobind::list *>(userData)->append(val);
140 },
141 &result);
142 return result;
143 }
144
145 nanobind::list getParams(PyValue &value) {
146 nanobind::list result;
148 state, value,
149 [](MlirAttribute attr, void *userData) {
150 static_cast<nanobind::list *>(userData)->append(attr);
151 },
152 &result);
153 return result;
154 }
155
156 MlirTransformState state;
157};
158
159//===----------------------------------------------------------------------===//
160// TransformOpInterface
161//===----------------------------------------------------------------------===//
163 : public PyConcreteOpInterface<PyTransformOpInterface> {
164public:
166
167 constexpr static const char *pyClassName = "TransformOpInterface";
170
171 /// Attach a new TransformOpInterface FallbackModel to the named operation.
172 /// The FallbackModel acts as a trampoline for callbacks on the Python class.
173 static void attach(nb::object &target, const std::string &opName,
175 // Prepare the callbacks that will be used by the FallbackModel.
177 // Make the pointer to the Python class available to the callbacks.
178 callbacks.userData = target.ptr();
179 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
180
181 // The above ref bump is all we need as initialization, no need to run the
182 // construct callback.
183 callbacks.construct = nullptr;
184 // Upon the FallbackModel's destruction, drop the ref to the Python class.
185 callbacks.destruct = [](void *userData) {
186 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
187 };
188 // The apply callback which calls into Python.
189 callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
190 MlirTransformResults results, MlirTransformState state,
191 void *userData) -> MlirDiagnosedSilenceableFailure {
192 nb::handle pyClass(static_cast<PyObject *>(userData));
193
194 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
195
196 auto pyRewriter = PyTransformRewriter(rewriter);
197 auto pyResults = PyTransformResults(results);
198 auto pyState = PyTransformState(state);
199
200 // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
201 // staticmethod.
202 PyMlirContextRef context =
204 auto opview = PyOperation::forOperation(context, op)->createOpView();
205 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
206
207 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
208 };
209
210 // The allows_repeated_handle_operands callback which calls into Python.
211 callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
212 void *userData) -> bool {
213 nb::handle pyClass(static_cast<PyObject *>(userData));
214
215 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
216 nb::getattr(pyClass, "allow_repeated_handle_operands"));
217
218 // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
219 // staticmethod.
220 PyMlirContextRef context =
222 auto opview = PyOperation::forOperation(context, op)->createOpView();
223 nb::object res = pyAllowRepeatedHandleOperands(opview);
224
225 return nb::cast<bool>(res);
226 };
227
228 // Attach a FallbackModel, which calls into Python, to the named operation.
230 ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
231 }
232
233 static void bindDerived(ClassTy &cls) {
234 cls.attr("attach") = classmethod(
235 [](const nb::object &cls, const nb::object &opName, nb::object target,
236 DefaultingPyMlirContext context) {
237 if (target.is_none())
238 target = cls;
239 return attach(target, nb::cast<std::string>(opName), context);
240 },
241 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
242 nb::arg("target").none() = nb::none(),
243 nb::arg("context").none() = nb::none(),
244 "Attach the interface subclass to the given operation name.");
245 }
246};
247
248//===-------------------------------------------------------------------===//
249// AnyOpType
250//===-------------------------------------------------------------------===//
251
252struct AnyOpType : PyConcreteType<AnyOpType> {
256 static constexpr const char *pyClassName = "AnyOpType";
258 using Base::Base;
259
260 static void bindDerived(ClassTy &c) {
261 c.def_static(
262 "get",
263 [](DefaultingPyMlirContext context) {
264 return AnyOpType(context->getRef(),
265 mlirTransformAnyOpTypeGet(context.get()->get()));
266 },
267 "Get an instance of AnyOpType in the given context.",
268 nb::arg("context").none() = nb::none());
269 }
270};
271
272//===-------------------------------------------------------------------===//
273// AnyParamType
274//===-------------------------------------------------------------------===//
275
276struct AnyParamType : PyConcreteType<AnyParamType> {
280 static constexpr const char *pyClassName = "AnyParamType";
282 using Base::Base;
283
284 static void bindDerived(ClassTy &c) {
285 c.def_static(
286 "get",
287 [](DefaultingPyMlirContext context) {
288 return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
289 context.get()->get()));
290 },
291 "Get an instance of AnyParamType in the given context.",
292 nb::arg("context").none() = nb::none());
293 }
294};
295
296//===-------------------------------------------------------------------===//
297// AnyValueType
298//===-------------------------------------------------------------------===//
299
300struct AnyValueType : PyConcreteType<AnyValueType> {
304 static constexpr const char *pyClassName = "AnyValueType";
306 using Base::Base;
307
308 static void bindDerived(ClassTy &c) {
309 c.def_static(
310 "get",
311 [](DefaultingPyMlirContext context) {
312 return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
313 context.get()->get()));
314 },
315 "Get an instance of AnyValueType in the given context.",
316 nb::arg("context").none() = nb::none());
317 }
318};
319
320//===-------------------------------------------------------------------===//
321// OperationType
322//===-------------------------------------------------------------------===//
323
324struct OperationType : PyConcreteType<OperationType> {
325 static constexpr IsAFunctionTy isaFunction =
329 static constexpr const char *pyClassName = "OperationType";
331 using Base::Base;
332
333 static void bindDerived(ClassTy &c) {
334 c.def_static(
335 "get",
336 [](const std::string &operationName, DefaultingPyMlirContext context) {
337 MlirStringRef cOperationName =
338 mlirStringRefCreate(operationName.data(), operationName.size());
339 return OperationType(context->getRef(),
341 context.get()->get(), cOperationName));
342 },
343 "Get an instance of OperationType for the given kind in the given "
344 "context",
345 nb::arg("operation_name"), nb::arg("context").none() = nb::none());
346 c.def_prop_ro(
347 "operation_name",
348 [](const OperationType &type) {
349 MlirStringRef operationName =
351 return nb::str(operationName.data, operationName.length);
352 },
353 "Get the name of the payload operation accepted by the handle.");
354 }
355};
356
357//===-------------------------------------------------------------------===//
358// ParamType
359//===-------------------------------------------------------------------===//
360
361struct ParamType : PyConcreteType<ParamType> {
365 static constexpr const char *pyClassName = "ParamType";
367 using Base::Base;
368
369 static void bindDerived(ClassTy &c) {
370 c.def_static(
371 "get",
372 [](const PyType &type, DefaultingPyMlirContext context) {
373 return ParamType(context->getRef(), mlirTransformParamTypeGet(
374 context.get()->get(), type));
375 },
376 "Get an instance of ParamType for the given type in the given context.",
377 nb::arg("type"), nb::arg("context").none() = nb::none());
378 c.def_prop_ro(
379 "type",
380 [](ParamType type) {
381 return PyType(type.getContext(), mlirTransformParamTypeGetType(type))
382 .maybeDownCast();
383 },
384 "Get the type this ParamType is associated with.");
385 }
386};
387
388//===----------------------------------------------------------------------===//
389// MemoryEffectsOpInterface helpers
390//===----------------------------------------------------------------------===//
391
392namespace {
393void onlyReadsHandle(nb::iterable &operands,
395 std::vector<MlirOpOperand> operandsVec;
396 for (auto operand : operands)
397 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
398 mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
399 effects.effects);
400};
401
402void consumesHandle(nb::iterable &operands,
403 PyMemoryEffectsInstanceList effects) {
404 std::vector<MlirOpOperand> operandsVec;
405 for (auto operand : operands)
406 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
407 mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
408 effects.effects);
409};
410
411void producesHandle(nb::iterable &results,
412 PyMemoryEffectsInstanceList effects) {
413 std::vector<MlirValue> resultsVec;
414 for (auto result : results)
415 resultsVec.push_back(nb::cast<PyOpResult>(result).get());
416 mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
417 effects.effects);
418};
419
420void modifiesPayload(PyMemoryEffectsInstanceList effects) {
421 mlirTransformModifiesPayload(effects.effects);
422}
423
424void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
425 mlirTransformOnlyReadsPayload(effects.effects);
426}
427} // namespace
428
429static void populateDialectTransformSubmodule(nb::module_ &m) {
430 nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
432 .value("SilenceableFailure",
434 .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
435
441
446
447 m.def("only_reads_handle", onlyReadsHandle,
448 "Mark operands as only reading handles.", nb::arg("operands"),
449 nb::arg("effects"));
450
451 m.def("consumes_handle", consumesHandle,
452 "Mark operands as consuming handles.", nb::arg("operands"),
453 nb::arg("effects"));
454
455 m.def("produces_handle", producesHandle, "Mark results as producing handles.",
456 nb::arg("results"), nb::arg("effects"));
457
458 m.def("modifies_payload", modifiesPayload,
459 "Mark the transform as modifying the payload.", nb::arg("effects"));
460
461 m.def("only_reads_payload", onlyReadsPayload,
462 "Mark the transform as only reading the payload.", nb::arg("effects"));
463}
464} // namespace transform
465} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
466} // namespace python
467} // namespace mlir
468
469NB_MODULE(_mlirDialectsTransform, m) {
470 m.doc() = "MLIR Transform dialect.";
473}
NB_MODULE(_mlirDialectsTransform, m)
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:279
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:877
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx)
Definition Transform.cpp:37
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void)
Definition Transform.cpp:53
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void)
Returns the interface TypeID of the TransformOpInterface.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type)
Definition Transform.cpp:69
MLIR_CAPI_EXPORTED void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential modifications to the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result, intptr_t numParams, MlirAttribute *params)
Set the parameters for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type)
Definition Transform.cpp:89
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as only reading handles.
MLIR_CAPI_EXPORTED void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value, MlirAttributeCallback callback, void *userData)
Iterate over parameters associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type)
Definition Transform.cpp:49
MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result, intptr_t numOps, MlirOperation *ops)
Set the payload operations for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void)
Definition Transform.cpp:73
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value, MlirValueCallback callback, void *userData)
Iterate over payload values associated with the transform IR value.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type)
MLIR_CAPI_EXPORTED MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName)
Definition Transform.cpp:97
MLIR_CAPI_EXPORTED MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter)
Cast the TransformRewriter to a RewriterBase.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetName(void)
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value, MlirOperationCallback callback, void *userData)
Iterate over payload operations associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type)
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type)
MlirDiagnosedSilenceableFailure
Enum representing the result of a transform operation.
Definition Transform.h:41
@ MlirDiagnosedSilenceableFailureSuccess
The operation succeeded.
Definition Transform.h:43
@ MlirDiagnosedSilenceableFailureDefiniteFailure
The operation failed definitively.
Definition Transform.h:47
@ MlirDiagnosedSilenceableFailureSilenceableFailure
The operation failed in a silenceable way.
Definition Transform.h:45
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void)
Definition Transform.cpp:93
MLIR_CAPI_EXPORTED void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults, MlirMemoryEffectInstancesList effects)
Helper to mark results as producing handles.
MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirTransformOpInterfaceCallbacks callbacks)
Attach TransformOpInterface to the operation with the given name using the provided callbacks.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyValueTypeGetName(void)
Definition Transform.cpp:81
MLIR_CAPI_EXPORTED void mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result, intptr_t numValues, MlirValue *values)
Set the payload values for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx)
Definition Transform.cpp:77
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx)
Definition Transform.cpp:57
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyOpTypeGetName(void)
Definition Transform.cpp:41
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential reads from the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as consuming handles.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyParamTypeGetName(void)
Definition Transform.cpp:61
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void)
Definition Transform.cpp:33
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:198
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1878
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Callbacks for implementing TransformOpInterface from external code.
Definition Transform.h:186
MlirDiagnosedSilenceableFailure(* apply)(MlirOperation op, MlirTransformRewriter rewriter, MlirTransformResults results, MlirTransformState state, void *userData)
Apply callback that implements the transformation.
Definition Transform.h:194
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Transform.h:192
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Transform.h:189
bool(* allowsRepeatedHandleOperands)(MlirOperation op, void *userData)
Callback to check if repeated handle operands are allowed.
Definition Transform.h:200