MLIR 23.0.0git
IRInterfaces.cpp
Go to the documentation of this file.
1//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <optional>
11#include <string>
12#include <utility>
13#include <vector>
14
16#include "mlir-c/IR.h"
17#include "mlir-c/Interfaces.h"
18#include "mlir-c/Support.h"
21
22namespace nb = nanobind;
23
24namespace mlir {
25namespace python {
27constexpr static const char *inferReturnTypesDoc =
28 R"(Given the arguments required to build an operation, attempts to infer
29its return types. Raises ValueError on failure.)";
30
31constexpr static const char *inferReturnTypeComponentsDoc =
32 R"(Given the arguments required to build an operation, attempts to infer
33its return shaped type components. Raises ValueError on failure.)";
34
35namespace {
36
37/// Takes in an optional ist of operands and converts them into a std::vector
38/// of MlirVlaues. Returns an empty std::vector if the list is empty.
39std::vector<MlirValue> wrapOperands(std::optional<nb::sequence> operandList) {
40 std::vector<MlirValue> mlirOperands;
41
42 if (!operandList || nb::len(*operandList) == 0) {
43 return mlirOperands;
44 }
45
46 // Note: as the list may contain other lists this may not be final size.
47 mlirOperands.reserve(nb::len(*operandList));
48 for (size_t i = 0, e = nb::len(*operandList); i < e; ++i) {
49 nb::handle operand = (*operandList)[i];
50 intptr_t index = static_cast<intptr_t>(i);
51 if (operand.is_none())
52 continue;
53
54 PyValue *val;
55 try {
56 val = nb::cast<PyValue *>(operand);
57 if (!val)
58 throw nb::cast_error();
59 mlirOperands.push_back(val->get());
60 continue;
61 } catch (nb::cast_error &err) {
62 // Intentionally unhandled to try sequence below first.
63 (void)err;
64 }
65
66 try {
67 auto vals = nb::cast<nb::sequence>(operand);
68 for (nb::handle v : vals) {
69 try {
70 val = nb::cast<PyValue *>(v);
71 if (!val)
72 throw nb::cast_error();
73 mlirOperands.push_back(val->get());
74 } catch (nb::cast_error &err) {
75 throw nb::value_error(
76 nanobind::detail::join("Operand ", index,
77 " must be a Value or Sequence of Values (",
78 err.what(), ")")
79 .c_str());
80 }
81 }
82 continue;
83 } catch (nb::cast_error &err) {
84 throw nb::value_error(
85 nanobind::detail::join("Operand ", index,
86 " must be a Value or Sequence of Values (",
87 err.what(), ")")
88 .c_str());
89 }
90
91 throw nb::cast_error();
92 }
93
94 return mlirOperands;
95}
96
97/// Takes in an optional vector of PyRegions and returns a std::vector of
98/// MlirRegion. Returns an empty std::vector if the list is empty.
99std::vector<MlirRegion>
100wrapRegions(std::optional<std::vector<PyRegion>> regions) {
101 std::vector<MlirRegion> mlirRegions;
102
103 if (regions) {
104 mlirRegions.reserve(regions->size());
105 for (PyRegion &region : *regions) {
106 mlirRegions.push_back(region);
107 }
108 }
109
110 return mlirRegions;
111}
112
113} // namespace
114
115/// Python wrapper for InferTypeOpInterface. This interface has only static
116/// methods.
118 : public PyConcreteOpInterface<PyInferTypeOpInterface> {
119public:
121
122 constexpr static const char *pyClassName = "InferTypeOpInterface";
125
126 /// C-style user-data structure for type appending callback.
131
132 /// Appends the types provided as the two first arguments to the user-data
133 /// structure (expects AppendResultsCallbackData).
134 static void appendResultsCallback(intptr_t nTypes, MlirType *types,
135 void *userData) {
136 auto *data = static_cast<AppendResultsCallbackData *>(userData);
137 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
138 for (intptr_t i = 0; i < nTypes; ++i) {
139 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
140 }
141 }
142
143 /// Given the arguments required to build an operation, attempts to infer its
144 /// return types. Throws value_error on failure.
145 std::vector<PyType>
146 inferReturnTypes(std::optional<nb::sequence> operandList,
147 std::optional<PyAttribute> attributes, void *properties,
148 std::optional<std::vector<PyRegion>> regions,
150 DefaultingPyLocation location) {
151 std::vector<MlirValue> mlirOperands = wrapOperands(std::move(operandList));
152 std::vector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
153
154 std::vector<PyType> inferredTypes;
155 PyMlirContext &pyContext = context.resolve();
156 AppendResultsCallbackData data{inferredTypes, pyContext};
157 MlirStringRef opNameRef =
158 mlirStringRefCreate(getOpName().data(), getOpName().length());
159 MlirAttribute attributeDict =
160 attributes ? attributes->get() : mlirAttributeGetNull();
161
163 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
164 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
165 mlirRegions.data(), &appendResultsCallback, &data);
166
168 throw nb::value_error("Failed to infer result types");
169 }
170
171 return inferredTypes;
172 }
173
174 static void bindDerived(ClassTy &cls) {
175 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
176 nb::arg("operands") = nb::none(),
177 nb::arg("attributes") = nb::none(),
178 nb::arg("properties") = nb::none(), nb::arg("regions") = nb::none(),
179 nb::arg("context") = nb::none(), nb::arg("loc") = nb::none(),
181 }
182};
183
184/// Wrapper around an shaped type components.
186public:
187 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
188 PyShapedTypeComponents(nb::list shape, MlirType elementType)
189 : shape(std::move(shape)), elementType(elementType), ranked(true) {}
190 PyShapedTypeComponents(nb::list shape, MlirType elementType,
191 MlirAttribute attribute)
192 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
193 ranked(true) {}
196 : shape(other.shape), elementType(other.elementType),
197 attribute(other.attribute), ranked(other.ranked) {}
198
199 static void bind(nb::module_ &m) {
200 nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
201 .def_prop_ro(
202 "element_type",
203 [](PyShapedTypeComponents &self) { return self.elementType; },
204 nb::sig("def element_type(self) -> Type"),
205 "Returns the element type of the shaped type components.")
206 .def_static(
207 "get",
208 [](PyType &elementType) {
209 return PyShapedTypeComponents(elementType);
210 },
211 nb::arg("element_type"),
212 "Create an shaped type components object with only the element "
213 "type.")
214 .def_static(
215 "get",
216 [](nb::typed<nb::list, nb::int_> shape, PyType &elementType) {
217 return PyShapedTypeComponents(std::move(shape), elementType);
218 },
219 nb::arg("shape"), nb::arg("element_type"),
220 "Create a ranked shaped type components object.")
221 .def_static(
222 "get",
223 [](nb::typed<nb::list, nb::int_> shape, PyType &elementType,
224 PyAttribute &attribute) {
225 return PyShapedTypeComponents(std::move(shape), elementType,
226 attribute);
227 },
228 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
229 "Create a ranked shaped type components object with attribute.")
230 .def_prop_ro(
231 "has_rank",
232 [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
233 "Returns whether the given shaped type component is ranked.")
234 .def_prop_ro(
235 "rank",
236 [](PyShapedTypeComponents &self) -> std::optional<nb::int_> {
237 if (!self.ranked)
238 return {};
239 return nb::int_(self.shape.size());
240 },
241 "Returns the rank of the given ranked shaped type components. If "
242 "the shaped type components does not have a rank, None is "
243 "returned.")
244 .def_prop_ro(
245 "shape",
246 [](PyShapedTypeComponents &self) -> std::optional<nb::list> {
247 if (!self.ranked)
248 return {};
249 return nb::list(self.shape);
250 },
251 "Returns the shape of the ranked shaped type components as a list "
252 "of integers. Returns none if the shaped type component does not "
253 "have a rank.");
254 }
255
256 nb::object getCapsule();
257 static PyShapedTypeComponents createFromCapsule(nb::object capsule);
258
259private:
260 nb::list shape;
261 MlirType elementType;
262 MlirAttribute attribute;
263 bool ranked{false};
264};
265
266/// Python wrapper for InferShapedTypeOpInterface. This interface has only
267/// static methods.
269 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
270public:
273
274 constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
277
278 /// C-style user-data structure for type appending callback.
280 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
281 };
282
283 /// Appends the shaped type components provided as unpacked shape, element
284 /// type, attribute to the user-data.
285 static void appendResultsCallback(bool hasRank, intptr_t rank,
286 const int64_t *shape, MlirType elementType,
287 MlirAttribute attribute, void *userData) {
288 auto *data = static_cast<AppendResultsCallbackData *>(userData);
289 if (!hasRank) {
290 data->inferredShapedTypeComponents.emplace_back(elementType);
291 } else {
292 nb::list shapeList;
293 for (intptr_t i = 0; i < rank; ++i) {
294 shapeList.append(shape[i]);
295 }
296 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
297 attribute);
298 }
299 }
300
301 /// Given the arguments required to build an operation, attempts to infer the
302 /// shaped type components. Throws value_error on failure.
303 std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
304 std::optional<nb::sequence> operandList,
305 std::optional<PyAttribute> attributes, void *properties,
306 std::optional<std::vector<PyRegion>> regions,
308 std::vector<MlirValue> mlirOperands = wrapOperands(std::move(operandList));
309 std::vector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
310
311 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
312 PyMlirContext &pyContext = context.resolve();
313 AppendResultsCallbackData data{inferredShapedTypeComponents};
314 MlirStringRef opNameRef =
315 mlirStringRefCreate(getOpName().data(), getOpName().length());
316 MlirAttribute attributeDict =
317 attributes ? attributes->get() : mlirAttributeGetNull();
318
320 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
321 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
322 mlirRegions.data(), &appendResultsCallback, &data);
323
325 throw nb::value_error("Failed to infer result shape type components");
326 }
327
328 return inferredShapedTypeComponents;
329 }
330
331 static void bindDerived(ClassTy &cls) {
332 cls.def("inferReturnTypeComponents",
334 nb::arg("operands") = nb::none(),
335 nb::arg("attributes") = nb::none(), nb::arg("regions") = nb::none(),
336 nb::arg("properties") = nb::none(), nb::arg("context") = nb::none(),
337 nb::arg("loc") = nb::none(), inferReturnTypeComponentsDoc);
338 }
339};
340
341/// Wrapper around the ConditionallySpeculatable interface.
343 : public PyConcreteOpInterface<PyConditionallySpeculatableOpInterface> {
344public:
347
348 constexpr static const char *pyClassName = "ConditionallySpeculatable";
351
352 /// Attach a new ConditionallySpeculatable FallbackModel to the named
353 /// operation. The FallbackModel acts as a trampoline for callbacks on the
354 /// Python class.
355 static void attach(nb::object &target, const std::string &opName,
358 callbacks.userData = target.ptr();
359 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
360 callbacks.construct = nullptr;
361 callbacks.destruct = [](void *userData) {
362 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
363 };
364 callbacks.getSpeculatability = [](MlirOperation op, void *userData) {
365 nb::handle pyClass(static_cast<PyObject *>(userData));
366
367 auto pyGetSpeculatability =
368 nb::cast<nb::callable>(nb::getattr(pyClass, "get_speculatability"));
369
370 PyMlirContextRef context =
372 auto opview = PyOperation::forOperation(context, op)->createOpView();
373
374 return nb::cast<MlirSpeculatability>(pyGetSpeculatability(opview));
375 };
376
378 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
379 callbacks);
380 }
381
382 static void bindDerived(ClassTy &cls) {
383 cls.def(
384 "getSpeculatability",
386 if (self.isStatic())
387 throw nb::type_error(
388 "Cannot query speculatability on a static interface");
389 auto operation = self.getOperationObject();
390 auto *pyOperation = nb::cast<PyOperation *>(operation);
392 pyOperation->get());
393 },
394 "Returns the speculatability of the given operation.");
395 cls.attr("attach") = classmethod(
396 [](const nb::object &cls, const nb::object &opName, nb::object target,
397 DefaultingPyMlirContext context) {
398 if (target.is_none())
399 target = cls;
400 return attach(target, nb::cast<std::string>(opName), context);
401 },
402 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
403 nb::arg("target").none() = nb::none(),
404 nb::arg("context").none() = nb::none(),
405 "Attach the interface subclass to the given operation name.");
406 }
407};
408
409/// Wrapper around the MemoryEffectsOpInterface.
411 : public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
412public:
415
416 constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
419
420 /// Attach a new MemoryEffectsOpInterface FallbackModel to the named
421 /// operation. The FallbackModel acts as a trampoline for callbacks on the
422 /// Python class.
423 static void attach(nb::object &target, const std::string &opName,
426 callbacks.userData = target.ptr();
427 nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
428 callbacks.construct = nullptr;
429 callbacks.destruct = [](void *userData) {
430 nb::handle(static_cast<PyObject *>(userData)).dec_ref();
431 };
432 callbacks.getEffects = [](MlirOperation op,
433 MlirMemoryEffectInstancesList effects,
434 void *userData) {
435 nb::handle pyClass(static_cast<PyObject *>(userData));
436
437 // Get the 'get_effects' method from the Python class.
438 auto pyGetEffects =
439 nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
440
441 PyMemoryEffectsInstanceList effectsWrapper{effects};
442
443 PyMlirContextRef context =
445 auto opview = PyOperation::forOperation(context, op)->createOpView();
446
447 // Invoke `pyClass.get_effects(op, effects)`.
448 pyGetEffects(opview, effectsWrapper);
449 };
450
452 ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
453 callbacks);
454 }
455
456 static void bindDerived(ClassTy &cls) {
457 cls.attr("attach") = classmethod(
458 [](const nb::object &cls, const nb::object &opName, nb::object target,
459 DefaultingPyMlirContext context) {
460 if (target.is_none())
461 target = cls;
462 return attach(target, nb::cast<std::string>(opName), context);
463 },
464 nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
465 nb::arg("target").none() = nb::none(),
466 nb::arg("context").none() = nb::none(),
467 "Attach the interface subclass to the given operation name.");
468 }
469};
470
471void populateIRInterfaces(nb::module_ &m) {
472 nb::enum_<MlirSpeculatability>(m, "Speculatability")
473 .value("NotSpeculatable", MlirSpeculatabilityNotSpeculatable)
474 .value("Speculatable", MlirSpeculatabilitySpeculatable)
475 .value("RecursivelySpeculatable",
477 auto memoryEffectsInstanceListClass =
478 nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
479 (void)memoryEffectsInstanceListClass;
480
486}
487} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
488} // namespace python
489} // namespace mlir
true
Given two iterators into the same block, return "true" if a is before `b.
MlirContext mlirOperationGetContext(MlirOperation op)
Definition IR.cpp:658
ReferrentTy * get() const
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:541
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRCore.h:291
Wrapper around the generic MlirAttribute.
Definition IRCore.h:1018
PyConcreteOpInterface(nanobind::object object, DefaultingPyMlirContext context)
nanobind::typed< nanobind::object, PyOperation > getOperationObject()
Returns the operation instance from which this object was constructed.
bool isStatic()
Returns true if this object was constructed from a subclass of OpView rather than from an operation i...
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new ConditionallySpeculatable FallbackModel to the named operation.
static void appendResultsCallback(bool hasRank, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute attribute, void *userData)
Appends the shaped type components provided as unpacked shape, element type, attribute to the user-da...
std::vector< PyShapedTypeComponents > inferReturnTypeComponents(std::optional< nb::sequence > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion > > regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer the shaped type components.
std::vector< PyType > inferReturnTypes(std::optional< nb::sequence > operandList, std::optional< PyAttribute > attributes, void *properties, std::optional< std::vector< PyRegion > > regions, DefaultingPyMlirContext context, DefaultingPyLocation location)
Given the arguments required to build an operation, attempts to infer its return types.
static void appendResultsCallback(intptr_t nTypes, MlirType *types, void *userData)
Appends the types provided as the two first arguments to the user-data structure (expects AppendResul...
static void attach(nb::object &target, const std::string &opName, DefaultingPyMlirContext ctx)
Attach a new MemoryEffectsOpInterface FallbackModel to the named operation.
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:461
MlirContext get()
Accesses the underlying MlirContext.
Definition IRCore.h:224
nanobind::object createOpView()
Creates an OpView suitable for this operation.
Definition IRCore.cpp:1352
static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive=nanobind::object())
Returns a PyOperation for the given MlirOperation, optionally associating it with a parentKeepAlive.
Definition IRCore.cpp:958
PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute)
static PyShapedTypeComponents createFromCapsule(nb::object capsule)
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
Wrapper around the generic MlirType.
Definition IRCore.h:891
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirShapedTypeComponentsCallback callback, void *userData)
Infers the return shaped type components of the operation.
MLIR_CAPI_EXPORTED MlirSpeculatability mlirConditionallySpeculatableOpInterfaceGetSpeculatability(MlirOperation operation)
Returns the speculatability of the given operation.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void)
Returns the interface TypeID of the InferTypeOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirConditionallySpeculatableOpInterfaceTypeID(void)
Returns the interface TypeID of the ConditionallySpeculatable interface.
MLIR_CAPI_EXPORTED void mlirConditionallySpeculatableOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirConditionallySpeculatableOpInterfaceCallbacks callbacks)
Attach a new FallbackModel for the ConditionallySpeculatable interface to the named operation.
@ MlirSpeculatabilityRecursivelySpeculatable
The operation is speculatable if all nested operations are speculatable.
Definition Interfaces.h:111
@ MlirSpeculatabilitySpeculatable
The operation is speculatable.
Definition Interfaces.h:109
@ MlirSpeculatabilityNotSpeculatable
The operation is not speculatable.
Definition Interfaces.h:107
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData)
Infers the return types of the operation identified by its canonical given the arguments that will be...
MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void)
Returns the interface TypeID of the MemoryEffectsOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void)
Returns the interface TypeID of the InferShapedTypeOpInterface.
MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(MlirContext ctx, MlirStringRef opName, MlirMemoryEffectsOpInterfaceCallbacks callbacks)
Attach a new FallbackModel for the MemoryEffectsOpInterface to the named operation.
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:87
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:132
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRCore.h:210
static constexpr const char * inferReturnTypesDoc
static constexpr const char * inferReturnTypeComponentsDoc
nanobind::object classmethod(Func f, Args... args)
Helper for creating an @classmethod.
Definition IRCore.h:2002
Include the generated interface declarations.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
Callbacks for implementing ConditionallySpeculatable from external code.
Definition Interfaces.h:119
void(* destruct)(void *userData)
Optional destructor for user data. Set to nullptr to disable it.
Definition Interfaces.h:123
void(* construct)(void *userData)
Optional constructor for user data. Set to nullptr to disable it.
Definition Interfaces.h:121
MlirSpeculatability(* getSpeculatability)(MlirOperation op, void *userData)
Returns the speculatability of the given operation.
Definition Interfaces.h:125
A logical result value, essentially a boolean with named states.
Definition Support.h:121
Callbacks for implementing MemoryEffectsOpInterface from external code.
Definition Interfaces.h:151
void(* construct)(void *userData)
Optional constructor for user data. Set to nullptr to disable it.
Definition Interfaces.h:153
void(* getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects, void *userData)
Get memory effects callback.
Definition Interfaces.h:157
void(* destruct)(void *userData)
Optional destructor for user data. Set to nullptr to disable it.
Definition Interfaces.h:155
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:78