MLIR 23.0.0git
DialectTransform.cpp
Go to the documentation of this file.
1//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <string>
10
11#include "Rewrite.h"
13#include "mlir-c/IR.h"
14#include "mlir-c/Rewrite.h"
15#include "mlir-c/Support.h"
18#include "nanobind/nanobind.h"
19#include <nanobind/trampoline.h>
20
21namespace nb = nanobind;
23
24namespace mlir {
25namespace python {
27namespace transform {
28
29//===----------------------------------------------------------------------===//
30// TransformRewriter
31//===----------------------------------------------------------------------===//
32class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
33public:
34 static constexpr const char *pyClassName = "TransformRewriter";
35
36 PyTransformRewriter(MlirTransformRewriter rewriter)
38};
39
40//===----------------------------------------------------------------------===//
41// TransformResults
42//===----------------------------------------------------------------------===//
44public:
45 PyTransformResults(MlirTransformResults results) : results(results) {}
46
47 MlirTransformResults get() const { return results; }
48
50 const nb::typed<nb::sequence, PyOperationBase> &ops) {
51 std::vector<MlirOperation> opsVec;
52 opsVec.reserve(nb::len(ops));
53 for (auto op : ops) {
54 opsVec.push_back(nb::cast<MlirOperation>(op));
55 }
56 mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
57 }
58
60 const nb::typed<nb::sequence, PyValue> &values) {
61 std::vector<MlirValue> valuesVec;
62 valuesVec.reserve(nb::len(values));
63 for (auto item : values) {
64 valuesVec.push_back(nb::cast<MlirValue>(item));
65 }
66 mlirTransformResultsSetValues(results, result, valuesVec.size(),
67 valuesVec.data());
68 }
69
71 const nb::typed<nb::sequence, PyAttribute> &params) {
72 std::vector<MlirAttribute> paramsVec;
73 paramsVec.reserve(nb::len(params));
74 for (auto item : params) {
75 paramsVec.push_back(nb::cast<MlirAttribute>(item));
76 }
77 mlirTransformResultsSetParams(results, result, paramsVec.size(),
78 paramsVec.data());
79 }
80
81 static void bind(nanobind::module_ &m) {
82 nb::class_<PyTransformResults>(m, "TransformResults")
83 .def(nb::init<MlirTransformResults>())
84 .def("set_ops", &PyTransformResults::setOps,
85 "Set the payload operations for a transform result.",
86 nb::arg("result"), nb::arg("ops"))
87 .def("set_values", &PyTransformResults::setValues,
88 "Set the payload values for a transform result.",
89 nb::arg("result"), nb::arg("values"))
90 .def("set_params", &PyTransformResults::setParams,
91 "Set the parameters for a transform result.", nb::arg("result"),
92 nb::arg("params"));
93 }
94
95private:
96 MlirTransformResults results;
97};
98
99//===----------------------------------------------------------------------===//
100// TransformState
101//===----------------------------------------------------------------------===//
103public:
104 PyTransformState(MlirTransformState state) : state(state) {}
105
106 MlirTransformState get() const { return state; }
107
108 static void bind(nanobind::module_ &m) {
109 nb::class_<PyTransformState>(m, "TransformState")
110 .def(nb::init<MlirTransformState>())
111 .def("get_payload_ops", &PyTransformState::getPayloadOps,
112 "Get the payload operations associated with a transform IR value.",
113 nb::arg("operand"))
114 .def("get_payload_values", &PyTransformState::getPayloadValues,
115 "Get the payload values associated with a transform IR value.",
116 nb::arg("operand"))
117 .def("get_params", &PyTransformState::getParams,
118 "Get the parameters (attributes) associated with a transform IR "
119 "value.",
120 nb::arg("operand"));
121 }
122
123private:
124 nanobind::list getPayloadOps(PyValue &value) {
125 nanobind::list result;
127 state, value,
128 [](MlirOperation op, void *userData) {
129 PyMlirContextRef context =
131 auto opview = PyOperation::forOperation(context, op)->createOpView();
132 static_cast<nanobind::list *>(userData)->append(opview);
133 },
134 &result);
135 return result;
136 }
137
138 nanobind::list getPayloadValues(PyValue &value) {
139 nanobind::list result;
141 state, value,
142 [](MlirValue val, void *userData) {
143 static_cast<nanobind::list *>(userData)->append(val);
144 },
145 &result);
146 return result;
147 }
148
149 nanobind::list getParams(PyValue &value) {
150 nanobind::list result;
152 state, value,
153 [](MlirAttribute attr, void *userData) {
154 static_cast<nanobind::list *>(userData)->append(attr);
155 },
156 &result);
157 return result;
158 }
159
160 MlirTransformState state;
161};
162
163//===----------------------------------------------------------------------===//
164// TransformOpInterface
165//===----------------------------------------------------------------------===//
167 : public PyConcreteOpInterface<PyTransformOpInterface> {
168public:
170
171 constexpr static const char *pyClassName = "TransformOpInterface";
174
175 /// Attach a new TransformOpInterface FallbackModel to the named operation.
176 /// The FallbackModel acts as a trampoline for callbacks on the Python class.
177 static void attach(nb::object &target, const std::string &opName,
179 // Prepare the callbacks that will be used by the FallbackModel.
181 // Make the pointer to the Python class available to the callbacks.
182 callbacks.userData = target.ptr();
183 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
184
185 // The above ref bump is all we need as initialization, no need to run the
186 // construct callback.
187 callbacks.construct = nullptr;
188 // Upon the FallbackModel's destruction, drop the ref to the Python class.
189 callbacks.destruct = [](void *userData) {
190 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
191 };
192 // The apply callback which calls into Python.
193 callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
194 MlirTransformResults results, MlirTransformState state,
195 void *userData) -> MlirDiagnosedSilenceableFailure {
196 nb::handle pyClass(static_cast<PyObject *>(userData));
197
198 auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
199
200 auto pyRewriter = PyTransformRewriter(rewriter);
201 auto pyResults = PyTransformResults(results);
202 auto pyState = PyTransformState(state);
203
204 // Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
205 // staticmethod.
206 PyMlirContextRef context =
208 auto opview = PyOperation::forOperation(context, op)->createOpView();
209 nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
210
211 return nb::cast<MlirDiagnosedSilenceableFailure>(res);
212 };
213
214 // The allows_repeated_handle_operands callback which calls into Python.
215 callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
216 void *userData) -> bool {
217 nb::handle pyClass(static_cast<PyObject *>(userData));
218
219 auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
220 nb::getattr(pyClass, "allow_repeated_handle_operands"));
221
222 // Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
223 // staticmethod.
224 PyMlirContextRef context =
226 auto opview = PyOperation::forOperation(context, op)->createOpView();
227 nb::object res = pyAllowRepeatedHandleOperands(opview);
228
229 return nb::cast<bool>(res);
230 };
231
232 // Attach a FallbackModel, which calls into Python, to the named operation.
234 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
235 callbacks);
236 }
237
238 static void bindDerived(ClassTy &cls) {
239 cls.attr("attach") = classmethod(
240 [](const nb::object &cls, const nb::object &opName, nb::object target,
241 DefaultingPyMlirContext context) {
242 if (target.is_none())
243 target = cls;
244 return attach(target, nb::cast<std::string>(opName), context);
245 },
246 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
247 nb::arg("target").none() = nb::none(),
248 nb::arg("context").none() = nb::none(),
249 "Attach the interface subclass to the given operation name.");
250 }
251};
252
253//===----------------------------------------------------------------------===//
254// PatternDescriptorOpInterface
255//===----------------------------------------------------------------------===//
257 : public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
258public:
261
262 constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
265
266 /// Attach a new PatternDescriptorOpInterface FallbackModel to the named
267 /// operation. The FallbackModel acts as a trampoline for callbacks on the
268 /// Python class.
269 static void attach(nb::object &target, const std::string &opName,
271 // Prepare the callbacks that will be used by the FallbackModel.
273 // Make the pointer to the Python class available to the callbacks.
274 callbacks.userData = target.ptr();
275 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
276
277 // The above ref bump is all we need as initialization, no need to run the
278 // construct callback.
279 callbacks.construct = nullptr;
280 // Upon the FallbackModel's destruction, drop the ref to the Python class.
281 callbacks.destruct = [](void *userData) {
282 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
283 };
284
285 // The populatePatterns callback which calls into Python.
286 callbacks.populatePatterns =
287 [](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
288 nb::handle pyClass(static_cast<PyObject *>(userData));
289
290 auto pyPopulatePatterns =
291 nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
292
293 auto pyPatterns = PyRewritePatternSet(patterns);
294
295 // Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
296 // staticmethod.
297 MlirContext ctx = mlirOperationGetContext(op);
299 auto opview = PyOperation::forOperation(context, op)->createOpView();
300 pyPopulatePatterns(opview, pyPatterns);
301 };
302
303 // The populatePatternsWithState callback which calls into Python.
304 // Check if the Python class has populate_patterns_with_state method.
305 if (nb::hasattr(target, "populate_patterns_with_state")) {
306 callbacks.populatePatternsWithState = [](MlirOperation op,
307 MlirRewritePatternSet patterns,
308 MlirTransformState state,
309 void *userData) {
310 nb::handle pyClass(static_cast<PyObject *>(userData));
311
312 auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
313 nb::getattr(pyClass, "populate_patterns_with_state"));
314
315 auto pyPatterns = PyRewritePatternSet(patterns);
316 auto pyState = PyTransformState(state);
317
318 // Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
319 // state)` as a staticmethod.
320 MlirContext ctx = mlirOperationGetContext(op);
322 auto opview = PyOperation::forOperation(context, op)->createOpView();
323 pyPopulatePatternsWithState(opview, pyPatterns, pyState);
324 };
325 } else {
326 // Use default implementation (will call populatePatterns).
327 callbacks.populatePatternsWithState = nullptr;
328 }
329
330 // Attach a FallbackModel, which calls into Python, to the named operation.
332 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
333 callbacks);
334 }
335
336 static void bindDerived(ClassTy &cls) {
337 cls.attr("attach") = classmethod(
338 [](const nb::object &cls, const nb::object &opName, nb::object target,
339 DefaultingPyMlirContext context) {
340 if (target.is_none())
341 target = cls;
342 return attach(target, nb::cast<std::string>(opName), context);
343 },
344 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
345 nb::arg("target").none() = nb::none(),
346 nb::arg("context").none() = nb::none(),
347 "Attach the interface subclass to the given operation name.");
348 }
349};
350
351//===-------------------------------------------------------------------===//
352// AnyOpType
353//===-------------------------------------------------------------------===//
354
355struct AnyOpType : PyConcreteType<AnyOpType> {
359 static constexpr const char *pyClassName = "AnyOpType";
361 using Base::Base;
362
363 static void bindDerived(ClassTy &c) {
364 c.def_static(
365 "get",
366 [](DefaultingPyMlirContext context) {
367 return AnyOpType(context->getRef(),
368 mlirTransformAnyOpTypeGet(context.get()->get()));
369 },
370 "Get an instance of AnyOpType in the given context.",
371 nb::arg("context").none() = nb::none());
372 }
373};
374
375//===-------------------------------------------------------------------===//
376// AnyParamType
377//===-------------------------------------------------------------------===//
378
379struct AnyParamType : PyConcreteType<AnyParamType> {
383 static constexpr const char *pyClassName = "AnyParamType";
385 using Base::Base;
386
387 static void bindDerived(ClassTy &c) {
388 c.def_static(
389 "get",
390 [](DefaultingPyMlirContext context) {
391 return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
392 context.get()->get()));
393 },
394 "Get an instance of AnyParamType in the given context.",
395 nb::arg("context").none() = nb::none());
396 }
397};
398
399//===-------------------------------------------------------------------===//
400// AnyValueType
401//===-------------------------------------------------------------------===//
402
403struct AnyValueType : PyConcreteType<AnyValueType> {
407 static constexpr const char *pyClassName = "AnyValueType";
409 using Base::Base;
410
411 static void bindDerived(ClassTy &c) {
412 c.def_static(
413 "get",
414 [](DefaultingPyMlirContext context) {
415 return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
416 context.get()->get()));
417 },
418 "Get an instance of AnyValueType in the given context.",
419 nb::arg("context").none() = nb::none());
420 }
421};
422
423//===-------------------------------------------------------------------===//
424// OperationType
425//===-------------------------------------------------------------------===//
426
427struct OperationType : PyConcreteType<OperationType> {
428 static constexpr IsAFunctionTy isaFunction =
432 static constexpr const char *pyClassName = "OperationType";
434 using Base::Base;
435
436 static void bindDerived(ClassTy &c) {
437 c.def_static(
438 "get",
439 [](const std::string &operationName, DefaultingPyMlirContext context) {
440 MlirStringRef cOperationName =
441 mlirStringRefCreate(operationName.data(), operationName.size());
442 return OperationType(context->getRef(),
444 context.get()->get(), cOperationName));
445 },
446 "Get an instance of OperationType for the given kind in the given "
447 "context",
448 nb::arg("operation_name"), nb::arg("context").none() = nb::none());
449 c.def_prop_ro(
450 "operation_name",
451 [](const OperationType &type) {
452 MlirStringRef operationName =
454 return nb::str(operationName.data, operationName.length);
455 },
456 "Get the name of the payload operation accepted by the handle.");
457 }
458};
459
460//===-------------------------------------------------------------------===//
461// ParamType
462//===-------------------------------------------------------------------===//
463
464struct ParamType : PyConcreteType<ParamType> {
468 static constexpr const char *pyClassName = "ParamType";
470 using Base::Base;
471
472 static void bindDerived(ClassTy &c) {
473 c.def_static(
474 "get",
475 [](const PyType &type, DefaultingPyMlirContext context) {
476 return ParamType(context->getRef(), mlirTransformParamTypeGet(
477 context.get()->get(), type));
478 },
479 "Get an instance of ParamType for the given type in the given context.",
480 nb::arg("type"), nb::arg("context").none() = nb::none());
481 c.def_prop_ro(
482 "type",
483 [](ParamType type) {
484 return PyType(type.getContext(), mlirTransformParamTypeGetType(type))
485 .maybeDownCast();
486 },
487 "Get the type this ParamType is associated with.");
488 }
489};
490
491//===----------------------------------------------------------------------===//
492// MemoryEffectsOpInterface helpers
493//===----------------------------------------------------------------------===//
494
495namespace {
496void onlyReadsHandle(nb::iterable &operands,
498 std::vector<MlirOpOperand> operandsVec;
499 for (auto operand : operands)
500 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
501 mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
502 effects.effects);
503};
504
505void consumesHandle(nb::iterable &operands,
506 PyMemoryEffectsInstanceList effects) {
507 std::vector<MlirOpOperand> operandsVec;
508 for (auto operand : operands)
509 operandsVec.push_back(nb::cast<PyOpOperand>(operand));
510 mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
511 effects.effects);
512};
513
514void producesHandle(nb::iterable &results,
515 PyMemoryEffectsInstanceList effects) {
516 std::vector<MlirValue> resultsVec;
517 for (auto result : results)
518 resultsVec.push_back(nb::cast<PyOpResult>(result).get());
519 mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
520 effects.effects);
521};
522
523void modifiesPayload(PyMemoryEffectsInstanceList effects) {
524 mlirTransformModifiesPayload(effects.effects);
525}
526
527void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
528 mlirTransformOnlyReadsPayload(effects.effects);
529}
530} // namespace
531
532static void populateDialectTransformSubmodule(nb::module_ &m) {
533 nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
535 .value("SilenceableFailure",
537 .value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
538
544
550
551 m.def("only_reads_handle", onlyReadsHandle,
552 "Mark operands as only reading handles.", nb::arg("operands"),
553 nb::arg("effects"));
554
555 m.def("consumes_handle", consumesHandle,
556 "Mark operands as consuming handles.", nb::arg("operands"),
557 nb::arg("effects"));
558
559 m.def("produces_handle", producesHandle, "Mark results as producing handles.",
560 nb::arg("results"), nb::arg("effects"));
561
562 m.def("modifies_payload", modifiesPayload,
563 "Mark the transform as modifying the payload.", nb::arg("effects"));
564
565 m.def("only_reads_payload", onlyReadsPayload,
566 "Mark the transform as only reading the payload.", nb::arg("effects"));
567}
568} // namespace transform
569} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
570} // namespace python
571} // namespace mlir
572
573NB_MODULE(_mlirDialectsTransform, m) {
574 m.doc() = "MLIR Transform dialect.";
577}
NB_MODULE(_mlirDialectsTransform, m)
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:650
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:291
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:460
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1351
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:957
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRCore.h:889
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new PatternDescriptorOpInterface FallbackModel to the named operation.
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new TransformOpInterface FallbackModel to the named operation.
void setOps(PyValue &result, const nb::typed< nb::sequence, PyOperationBase > &ops)
void setValues(PyValue &result, const nb::typed< nb::sequence, PyValue > &values)
void setParams(PyValue &result, const nb::typed< nb::sequence, PyAttribute > &params)
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx)
Definition Transform.cpp:37
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void)
Definition Transform.cpp:53
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void)
Returns the interface TypeID of the TransformOpInterface.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type)
Definition Transform.cpp:69
MLIR_CAPI_EXPORTED void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential modifications to the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result, intptr_t numParams, MlirAttribute *params)
Set the parameters for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type)
Definition Transform.cpp:89
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as only reading handles.
MLIR_CAPI_EXPORTED void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value, MlirAttributeCallback callback, void *userData)
Iterate over parameters associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type)
Definition Transform.cpp:49
MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirPatternDescriptorOpInterfaceCallbacks callbacks)
Attach PatternDescriptorOpInterface to the operation with the given name using the provided callbacks...
MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result, intptr_t numOps, MlirOperation *ops)
Set the payload operations for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void)
Definition Transform.cpp:73
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value, MlirValueCallback callback, void *userData)
Iterate over payload values associated with the transform IR value.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void)
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type)
MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void)
Returns the interface TypeID of the PatternDescriptorOpInterface.
MLIR_CAPI_EXPORTED MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName)
Definition Transform.cpp:97
MLIR_CAPI_EXPORTED MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter)
Cast the TransformRewriter to a RewriterBase.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetName(void)
MLIR_CAPI_EXPORTED void mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value, MlirOperationCallback callback, void *userData)
Iterate over payload operations associated with the transform IR value.
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type)
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type)
MlirDiagnosedSilenceableFailure
Enum representing the result of a transform operation.
Definition Transform.h:41
@ MlirDiagnosedSilenceableFailureSuccess
The operation succeeded.
Definition Transform.h:43
@ MlirDiagnosedSilenceableFailureDefiniteFailure
The operation failed definitively.
Definition Transform.h:47
@ MlirDiagnosedSilenceableFailureSilenceableFailure
The operation failed in a silenceable way.
Definition Transform.h:45
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void)
Definition Transform.cpp:93
MLIR_CAPI_EXPORTED void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults, MlirMemoryEffectInstancesList effects)
Helper to mark results as producing handles.
MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirTransformOpInterfaceCallbacks callbacks)
Attach TransformOpInterface to the operation with the given name using the provided callbacks.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyValueTypeGetName(void)
Definition Transform.cpp:81
MLIR_CAPI_EXPORTED void mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result, intptr_t numValues, MlirValue *values)
Set the payload values for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx)
Definition Transform.cpp:77
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx)
Definition Transform.cpp:57
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type)
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyOpTypeGetName(void)
Definition Transform.cpp:41
MLIR_CAPI_EXPORTED void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects)
Helper to mark potential reads from the payload IR.
MLIR_CAPI_EXPORTED void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands, MlirMemoryEffectInstancesList effects)
Helper to mark operands as consuming handles.
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyParamTypeGetName(void)
Definition Transform.cpp:61
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void)
Definition Transform.cpp:33
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:210
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:1873
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Callbacks for implementing PatternDescriptorOpInterface from external code.
Definition Transform.h:218
void(* populatePatternsWithState)(MlirOperation op, MlirRewritePatternSet patterns, MlirTransformState state, void *userData)
Optional callback to populate rewrite patterns with transform state.
Definition Transform.h:230
void(* populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns, void *userData)
Callback to populate rewrite patterns into the given pattern set.
Definition Transform.h:226
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Transform.h:221
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Transform.h:224
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78
const char * data
Pointer to the first symbol.
Definition Support.h:79
size_t length
Length of the fragment.
Definition Support.h:80
Callbacks for implementing TransformOpInterface from external code.
Definition Transform.h:186
MlirDiagnosedSilenceableFailure(* apply)(MlirOperation op, MlirTransformRewriter rewriter, MlirTransformResults results, MlirTransformState state, void *userData)
Apply callback that implements the transformation.
Definition Transform.h:194
void(* destruct)(void *userData)
Optional destructor for the user data.
Definition Transform.h:192
void(* construct)(void *userData)
Optional constructor for the user data.
Definition Transform.h:189
bool(* allowsRepeatedHandleOperands)(MlirOperation op, void *userData)
Callback to check if repeated handle operands are allowed.
Definition Transform.h:200