MLIR 23.0.0git
DialectTransform.cpp
Go to the documentation of this file.
1//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <string>
10
11#include "Rewrite.h"
13#include "mlir-c/IR.h"
14#include "mlir-c/Rewrite.h"
15#include "mlir-c/Support.h"
18#include "nanobind/nanobind.h"
19#include <nanobind/trampoline.h>
20
21namespace nb = nanobind;
23
24namespace mlir {
25namespace python {
27namespace transform {
28
29//===----------------------------------------------------------------------===//
30// TransformRewriter
31//===----------------------------------------------------------------------===//
32class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
33public:
34 static constexpr const char *pyClassName = "TransformRewriter";
35
36 PyTransformRewriter(MlirTransformRewriter rewriter)
38};
39
40//===----------------------------------------------------------------------===//
41// TransformResults
42//===----------------------------------------------------------------------===//
44public:
45 PyTransformResults(MlirTransformResults results) : results(results) {}
46
47 MlirTransformResults get() const { return results; }
48
49 void setOps(PyValue &result, const nb::list &ops) {
50 std::vector<MlirOperation> opsVec;
51 opsVec.reserve(ops.size());
52 for (auto op : ops) {
53 opsVec.push_back(nb::cast<MlirOperation>(op));
54 }
55 mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
56 }
57
58 void setValues(PyValue &result, const nb::list &values) {
59 std::vector<MlirValue> valuesVec;
60 valuesVec.reserve(values.size());
61 for (auto item : values) {
62 valuesVec.push_back(nb::cast<MlirValue>(item));
63 }
64 mlirTransformResultsSetValues(results, result, valuesVec.size(),
65 valuesVec.data());
66 }
67
68 void setParams(PyValue &result, const nb::list &params) {
69 std::vector<MlirAttribute> paramsVec;
70 paramsVec.reserve(params.size());
71 for (auto item : params) {
72 paramsVec.push_back(nb::cast<MlirAttribute>(item));
73 }
74 mlirTransformResultsSetParams(results, result, paramsVec.size(),
75 paramsVec.data());
76 }
77
78 static void bind(nanobind::module_ &m) {
79 nb::class_<PyTransformResults>(m, "TransformResults")
80 .def(nb::init<MlirTransformResults>())
81 .def("set_ops", &PyTransformResults::setOps,
82 "Set the payload operations for a transform result.",
83 nb::arg("result"), nb::arg("ops"))
84 .def("set_values", &PyTransformResults::setValues,
85 "Set the payload values for a transform result.",
86 nb::arg("result"), nb::arg("values"))
87 .def("set_params", &PyTransformResults::setParams,
88 "Set the parameters for a transform result.", nb::arg("result"),
89 nb::arg("params"));
90 }
91
92private:
93 MlirTransformResults results;
94};
95
96//===----------------------------------------------------------------------===//
97// TransformState
98//===----------------------------------------------------------------------===//
100public:
101 PyTransformState(MlirTransformState state) : state(state) {}
102
103 MlirTransformState get() const { return state; }
104
105 static void bind(nanobind::module_ &m) {
106 nb::class_<PyTransformState>(m, "TransformState")
107 .def(nb::init<MlirTransformState>())
108 .def("get_payload_ops", &PyTransformState::getPayloadOps,
109 "Get the payload operations associated with a transform IR value.",
110 nb::arg("operand"))
111 .def("get_payload_values", &PyTransformState::getPayloadValues,
112 "Get the payload values associated with a transform IR value.",
113 nb::arg("operand"))
114 .def("get_params", &PyTransformState::getParams,
115 "Get the parameters (attributes) associated with a transform IR "
116 "value.",
117 nb::arg("operand"));
118 }
119
120private:
121 nanobind::list getPayloadOps(PyValue &value) {
122 nanobind::list result;
124 state, value,
125 [](MlirOperation op, void *userData) {
126 PyMlirContextRef context =
128 auto opview = PyOperation::forOperation(context, op)->createOpView();
129 static_cast<nanobind::list *>(userData)->append(opview);
130 },
131 &result);
132 return result;
133 }
134
135 nanobind::list getPayloadValues(PyValue &value) {
136 nanobind::list result;
138 state, value,
139 [](MlirValue val, void *userData) {
140 static_cast<nanobind::list *>(userData)->append(val);
141 },
142 &result);
143 return result;
144 }
145
146 nanobind::list getParams(PyValue &value) {
147 nanobind::list result;
149 state, value,
150 [](MlirAttribute attr, void *userData) {
151 static_cast<nanobind::list *>(userData)->append(attr);
152 },
153 &result);
154 return result;
155 }
156
157 MlirTransformState state;
158};
159
160//===----------------------------------------------------------------------===//
161// TransformOpInterface
162//===----------------------------------------------------------------------===//
164 : public PyConcreteOpInterface<PyTransformOpInterface> {
165public:
167
168 constexpr static const char *pyClassName = "TransformOpInterface";
171
172 /// Attach a new TransformOpInterface FallbackModel to the named operation.
173 /// The FallbackModel acts as a trampoline for callbacks on the Python class.
174 static void attach(nb::object &target, const std::string &opName,
176 // Prepare the callbacks that will be used by the FallbackModel.
178 // Make the pointer to the Python class available to the callbacks.
179 callbacks.userData = target.ptr();
180 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
181
182 // The above ref bump is all we need as initialization, no need to run the
183 // construct callback.
184 callbacks.construct = nullptr;
185 // Upon the FallbackModel's destruction, drop the ref to the Python class.
186 callbacks.destruct = [](void *userData) {
187 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
188 };
189 // The apply callback which calls into Python.
190 callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
191 MlirTransformResults results, MlirTransformState state,
192 void *userData) -> MlirDiagnosedSilenceableFailure {
193 nb::handle pyClass(static_cast<PyObject *>(userData));
194
195 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
196
197 auto pyRewriter = PyTransformRewriter(rewriter);
198 auto pyResults = PyTransformResults(results);
199 auto pyState = PyTransformState(state);
200
201 // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
202 // staticmethod.
203 PyMlirContextRef context =
205 auto opview = PyOperation::forOperation(context, op)->createOpView();
206 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
207
208 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
209 };
210
211 // The allows_repeated_handle_operands callback which calls into Python.
212 callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
213 void *userData) -> bool {
214 nb::handle pyClass(static_cast<PyObject *>(userData));
215
216 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
217 nb::getattr(pyClass, "allow_repeated_handle_operands"));
218
219 // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
220 // staticmethod.
221 PyMlirContextRef context =
223 auto opview = PyOperation::forOperation(context, op)->createOpView();
224 nb::object res = pyAllowRepeatedHandleOperands(opview);
225
226 return nb::cast<bool>(res);
227 };
228
229 // Attach a FallbackModel, which calls into Python, to the named operation.
231 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
232 callbacks);
233 }
234
235 static void bindDerived(ClassTy &cls) {
236 cls.attr("attach") = classmethod(
237 [](const nb::object &cls, const nb::object &opName, nb::object target,
238 DefaultingPyMlirContext context) {
239 if (target.is_none())
240 target = cls;
241 return attach(target, nb::cast<std::string>(opName), context);
242 },
243 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
244 nb::arg("target").none() = nb::none(),
245 nb::arg("context").none() = nb::none(),
246 "Attach the interface subclass to the given operation name.");
247 }
248};
249
250//===----------------------------------------------------------------------===//
251// PatternDescriptorOpInterface
252//===----------------------------------------------------------------------===//
254 : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
255public:
258
259 constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
262
263 /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
264 /// operation. The FallbackModel acts as a trampoline for callbacks on the
265 /// Python class.
266 static void attach(nb::object &target, const std::string &opName,
268 // Prepare the callbacks that will be used by the FallbackModel.
270 // Make the pointer to the Python class available to the callbacks.
271 callbacks.userData = target.ptr();
272 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
273
274 // The above ref bump is all we need as initialization, no need to run the
275 // construct callback.
276 callbacks.construct = nullptr;
277 // Upon the FallbackModel's destruction, drop the ref to the Python class.
278 callbacks.destruct = [](void *userData) {
279 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
280 };
281
282 // The populatePatterns callback which calls into Python.
283 callbacks.populatePatterns =
284 [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
285 nb::handle pyClass(static_cast<PyObject *>(userData));
286
287 auto pyPopulatePatterns =
288 nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
289
290 auto pyPatterns = PyRewritePatternSet(patterns);
291
292 // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
293 // staticmethod.
294 MlirContext ctx = mlirOperationGetContext(op);
296 auto opview = PyOperation::forOperation(context, op)->createOpView();
297 pyPopulatePatterns(opview, pyPatterns);
298 };
299
300 // The populatePatternsWithState callback which calls into Python.
301 // Check if the Python class has populate_patterns_with_state method.
302 if (nb::hasattr(target, "populate_patterns_with_state")) {
303 callbacks.populatePatternsWithState = [](MlirOperation op,
304 MlirRewritePatternSet patterns,
305 MlirTransformState state,
306 void *userData) {
307 nb::handle pyClass(static_cast<PyObject *>(userData));
308
309 auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
310 nb::getattr(pyClass, "populate_patterns_with_state"));
311
312 auto pyPatterns = PyRewritePatternSet(patterns);
313 auto pyState = PyTransformState(state);
314
315 // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
316 // state)` as a staticmethod.
317 MlirContext ctx = mlirOperationGetContext(op);
319 auto opview = PyOperation::forOperation(context, op)->createOpView();
320 pyPopulatePatternsWithState(opview, pyPatterns, pyState);
321 };
322 } else {
323 // Use default implementation (will call populatePatterns).
324 callbacks.populatePatternsWithState = nullptr;
325 }
326
327 // Attach a FallbackModel, which calls into Python, to the named operation.
329 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
330 callbacks);
331 }
332
333 static void bindDerived(ClassTy &cls) {
334 cls.attr("attach") = classmethod(
335 [](const nb::object &cls, const nb::object &opName, nb::object target,
336 DefaultingPyMlirContext context) {
337 if (target.is_none())
338 target = cls;
339 return attach(target, nb::cast<std::string>(opName), context);
340 },
341 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
342 nb::arg("target").none() = nb::none(),
343 nb::arg("context").none() = nb::none(),
344 "Attach the interface subclass to the given operation name.");
345 }
346};
347
348//===-------------------------------------------------------------------===//
349// AnyOpType
350//===-------------------------------------------------------------------===//
351
352struct AnyOpType : PyConcreteType<AnyOpType> {
356 static constexpr const char *pyClassName = "AnyOpType";
358 using Base::Base;
359
360 static void bindDerived(ClassTy &c) {
361 c.def_static(
362 "get",
363 [](DefaultingPyMlirContext context) {
364 return AnyOpType(context->getRef(),
365 mlirTransformAnyOpTypeGet(context.get()->get()));
366 },
367 "Get an instance of AnyOpType in the given context.",
368 nb::arg("context").none() = nb::none());
369 }
370};
371
372//===-------------------------------------------------------------------===//
373// AnyParamType
374//===-------------------------------------------------------------------===//
375
376struct AnyParamType : PyConcreteType<AnyParamType> {
380 static constexpr const char *pyClassName = "AnyParamType";
382 using Base::Base;
383
384 static void bindDerived(ClassTy &c) {
385 c.def_static(
386 "get",
387 [](DefaultingPyMlirContext context) {
388 return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
389 context.get()->get()));
390 },
391 "Get an instance of AnyParamType in the given context.",
392 nb::arg("context").none() = nb::none());
393 }
394};
395
396//===-------------------------------------------------------------------===//
397// AnyValueType
398//===-------------------------------------------------------------------===//
399
400struct AnyValueType : PyConcreteType<AnyValueType> {
404 static constexpr const char *pyClassName = "AnyValueType";
406 using Base::Base;
407
408 static void bindDerived(ClassTy &c) {
409 c.def_static(
410 "get",
411 [](DefaultingPyMlirContext context) {
412 return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
413 context.get()->get()));
414 },
415 "Get an instance of AnyValueType in the given context.",
416 nb::arg("context").none() = nb::none());
417 }
418};
419
420//===-------------------------------------------------------------------===//
421// OperationType
422//===-------------------------------------------------------------------===//
423
424struct OperationType : PyConcreteType<OperationType> {
425 static constexpr IsAFunctionTy isaFunction =
429 static constexpr const char *pyClassName = "OperationType";
431 using Base::Base;
432
433 static void bindDerived(ClassTy &c) {
434 c.def_static(
435 "get",
436 [](const std::string &operationName, DefaultingPyMlirContext context) {
437 MlirStringRef cOperationName =
438 mlirStringRefCreate(operationName.data(), operationName.size());
439 return OperationType(context->getRef(),
441 context.get()->get(), cOperationName));
442 },
443 "Get an instance of OperationType for the given kind in the given "
444 "context",
445 nb::arg("operation_name"), nb::arg("context").none() = nb::none());
446 c.def_prop_ro(
447 "operation_name",
448 [](const OperationType &type) {
449 MlirStringRef operationName =
451 return nb::str(operationName.data, operationName.length);
452 },
453 "Get the name of the payload operation accepted by the handle.");
454 }
455};
456
457//===-------------------------------------------------------------------===//
458// ParamType
459//===-------------------------------------------------------------------===//
460
461struct ParamType : PyConcreteType<ParamType> {
465 static constexpr const char *pyClassName = "ParamType";
467 using Base::Base;
468
469 static void bindDerived(ClassTy &c) {
470 c.def_static(
471 "get",
472 [](const PyType &type, DefaultingPyMlirContext context) {
473 return ParamType(context->getRef(), mlirTransformParamTypeGet(
474 context.get()->get(), type));
475 },
476 "Get an instance of ParamType for the given type in the given context.",
477 nb::arg("type"), nb::arg("context").none() = nb::none());
478 c.def_prop_ro(
479 "type",
480 [](ParamType type) {
481 return PyType(type.getContext(), mlirTransformParamTypeGetType(type))
482 .maybeDownCast();
483 },
484 "Get the type this ParamType is associated with.");
485 }
486};
487
488//===----------------------------------------------------------------------===//
489// MemoryEffectsOpInterface helpers
490//===----------------------------------------------------------------------===//
491
492namespace {
493void onlyReadsHandle(nb::iterable &operands,
495 std::vector<MlirOpOperand> operandsVec;
496 for (auto operand : operands)
497 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
498 mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
499 effects.effects);
500};
501
502void consumesHandle(nb::iterable &operands,
503 PyMemoryEffectsInstanceList effects) {
504 std::vector<MlirOpOperand> operandsVec;
505 for (auto operand : operands)
506 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
507 mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
508 effects.effects);
509};
510
511void producesHandle(nb::iterable &results,
512 PyMemoryEffectsInstanceList effects) {
513 std::vector<MlirValue> resultsVec;
514 for (auto result : results)
515 resultsVec.push_back(nb::cast<PyOpResult>(result).get());
516 mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
517 effects.effects);
518};
519
520void modifiesPayload(PyMemoryEffectsInstanceList effects) {
521 mlirTransformModifiesPayload(effects.effects);
522}
523
524void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
525 mlirTransformOnlyReadsPayload(effects.effects);
526}
527} // namespace
528
529static void populateDialectTransformSubmodule(nb::module_ &m) {
530 nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
532 .value("SilenceableFailure",
534 .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
535
541
547
548 m.def("only_reads_handle", onlyReadsHandle,
549 "Mark operands as only reading handles.", nb::arg("operands"),
550 nb::arg("effects"));
551
552 m.def("consumes_handle", consumesHandle,
553 "Mark operands as consuming handles.", nb::arg("operands"),
554 nb::arg("effects"));
555
556 m.def("produces_handle", producesHandle, "Mark results as producing handles.",
557 nb::arg("results"), nb::arg("effects"));
558
559 m.def("modifies_payload", modifiesPayload,
560 "Mark the transform as modifying the payload.", nb::arg("effects"));
561
562 m.def("only_reads_payload", onlyReadsPayload,
563 "Mark the transform as only reading the payload.", nb::arg("effects"));
564}
565} // namespace transform
566} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
567} // namespace python
568} // namespace mlir
569
570NB_MODULE(_mlirDialectsTransform, m) {
571 m.doc() = "MLIR Transform dialect.";
574}
NB_MODULE(_mlirDialectsTransform, m)
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:278
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:486
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1377
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:983
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:876
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new PatternDescriptorOpInterface FallbackModel to the named operation.
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx)
Definition Transform.cpp:37
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void)
Definition Transform.cpp:53
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void)
Returns the interface TypeID of the TransformOpInterface.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type)
Definition Transform.cpp:69
MLIR_CAPI_EXPORTED void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential modifications to the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result, intptr_t numParams, MlirAttribute *params)
Set the parameters for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type)
Definition Transform.cpp:89
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as only reading handles.
MLIR_CAPI_EXPORTED void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value, MlirAttributeCallback callback, void *userData)
Iterate over parameters associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type)
Definition Transform.cpp:49
MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirPatternDescriptorOpInterfaceCallbacks callbacks)
Attach PatternDescriptorOpInterface to the operation with the given name using the provided callbacks...
MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result, intptr_t numOps, MlirOperation *ops)
Set the payload operations for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void)
Definition Transform.cpp:73
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value, MlirValueCallback callback, void *userData)
Iterate over payload values associated with the transform IR value.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type)
MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void)
Returns the interface TypeID of the PatternDescriptorOpInterface.
MLIR_CAPI_EXPORTED MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName)
Definition Transform.cpp:97
MLIR_CAPI_EXPORTED MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter)
Cast the TransformRewriter to a RewriterBase.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetName(void)
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value, MlirOperationCallback callback, void *userData)
Iterate over payload operations associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type)
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type)
MlirDiagnosedSilenceableFailure
Enum representing the result of a transform operation.
Definition Transform.h:41
@ MlirDiagnosedSilenceableFailureSuccess
The operation succeeded.
Definition Transform.h:43
@ MlirDiagnosedSilenceableFailureDefiniteFailure
The operation failed definitively.
Definition Transform.h:47
@ MlirDiagnosedSilenceableFailureSilenceableFailure
The operation failed in a silenceable way.
Definition Transform.h:45
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void)
Definition Transform.cpp:93
MLIR_CAPI_EXPORTED void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults, MlirMemoryEffectInstancesList effects)
Helper to mark results as producing handles.
MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirTransformOpInterfaceCallbacks callbacks)
Attach TransformOpInterface to the operation with the given name using the provided callbacks.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyValueTypeGetName(void)
Definition Transform.cpp:81
MLIR_CAPI_EXPORTED void mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result, intptr_t numValues, MlirValue *values)
Set the payload values for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx)
Definition Transform.cpp:77
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx)
Definition Transform.cpp:57
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyOpTypeGetName(void)
Definition Transform.cpp:41
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential reads from the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as consuming handles.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyParamTypeGetName(void)
Definition Transform.cpp:61
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void)
Definition Transform.cpp:33
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:197
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1877
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Callbacks for implementing PatternDescriptorOpInterface from external code.
Definition Transform.h:218
void(* populatePatternsWithState)(MlirOperation op, MlirRewritePatternSet patterns, MlirTransformState state, void *userData)
Optional callback to populate rewrite patterns with transform state.
Definition Transform.h:230
void(* populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns, void *userData)
Callback to populate rewrite patterns into the given pattern set.
Definition Transform.h:226
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Transform.h:221
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Transform.h:224
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Callbacks for implementing TransformOpInterface from external code.
Definition Transform.h:186
MlirDiagnosedSilenceableFailure(* apply)(MlirOperation op, MlirTransformRewriter rewriter, MlirTransformResults results, MlirTransformState state, void *userData)
Apply callback that implements the transformation.
Definition Transform.h:194
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Transform.h:192
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Transform.h:189
bool(* allowsRepeatedHandleOperands)(MlirOperation op, void *userData)
Callback to check if repeated handle operands are allowed.
Definition Transform.h:200