MLIR 22.0.0git
IRAttributes.cpp
Go to the documentation of this file.
1//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <optional>
11#include <string>
12#include <string_view>
13#include <utility>
14
15#include "IRModule.h"
16#include "NanobindUtils.h"
18#include "mlir-c/BuiltinTypes.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/Support/raw_ostream.h"
23
24namespace nb = nanobind;
25using namespace nanobind::literals;
26using namespace mlir;
27using namespace mlir::python;
28
30
31//------------------------------------------------------------------------------
32// Docstrings (trivial, non-duplicated docstrings are included inline).
33//------------------------------------------------------------------------------
34
35static const char kDenseElementsAttrGetDocstring[] =
36 R"(Gets a DenseElementsAttr from a Python buffer or array.
37
38When `type` is not provided, then some limited type inferencing is done based
39on the buffer format. Support presently exists for 8/16/32/64 signed and
40unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41types can also be converted back to a corresponding buffer.
42
43For conversions outside of these types, a `type=` must be explicitly provided
44and the buffer contents must be bit-castable to the MLIR internal
45representation:
46
47 * Integer types (except for i1): the buffer must be byte aligned to the
48 next byte boundary.
49 * Floating point types: Must be bit-castable to the given floating point
50 size.
51 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
54
55If a single element buffer is passed (or for i1, a single byte with value 0
56or 255), then a splat will be created.
57
58Args:
59 array: The array or buffer to convert.
60 signless: If inferring an appropriate MLIR type, use signless types for
61 integers (defaults True).
62 type: Skips inference of the MLIR element type and uses this instead. The
63 storage size must be consistent with the actual contents of the buffer.
64 shape: Overrides the shape of the buffer when constructing the MLIR
65 shaped type. This is needed when the physical and logical shape differ (as
66 for i1).
67 context: Explicit context, if not from context manager.
68
69Returns:
70 DenseElementsAttr on success.
71
72Raises:
73 ValueError: If the type of the buffer or array cannot be matched to an MLIR
74 type or if the buffer does not meet expectations.
75)";
76
78 R"(Gets a DenseElementsAttr from a Python list of attributes.
79
80Note that it can be expensive to construct attributes individually.
81For a large number of elements, consider using a Python buffer or array instead.
82
83Args:
84 attrs: A list of attributes.
85 type: The desired shape and type of the resulting DenseElementsAttr.
86 If not provided, the element type is determined based on the type
87 of the 0th attribute and the shape is `[len(attrs)]`.
88 context: Explicit context, if not from context manager.
89
90Returns:
91 DenseElementsAttr on success.
92
93Raises:
94 ValueError: If the type of the attributes does not match the type
95 specified by `shaped_type`.
96)";
97
99 R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100
101This function does minimal validation or massaging of the data, and it is
102up to the caller to ensure that the buffer meets the characteristics
103implied by the shape.
104
105The backing buffer and any user objects will be retained for the lifetime
106of the resource blob. This is typically bounded to the context but the
107resource can have a shorter lifespan depending on how it is used in
108subsequent processing.
109
110Args:
111 buffer: The array or buffer to convert.
112 name: Name to provide to the resource (may be changed upon collision).
113 type: The explicit ShapedType to construct the attribute with.
114 context: Explicit context, if not from context manager.
115
116Returns:
117 DenseResourceElementsAttr on success.
118
119Raises:
120 ValueError: If the type of the buffer or array cannot be matched to an MLIR
121 type or if the buffer does not meet expectations.
122)";
123
124namespace {
125
126struct nb_buffer_info {
127 void *ptr = nullptr;
128 ssize_t itemsize = 0;
129 ssize_t size = 0;
130 const char *format = nullptr;
131 ssize_t ndim = 0;
134 bool readonly = false;
135
136 nb_buffer_info(
137 void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
139 bool readonly = false,
140 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142 : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143 shape(std::move(shape_in)), strides(std::move(strides_in)),
144 readonly(readonly), owned_view(std::move(owned_view_in)) {
145 size = 1;
146 for (ssize_t i = 0; i < ndim; ++i) {
147 size *= shape[i];
148 }
149 }
150
151 explicit nb_buffer_info(Py_buffer *view)
152 : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153 {view->shape, view->shape + view->ndim},
154 // TODO(phawkins): check for null strides
155 {view->strides, view->strides + view->ndim},
156 view->readonly != 0,
157 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158 view, PyBuffer_Release)) {}
159
160 nb_buffer_info(const nb_buffer_info &) = delete;
161 nb_buffer_info(nb_buffer_info &&) = default;
162 nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163 nb_buffer_info &operator=(nb_buffer_info &&) = default;
164
165private:
166 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167};
168
169class nb_buffer : public nb::object {
170 NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
171
172 nb_buffer_info request() const {
173 int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174 auto *view = new Py_buffer();
175 if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176 delete view;
177 throw nb::python_error();
178 }
179 return nb_buffer_info(view);
180 }
181};
182
183template <typename T>
184struct nb_format_descriptor {};
185
186template <>
187struct nb_format_descriptor<bool> {
188 static const char *format() { return "?"; }
189};
190template <>
191struct nb_format_descriptor<int8_t> {
192 static const char *format() { return "b"; }
193};
194template <>
195struct nb_format_descriptor<uint8_t> {
196 static const char *format() { return "B"; }
197};
198template <>
199struct nb_format_descriptor<int16_t> {
200 static const char *format() { return "h"; }
201};
202template <>
203struct nb_format_descriptor<uint16_t> {
204 static const char *format() { return "H"; }
205};
206template <>
207struct nb_format_descriptor<int32_t> {
208 static const char *format() { return "i"; }
209};
210template <>
211struct nb_format_descriptor<uint32_t> {
212 static const char *format() { return "I"; }
213};
214template <>
215struct nb_format_descriptor<int64_t> {
216 static const char *format() { return "q"; }
217};
218template <>
219struct nb_format_descriptor<uint64_t> {
220 static const char *format() { return "Q"; }
221};
222template <>
223struct nb_format_descriptor<float> {
224 static const char *format() { return "f"; }
225};
226template <>
227struct nb_format_descriptor<double> {
228 static const char *format() { return "d"; }
229};
230
231static MlirStringRef toMlirStringRef(const std::string &s) {
232 return mlirStringRefCreate(s.data(), s.size());
233}
234
235static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237}
238
239class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240public:
241 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242 static constexpr const char *pyClassName = "AffineMapAttr";
244 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
246
247 static void bindDerived(ClassTy &c) {
248 c.def_static(
249 "get",
250 [](PyAffineMap &affineMap) {
251 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252 return PyAffineMapAttribute(affineMap.getContext(), attr);
253 },
254 nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255 c.def_prop_ro(
256 "value",
257 [](PyAffineMapAttribute &self) {
258 return PyAffineMap(self.getContext(),
260 },
261 "Returns the value of the AffineMap attribute");
262 }
263};
264
265class PyIntegerSetAttribute
266 : public PyConcreteAttribute<PyIntegerSetAttribute> {
267public:
268 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
269 static constexpr const char *pyClassName = "IntegerSetAttr";
271 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
273
274 static void bindDerived(ClassTy &c) {
275 c.def_static(
276 "get",
277 [](PyIntegerSet &integerSet) {
278 MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
279 return PyIntegerSetAttribute(integerSet.getContext(), attr);
280 },
281 nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
282 }
283};
284
285template <typename T>
286static T pyTryCast(nb::handle object) {
287 try {
288 return nb::cast<T>(object);
289 } catch (nb::cast_error &err) {
290 std::string msg = std::string("Invalid attribute when attempting to "
291 "create an ArrayAttribute (") +
292 err.what() + ")";
293 throw std::runtime_error(msg.c_str());
294 } catch (std::runtime_error &err) {
295 std::string msg = std::string("Invalid attribute (None?) when attempting "
296 "to create an ArrayAttribute (") +
297 err.what() + ")";
298 throw std::runtime_error(msg.c_str());
299 }
300}
301
302/// A python-wrapped dense array attribute with an element type and a derived
303/// implementation class.
304template <typename EltTy, typename DerivedT>
305class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
306public:
308
309 /// Iterator over the integer elements of a dense array.
310 class PyDenseArrayIterator {
311 public:
312 PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
313
314 /// Return a copy of the iterator.
315 PyDenseArrayIterator dunderIter() { return *this; }
316
317 /// Return the next element.
318 EltTy dunderNext() {
319 // Throw if the index has reached the end.
321 throw nb::stop_iteration();
322 return DerivedT::getElement(attr.get(), nextIndex++);
323 }
324
325 /// Bind the iterator class.
326 static void bind(nb::module_ &m) {
327 nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
328 .def("__iter__", &PyDenseArrayIterator::dunderIter)
329 .def("__next__", &PyDenseArrayIterator::dunderNext);
330 }
331
332 private:
333 /// The referenced dense array attribute.
334 PyAttribute attr;
335 /// The next index to read.
336 int nextIndex = 0;
337 };
338
339 /// Get the element at the given index.
340 EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
341
342 /// Bind the attribute class.
343 static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
344 // Bind the constructor.
345 if constexpr (std::is_same_v<EltTy, bool>) {
346 c.def_static(
347 "get",
348 [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
349 std::vector<bool> values;
350 for (nb::handle py_value : py_values) {
351 int is_true = PyObject_IsTrue(py_value.ptr());
352 if (is_true < 0) {
353 throw nb::python_error();
354 }
355 values.push_back(is_true);
356 }
357 return getAttribute(values, ctx->getRef());
358 },
359 nb::arg("values"), nb::arg("context") = nb::none(),
360 "Gets a uniqued dense array attribute");
361 } else {
362 c.def_static(
363 "get",
364 [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
365 return getAttribute(values, ctx->getRef());
366 },
367 nb::arg("values"), nb::arg("context") = nb::none(),
368 "Gets a uniqued dense array attribute");
369 }
370 // Bind the array methods.
371 c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
372 if (i >= mlirDenseArrayGetNumElements(arr))
373 throw nb::index_error("DenseArray index out of range");
374 return arr.getItem(i);
375 });
376 c.def("__len__", [](const DerivedT &arr) {
378 });
379 c.def("__iter__",
380 [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
381 c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
382 std::vector<EltTy> values;
383 intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
384 values.reserve(numOldElements + nb::len(extras));
385 for (intptr_t i = 0; i < numOldElements; ++i)
386 values.push_back(arr.getItem(i));
387 for (nb::handle attr : extras)
388 values.push_back(pyTryCast<EltTy>(attr));
389 return getAttribute(values, arr.getContext());
390 });
391 }
392
393private:
394 static DerivedT getAttribute(const std::vector<EltTy> &values,
395 PyMlirContextRef ctx) {
396 if constexpr (std::is_same_v<EltTy, bool>) {
397 std::vector<int> intValues(values.begin(), values.end());
398 MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
399 intValues.data());
400 return DerivedT(ctx, attr);
401 } else {
402 MlirAttribute attr =
403 DerivedT::getAttribute(ctx->get(), values.size(), values.data());
404 return DerivedT(ctx, attr);
405 }
406 }
407};
408
409/// Instantiate the python dense array classes.
410struct PyDenseBoolArrayAttribute
411 : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
412 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
413 static constexpr auto getAttribute = mlirDenseBoolArrayGet;
414 static constexpr auto getElement = mlirDenseBoolArrayGetElement;
415 static constexpr const char *pyClassName = "DenseBoolArrayAttr";
416 static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
417 using PyDenseArrayAttribute::PyDenseArrayAttribute;
418};
419struct PyDenseI8ArrayAttribute
420 : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
421 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
422 static constexpr auto getAttribute = mlirDenseI8ArrayGet;
423 static constexpr auto getElement = mlirDenseI8ArrayGetElement;
424 static constexpr const char *pyClassName = "DenseI8ArrayAttr";
425 static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
426 using PyDenseArrayAttribute::PyDenseArrayAttribute;
427};
428struct PyDenseI16ArrayAttribute
429 : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
430 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
431 static constexpr auto getAttribute = mlirDenseI16ArrayGet;
432 static constexpr auto getElement = mlirDenseI16ArrayGetElement;
433 static constexpr const char *pyClassName = "DenseI16ArrayAttr";
434 static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
435 using PyDenseArrayAttribute::PyDenseArrayAttribute;
436};
437struct PyDenseI32ArrayAttribute
438 : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
439 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
440 static constexpr auto getAttribute = mlirDenseI32ArrayGet;
441 static constexpr auto getElement = mlirDenseI32ArrayGetElement;
442 static constexpr const char *pyClassName = "DenseI32ArrayAttr";
443 static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
444 using PyDenseArrayAttribute::PyDenseArrayAttribute;
445};
446struct PyDenseI64ArrayAttribute
447 : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
448 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
449 static constexpr auto getAttribute = mlirDenseI64ArrayGet;
450 static constexpr auto getElement = mlirDenseI64ArrayGetElement;
451 static constexpr const char *pyClassName = "DenseI64ArrayAttr";
452 static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
453 using PyDenseArrayAttribute::PyDenseArrayAttribute;
454};
455struct PyDenseF32ArrayAttribute
456 : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
457 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
458 static constexpr auto getAttribute = mlirDenseF32ArrayGet;
459 static constexpr auto getElement = mlirDenseF32ArrayGetElement;
460 static constexpr const char *pyClassName = "DenseF32ArrayAttr";
461 static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
462 using PyDenseArrayAttribute::PyDenseArrayAttribute;
463};
464struct PyDenseF64ArrayAttribute
465 : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
466 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
467 static constexpr auto getAttribute = mlirDenseF64ArrayGet;
468 static constexpr auto getElement = mlirDenseF64ArrayGetElement;
469 static constexpr const char *pyClassName = "DenseF64ArrayAttr";
470 static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
471 using PyDenseArrayAttribute::PyDenseArrayAttribute;
472};
473
474class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
475public:
476 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
477 static constexpr const char *pyClassName = "ArrayAttr";
479 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
481
482 class PyArrayAttributeIterator {
483 public:
484 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
485
486 PyArrayAttributeIterator &dunderIter() { return *this; }
487
488 nb::typed<nb::object, PyAttribute> dunderNext() {
489 // TODO: Throw is an inefficient way to stop iteration.
491 throw nb::stop_iteration();
492 return PyAttribute(this->attr.getContext(),
494 .maybeDownCast();
495 }
496
497 static void bind(nb::module_ &m) {
498 nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
499 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
500 .def("__next__", &PyArrayAttributeIterator::dunderNext);
501 }
502
503 private:
504 PyAttribute attr;
505 int nextIndex = 0;
506 };
507
508 MlirAttribute getItem(intptr_t i) {
509 return mlirArrayAttrGetElement(*this, i);
510 }
511
512 static void bindDerived(ClassTy &c) {
513 c.def_static(
514 "get",
515 [](const nb::list &attributes, DefaultingPyMlirContext context) {
516 SmallVector<MlirAttribute> mlirAttributes;
517 mlirAttributes.reserve(nb::len(attributes));
518 for (auto attribute : attributes) {
519 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
520 }
521 MlirAttribute attr = mlirArrayAttrGet(
522 context->get(), mlirAttributes.size(), mlirAttributes.data());
523 return PyArrayAttribute(context->getRef(), attr);
524 },
525 nb::arg("attributes"), nb::arg("context") = nb::none(),
526 "Gets a uniqued Array attribute");
527 c.def(
528 "__getitem__",
529 [](PyArrayAttribute &arr,
530 intptr_t i) -> nb::typed<nb::object, PyAttribute> {
531 if (i >= mlirArrayAttrGetNumElements(arr))
532 throw nb::index_error("ArrayAttribute index out of range");
533 return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
534 })
535 .def("__len__",
536 [](const PyArrayAttribute &arr) {
537 return mlirArrayAttrGetNumElements(arr);
538 })
539 .def("__iter__", [](const PyArrayAttribute &arr) {
540 return PyArrayAttributeIterator(arr);
541 });
542 c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
543 std::vector<MlirAttribute> attributes;
544 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
545 attributes.reserve(numOldElements + nb::len(extras));
546 for (intptr_t i = 0; i < numOldElements; ++i)
547 attributes.push_back(arr.getItem(i));
548 for (nb::handle attr : extras)
549 attributes.push_back(pyTryCast<PyAttribute>(attr));
550 MlirAttribute arrayAttr = mlirArrayAttrGet(
551 arr.getContext()->get(), attributes.size(), attributes.data());
552 return PyArrayAttribute(arr.getContext(), arrayAttr);
553 });
554 }
555};
556
557/// Float Point Attribute subclass - FloatAttr.
558class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
559public:
560 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
561 static constexpr const char *pyClassName = "FloatAttr";
563 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
565
566 static void bindDerived(ClassTy &c) {
567 c.def_static(
568 "get",
569 [](PyType &type, double value, DefaultingPyLocation loc) {
570 PyMlirContext::ErrorCapture errors(loc->getContext());
571 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
572 if (mlirAttributeIsNull(attr))
573 throw MLIRError("Invalid attribute", errors.take());
574 return PyFloatAttribute(type.getContext(), attr);
575 },
576 nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
577 "Gets an uniqued float point attribute associated to a type");
578 c.def_static(
579 "get_unchecked",
580 [](PyType &type, double value, DefaultingPyMlirContext context) {
581 PyMlirContext::ErrorCapture errors(context->getRef());
582 MlirAttribute attr =
583 mlirFloatAttrDoubleGet(context.get()->get(), type, value);
584 if (mlirAttributeIsNull(attr))
585 throw MLIRError("Invalid attribute", errors.take());
586 return PyFloatAttribute(type.getContext(), attr);
587 },
588 nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
589 "Gets an uniqued float point attribute associated to a type");
590 c.def_static(
591 "get_f32",
592 [](double value, DefaultingPyMlirContext context) {
593 MlirAttribute attr = mlirFloatAttrDoubleGet(
594 context->get(), mlirF32TypeGet(context->get()), value);
595 return PyFloatAttribute(context->getRef(), attr);
596 },
597 nb::arg("value"), nb::arg("context") = nb::none(),
598 "Gets an uniqued float point attribute associated to a f32 type");
599 c.def_static(
600 "get_f64",
601 [](double value, DefaultingPyMlirContext context) {
602 MlirAttribute attr = mlirFloatAttrDoubleGet(
603 context->get(), mlirF64TypeGet(context->get()), value);
604 return PyFloatAttribute(context->getRef(), attr);
605 },
606 nb::arg("value"), nb::arg("context") = nb::none(),
607 "Gets an uniqued float point attribute associated to a f64 type");
608 c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
609 "Returns the value of the float attribute");
610 c.def("__float__", mlirFloatAttrGetValueDouble,
611 "Converts the value of the float attribute to a Python float");
612 }
613};
614
615/// Integer Attribute subclass - IntegerAttr.
616class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
617public:
618 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
619 static constexpr const char *pyClassName = "IntegerAttr";
621
622 static void bindDerived(ClassTy &c) {
623 c.def_static(
624 "get",
625 [](PyType &type, int64_t value) {
626 MlirAttribute attr = mlirIntegerAttrGet(type, value);
627 return PyIntegerAttribute(type.getContext(), attr);
628 },
629 nb::arg("type"), nb::arg("value"),
630 "Gets an uniqued integer attribute associated to a type");
631 c.def_prop_ro("value", toPyInt,
632 "Returns the value of the integer attribute");
633 c.def("__int__", toPyInt,
634 "Converts the value of the integer attribute to a Python int");
635 c.def_prop_ro_static(
636 "static_typeid",
637 [](nb::object & /*class*/) {
639 },
640 nanobind::sig("def static_typeid(/) -> TypeID"));
641 }
642
643private:
644 static int64_t toPyInt(PyIntegerAttribute &self) {
645 MlirType type = mlirAttributeGetType(self);
647 return mlirIntegerAttrGetValueInt(self);
648 if (mlirIntegerTypeIsSigned(type))
649 return mlirIntegerAttrGetValueSInt(self);
650 return mlirIntegerAttrGetValueUInt(self);
651 }
652};
653
654/// Bool Attribute subclass - BoolAttr.
655class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
656public:
657 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
658 static constexpr const char *pyClassName = "BoolAttr";
660
661 static void bindDerived(ClassTy &c) {
662 c.def_static(
663 "get",
664 [](bool value, DefaultingPyMlirContext context) {
665 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
666 return PyBoolAttribute(context->getRef(), attr);
667 },
668 nb::arg("value"), nb::arg("context") = nb::none(),
669 "Gets an uniqued bool attribute");
670 c.def_prop_ro("value", mlirBoolAttrGetValue,
671 "Returns the value of the bool attribute");
672 c.def("__bool__", mlirBoolAttrGetValue,
673 "Converts the value of the bool attribute to a Python bool");
674 }
675};
676
677class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
678public:
679 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
680 static constexpr const char *pyClassName = "SymbolRefAttr";
682
683 static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
684 PyMlirContext &context) {
685 if (symbols.empty())
686 throw std::runtime_error("SymbolRefAttr must be composed of at least "
687 "one symbol.");
688 MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
689 SmallVector<MlirAttribute, 3> referenceAttrs;
690 for (size_t i = 1; i < symbols.size(); ++i) {
691 referenceAttrs.push_back(
692 mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
693 }
694 return PySymbolRefAttribute(context.getRef(),
695 mlirSymbolRefAttrGet(context.get(), rootSymbol,
696 referenceAttrs.size(),
697 referenceAttrs.data()));
698 }
699
700 static void bindDerived(ClassTy &c) {
701 c.def_static(
702 "get",
703 [](const std::vector<std::string> &symbols,
704 DefaultingPyMlirContext context) {
705 return PySymbolRefAttribute::fromList(symbols, context.resolve());
706 },
707 nb::arg("symbols"), nb::arg("context") = nb::none(),
708 "Gets a uniqued SymbolRef attribute from a list of symbol names");
709 c.def_prop_ro(
710 "value",
711 [](PySymbolRefAttribute &self) {
712 std::vector<std::string> symbols = {
714 for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
715 ++i)
716 symbols.push_back(
719 .str());
720 return symbols;
721 },
722 "Returns the value of the SymbolRef attribute as a list[str]");
723 }
724};
725
726class PyFlatSymbolRefAttribute
727 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
728public:
729 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
730 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
732
733 static void bindDerived(ClassTy &c) {
734 c.def_static(
735 "get",
736 [](const std::string &value, DefaultingPyMlirContext context) {
737 MlirAttribute attr =
738 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
739 return PyFlatSymbolRefAttribute(context->getRef(), attr);
740 },
741 nb::arg("value"), nb::arg("context") = nb::none(),
742 "Gets a uniqued FlatSymbolRef attribute");
743 c.def_prop_ro(
744 "value",
745 [](PyFlatSymbolRefAttribute &self) {
747 return nb::str(stringRef.data, stringRef.length);
748 },
749 "Returns the value of the FlatSymbolRef attribute as a string");
750 }
751};
752
753class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
754public:
755 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
756 static constexpr const char *pyClassName = "OpaqueAttr";
758 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
760
761 static void bindDerived(ClassTy &c) {
762 c.def_static(
763 "get",
764 [](const std::string &dialectNamespace, const nb_buffer &buffer,
765 PyType &type, DefaultingPyMlirContext context) {
766 const nb_buffer_info bufferInfo = buffer.request();
767 intptr_t bufferSize = bufferInfo.size;
768 MlirAttribute attr = mlirOpaqueAttrGet(
769 context->get(), toMlirStringRef(dialectNamespace), bufferSize,
770 static_cast<char *>(bufferInfo.ptr), type);
771 return PyOpaqueAttribute(context->getRef(), attr);
772 },
773 nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
774 nb::arg("context") = nb::none(),
775 // clang-format off
776 nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
777 // clang-format on
778 "Gets an Opaque attribute.");
779 c.def_prop_ro(
780 "dialect_namespace",
781 [](PyOpaqueAttribute &self) {
783 return nb::str(stringRef.data, stringRef.length);
784 },
785 "Returns the dialect namespace for the Opaque attribute as a string");
786 c.def_prop_ro(
787 "data",
788 [](PyOpaqueAttribute &self) {
789 MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
790 return nb::bytes(stringRef.data, stringRef.length);
791 },
792 "Returns the data for the Opaqued attributes as `bytes`");
793 }
794};
795
796// TODO: Support construction of string elements.
797class PyDenseElementsAttribute
798 : public PyConcreteAttribute<PyDenseElementsAttribute> {
799public:
800 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
801 static constexpr const char *pyClassName = "DenseElementsAttr";
803
804 static PyDenseElementsAttribute
805 getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
806 DefaultingPyMlirContext contextWrapper) {
807 const size_t numAttributes = nb::len(attributes);
808 if (numAttributes == 0)
809 throw nb::value_error("Attributes list must be non-empty.");
810
811 MlirType shapedType;
812 if (explicitType) {
813 if ((!mlirTypeIsAShaped(*explicitType) ||
814 !mlirShapedTypeHasStaticShape(*explicitType))) {
815
816 std::string message;
817 llvm::raw_string_ostream os(message);
818 os << "Expected a static ShapedType for the shaped_type parameter: "
819 << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
820 throw nb::value_error(message.c_str());
821 }
822 shapedType = *explicitType;
823 } else {
824 SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
825 shapedType = mlirRankedTensorTypeGet(
826 shape.size(), shape.data(),
827 mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
829 }
830
831 SmallVector<MlirAttribute> mlirAttributes;
832 mlirAttributes.reserve(numAttributes);
833 for (const nb::handle &attribute : attributes) {
834 MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
835 MlirType attrType = mlirAttributeGetType(mlirAttribute);
836 mlirAttributes.push_back(mlirAttribute);
837
838 if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
839 std::string message;
840 llvm::raw_string_ostream os(message);
841 os << "All attributes must be of the same type and match "
842 << "the type parameter: expected="
843 << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
844 << ", but got="
845 << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
846 throw nb::value_error(message.c_str());
847 }
848 }
849
850 MlirAttribute elements = mlirDenseElementsAttrGet(
851 shapedType, mlirAttributes.size(), mlirAttributes.data());
852
853 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
854 }
855
856 static PyDenseElementsAttribute
857 getFromBuffer(const nb_buffer &array, bool signless,
858 const std::optional<PyType> &explicitType,
859 std::optional<std::vector<int64_t>> explicitShape,
860 DefaultingPyMlirContext contextWrapper) {
861 // Request a contiguous view. In exotic cases, this will cause a copy.
862 int flags = PyBUF_ND;
863 if (!explicitType) {
864 flags |= PyBUF_FORMAT;
865 }
866 Py_buffer view;
867 if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
868 throw nb::python_error();
869 }
870 auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
871
872 MlirContext context = contextWrapper->get();
873 MlirAttribute attr = getAttributeFromBuffer(
874 view, signless, explicitType, std::move(explicitShape), context);
875 if (mlirAttributeIsNull(attr)) {
876 throw std::invalid_argument(
877 "DenseElementsAttr could not be constructed from the given buffer. "
878 "This may mean that the Python buffer layout does not match that "
879 "MLIR expected layout and is a bug.");
880 }
881 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
882 }
883
884 static PyDenseElementsAttribute getSplat(const PyType &shapedType,
885 PyAttribute &elementAttr) {
886 auto contextWrapper =
888 if (!mlirAttributeIsAInteger(elementAttr) &&
889 !mlirAttributeIsAFloat(elementAttr)) {
890 std::string message = "Illegal element type for DenseElementsAttr: ";
891 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
892 throw nb::value_error(message.c_str());
893 }
894 if (!mlirTypeIsAShaped(shapedType) ||
895 !mlirShapedTypeHasStaticShape(shapedType)) {
896 std::string message =
897 "Expected a static ShapedType for the shaped_type parameter: ";
898 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
899 throw nb::value_error(message.c_str());
900 }
901 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
902 MlirType attrType = mlirAttributeGetType(elementAttr);
903 if (!mlirTypeEqual(shapedElementType, attrType)) {
904 std::string message =
905 "Shaped element type and attribute type must be equal: shaped=";
906 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
907 message.append(", element=");
908 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
909 throw nb::value_error(message.c_str());
910 }
911
912 MlirAttribute elements =
913 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
914 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
915 }
916
917 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
918
919 std::unique_ptr<nb_buffer_info> accessBuffer() {
920 MlirType shapedType = mlirAttributeGetType(*this);
921 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
922 std::string format;
923
924 if (mlirTypeIsAF32(elementType)) {
925 // f32
926 return bufferInfo<float>(shapedType);
927 }
928 if (mlirTypeIsAF64(elementType)) {
929 // f64
930 return bufferInfo<double>(shapedType);
931 }
932 if (mlirTypeIsAF16(elementType)) {
933 // f16
934 return bufferInfo<uint16_t>(shapedType, "e");
935 }
936 if (mlirTypeIsAIndex(elementType)) {
937 // Same as IndexType::kInternalStorageBitWidth
938 return bufferInfo<int64_t>(shapedType);
939 }
940 if (mlirTypeIsAInteger(elementType) &&
941 mlirIntegerTypeGetWidth(elementType) == 32) {
942 if (mlirIntegerTypeIsSignless(elementType) ||
943 mlirIntegerTypeIsSigned(elementType)) {
944 // i32
945 return bufferInfo<int32_t>(shapedType);
946 }
947 if (mlirIntegerTypeIsUnsigned(elementType)) {
948 // unsigned i32
949 return bufferInfo<uint32_t>(shapedType);
950 }
951 } else if (mlirTypeIsAInteger(elementType) &&
952 mlirIntegerTypeGetWidth(elementType) == 64) {
953 if (mlirIntegerTypeIsSignless(elementType) ||
954 mlirIntegerTypeIsSigned(elementType)) {
955 // i64
956 return bufferInfo<int64_t>(shapedType);
957 }
958 if (mlirIntegerTypeIsUnsigned(elementType)) {
959 // unsigned i64
960 return bufferInfo<uint64_t>(shapedType);
961 }
962 } else if (mlirTypeIsAInteger(elementType) &&
963 mlirIntegerTypeGetWidth(elementType) == 8) {
964 if (mlirIntegerTypeIsSignless(elementType) ||
965 mlirIntegerTypeIsSigned(elementType)) {
966 // i8
967 return bufferInfo<int8_t>(shapedType);
968 }
969 if (mlirIntegerTypeIsUnsigned(elementType)) {
970 // unsigned i8
971 return bufferInfo<uint8_t>(shapedType);
972 }
973 } else if (mlirTypeIsAInteger(elementType) &&
974 mlirIntegerTypeGetWidth(elementType) == 16) {
975 if (mlirIntegerTypeIsSignless(elementType) ||
976 mlirIntegerTypeIsSigned(elementType)) {
977 // i16
978 return bufferInfo<int16_t>(shapedType);
979 }
980 if (mlirIntegerTypeIsUnsigned(elementType)) {
981 // unsigned i16
982 return bufferInfo<uint16_t>(shapedType);
983 }
984 } else if (mlirTypeIsAInteger(elementType) &&
985 mlirIntegerTypeGetWidth(elementType) == 1) {
986 // i1 / bool
987 // We can not send the buffer directly back to Python, because the i1
988 // values are bitpacked within MLIR. We call numpy's unpackbits function
989 // to convert the bytes.
990 return getBooleanBufferFromBitpackedAttribute();
991 }
992
993 // TODO: Currently crashes the program.
994 // Reported as https://github.com/pybind/pybind11/issues/3336
995 throw std::invalid_argument(
996 "unsupported data type for conversion to Python buffer");
997 }
998
999 static void bindDerived(ClassTy &c) {
1000#if PY_VERSION_HEX < 0x03090000
1001 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1002 tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1003 tp->tp_as_buffer->bf_releasebuffer =
1004 PyDenseElementsAttribute::bf_releasebuffer;
1005#endif
1006 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1007 .def_static(
1008 "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
1009 nb::arg("signless") = true, nb::arg("type") = nb::none(),
1010 nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
1011 // clang-format off
1012 nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
1013 // clang-format on
1015 .def_static("get", PyDenseElementsAttribute::getFromList,
1016 nb::arg("attrs"), nb::arg("type") = nb::none(),
1017 nb::arg("context") = nb::none(),
1019 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1020 nb::arg("shaped_type"), nb::arg("element_attr"),
1021 "Gets a DenseElementsAttr where all values are the same")
1022 .def_prop_ro("is_splat",
1023 [](PyDenseElementsAttribute &self) -> bool {
1024 return mlirDenseElementsAttrIsSplat(self);
1025 })
1026 .def("get_splat_value",
1027 [](PyDenseElementsAttribute &self)
1028 -> nb::typed<nb::object, PyAttribute> {
1030 throw nb::value_error(
1031 "get_splat_value called on a non-splat attribute");
1032 return PyAttribute(self.getContext(),
1034 .maybeDownCast();
1035 });
1036 }
1037
1038 static PyType_Slot slots[];
1039
1040private:
1041 static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1042 static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1043
1044 static bool isUnsignedIntegerFormat(std::string_view format) {
1045 if (format.empty())
1046 return false;
1047 char code = format[0];
1048 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1049 code == 'Q';
1050 }
1051
1052 static bool isSignedIntegerFormat(std::string_view format) {
1053 if (format.empty())
1054 return false;
1055 char code = format[0];
1056 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1057 code == 'q';
1058 }
1059
1060 static MlirType
1061 getShapedType(std::optional<MlirType> bulkLoadElementType,
1062 std::optional<std::vector<int64_t>> explicitShape,
1063 Py_buffer &view) {
1065 if (explicitShape) {
1066 shape.append(explicitShape->begin(), explicitShape->end());
1067 } else {
1068 shape.append(view.shape, view.shape + view.ndim);
1069 }
1070
1071 if (mlirTypeIsAShaped(*bulkLoadElementType)) {
1072 if (explicitShape) {
1073 throw std::invalid_argument("Shape can only be specified explicitly "
1074 "when the type is not a shaped type.");
1075 }
1076 return *bulkLoadElementType;
1077 }
1078 MlirAttribute encodingAttr = mlirAttributeGetNull();
1079 return mlirRankedTensorTypeGet(shape.size(), shape.data(),
1080 *bulkLoadElementType, encodingAttr);
1081 }
1082
1083 static MlirAttribute getAttributeFromBuffer(
1084 Py_buffer &view, bool signless, std::optional<PyType> explicitType,
1085 const std::optional<std::vector<int64_t>> &explicitShape,
1086 MlirContext &context) {
1087 // Detect format codes that are suitable for bulk loading. This includes
1088 // all byte aligned integer and floating point types up to 8 bytes.
1089 // Notably, this excludes exotics types which do not have a direct
1090 // representation in the buffer protocol (i.e. complex, etc).
1091 std::optional<MlirType> bulkLoadElementType;
1092 if (explicitType) {
1093 bulkLoadElementType = *explicitType;
1094 } else {
1095 std::string_view format(view.format);
1096 if (format == "f") {
1097 // f32
1098 assert(view.itemsize == 4 && "mismatched array itemsize");
1099 bulkLoadElementType = mlirF32TypeGet(context);
1100 } else if (format == "d") {
1101 // f64
1102 assert(view.itemsize == 8 && "mismatched array itemsize");
1103 bulkLoadElementType = mlirF64TypeGet(context);
1104 } else if (format == "e") {
1105 // f16
1106 assert(view.itemsize == 2 && "mismatched array itemsize");
1107 bulkLoadElementType = mlirF16TypeGet(context);
1108 } else if (format == "?") {
1109 // i1
1110 // The i1 type needs to be bit-packed, so we will handle it separately
1111 return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
1112 context);
1113 } else if (isSignedIntegerFormat(format)) {
1114 if (view.itemsize == 4) {
1115 // i32
1116 bulkLoadElementType = signless
1117 ? mlirIntegerTypeGet(context, 32)
1118 : mlirIntegerTypeSignedGet(context, 32);
1119 } else if (view.itemsize == 8) {
1120 // i64
1121 bulkLoadElementType = signless
1122 ? mlirIntegerTypeGet(context, 64)
1123 : mlirIntegerTypeSignedGet(context, 64);
1124 } else if (view.itemsize == 1) {
1125 // i8
1126 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1127 : mlirIntegerTypeSignedGet(context, 8);
1128 } else if (view.itemsize == 2) {
1129 // i16
1130 bulkLoadElementType = signless
1131 ? mlirIntegerTypeGet(context, 16)
1132 : mlirIntegerTypeSignedGet(context, 16);
1133 }
1134 } else if (isUnsignedIntegerFormat(format)) {
1135 if (view.itemsize == 4) {
1136 // unsigned i32
1137 bulkLoadElementType = signless
1138 ? mlirIntegerTypeGet(context, 32)
1139 : mlirIntegerTypeUnsignedGet(context, 32);
1140 } else if (view.itemsize == 8) {
1141 // unsigned i64
1142 bulkLoadElementType = signless
1143 ? mlirIntegerTypeGet(context, 64)
1144 : mlirIntegerTypeUnsignedGet(context, 64);
1145 } else if (view.itemsize == 1) {
1146 // i8
1147 bulkLoadElementType = signless
1148 ? mlirIntegerTypeGet(context, 8)
1149 : mlirIntegerTypeUnsignedGet(context, 8);
1150 } else if (view.itemsize == 2) {
1151 // i16
1152 bulkLoadElementType = signless
1153 ? mlirIntegerTypeGet(context, 16)
1154 : mlirIntegerTypeUnsignedGet(context, 16);
1155 }
1156 }
1157 if (!bulkLoadElementType) {
1158 throw std::invalid_argument(
1159 std::string("unimplemented array format conversion from format: ") +
1160 std::string(format));
1161 }
1162 }
1163
1164 MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1165 return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1166 }
1167
1168 // There is a complication for boolean numpy arrays, as numpy represents
1169 // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1170 // booleans per byte.
1171 static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1172 Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1173 MlirContext &context) {
1174 if (llvm::endianness::native != llvm::endianness::little) {
1175 // Given we have no good way of testing the behavior on big-endian
1176 // systems we will throw
1177 throw nb::type_error("Constructing a bit-packed MLIR attribute is "
1178 "unsupported on big-endian systems");
1179 }
1180 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1181 /*data=*/static_cast<uint8_t *>(view.buf),
1182 /*shape=*/{static_cast<size_t>(view.len)});
1183
1184 nb::module_ numpy = nb::module_::import_("numpy");
1185 nb::object packbitsFunc = numpy.attr("packbits");
1186 nb::object packedBooleans =
1187 packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1188 nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
1189
1190 MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
1191 std::move(explicitShape), view);
1192 assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1193 // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1194 // packedBooleans, hence the MlirAttribute will remain valid even when
1195 // packedBooleans get reclaimed by the end of the function.
1196 return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1197 pythonBuffer.ptr);
1198 }
1199
1200 // This does the opposite transformation of
1201 // `getBitpackedAttributeFromBooleanBuffer`
1202 std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
1203 if (llvm::endianness::native != llvm::endianness::little) {
1204 // Given we have no good way of testing the behavior on big-endian
1205 // systems we will throw
1206 throw nb::type_error("Constructing a numpy array from a MLIR attribute "
1207 "is unsupported on big-endian systems");
1208 }
1209
1210 int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1211 int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1212 uint8_t *bitpackedData = static_cast<uint8_t *>(
1213 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1214 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1215 /*data=*/bitpackedData,
1216 /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
1217
1218 nb::module_ numpy = nb::module_::import_("numpy");
1219 nb::object unpackbitsFunc = numpy.attr("unpackbits");
1220 nb::object equalFunc = numpy.attr("equal");
1221 nb::object reshapeFunc = numpy.attr("reshape");
1222 nb::object unpackedBooleans =
1223 unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
1224
1225 // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1226 // We need to:
1227 // 1. Slice away the padded bits
1228 // 2. Make the boolean array have the correct shape
1229 // 3. Convert the array to a boolean array
1230 unpackedBooleans = unpackedBooleans[nb::slice(
1231 nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
1232 unpackedBooleans = equalFunc(unpackedBooleans, 1);
1233
1234 MlirType shapedType = mlirAttributeGetType(*this);
1235 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1236 std::vector<intptr_t> shape(rank);
1237 for (intptr_t i = 0; i < rank; ++i) {
1238 shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
1239 }
1240 unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1241
1242 // Make sure the returned nb::buffer_view claims ownership of the data in
1243 // `pythonBuffer` so it remains valid when Python reads it
1244 nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1245 return std::make_unique<nb_buffer_info>(pythonBuffer.request());
1246 }
1247
1248 template <typename Type>
1249 std::unique_ptr<nb_buffer_info>
1250 bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
1251 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1252 // Prepare the data for the buffer_info.
1253 // Buffer is configured for read-only access below.
1254 Type *data = static_cast<Type *>(
1255 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1256 // Prepare the shape for the buffer_info.
1258 for (intptr_t i = 0; i < rank; ++i)
1259 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1260 // Prepare the strides for the buffer_info.
1262 if (mlirDenseElementsAttrIsSplat(*this)) {
1263 // Splats are special, only the single value is stored.
1264 strides.assign(rank, 0);
1265 } else {
1266 for (intptr_t i = 1; i < rank; ++i) {
1267 intptr_t strideFactor = 1;
1268 for (intptr_t j = i; j < rank; ++j)
1269 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1270 strides.push_back(sizeof(Type) * strideFactor);
1271 }
1272 strides.push_back(sizeof(Type));
1273 }
1274 const char *format;
1275 if (explicitFormat) {
1276 format = explicitFormat;
1277 } else {
1278 format = nb_format_descriptor<Type>::format();
1279 }
1280 return std::make_unique<nb_buffer_info>(
1281 data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
1282 /*readonly=*/true);
1283 }
1284}; // namespace
1285
1286PyType_Slot PyDenseElementsAttribute::slots[] = {
1287// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1288#if PY_VERSION_HEX >= 0x03090000
1289 {Py_bf_getbuffer,
1290 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1291 {Py_bf_releasebuffer,
1292 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1293#endif
1294 {0, nullptr},
1295};
1296
1297/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1298 Py_buffer *view,
1299 int flags) {
1300 view->obj = nullptr;
1301 std::unique_ptr<nb_buffer_info> info;
1302 try {
1303 auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1304 info = attr->accessBuffer();
1305 } catch (nb::python_error &e) {
1306 e.restore();
1307 nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1308 return -1;
1309 } catch (std::exception &e) {
1310 nb::chain_error(PyExc_BufferError,
1311 "Error converting attribute to buffer: %s", e.what());
1312 return -1;
1313 }
1314 view->obj = obj;
1315 view->ndim = 1;
1316 view->buf = info->ptr;
1317 view->itemsize = info->itemsize;
1318 view->len = info->itemsize;
1319 for (auto s : info->shape) {
1320 view->len *= s;
1321 }
1322 view->readonly = info->readonly;
1323 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1324 view->format = const_cast<char *>(info->format);
1325 }
1326 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1327 view->ndim = static_cast<int>(info->ndim);
1328 view->strides = info->strides.data();
1329 view->shape = info->shape.data();
1330 }
1331 view->suboffsets = nullptr;
1332 view->internal = info.release();
1333 Py_INCREF(obj);
1334 return 0;
1335}
1336
1337/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1338 Py_buffer *view) {
1339 delete reinterpret_cast<nb_buffer_info *>(view->internal);
1340}
1341
1342/// Refinement of the PyDenseElementsAttribute for attributes containing
1343/// integer (and boolean) values. Supports element access.
1344class PyDenseIntElementsAttribute
1345 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1346 PyDenseElementsAttribute> {
1347public:
1348 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1349 static constexpr const char *pyClassName = "DenseIntElementsAttr";
1351
1352 /// Returns the element at the given linear position. Asserts if the index
1353 /// is out of range.
1354 nb::int_ dunderGetItem(intptr_t pos) {
1355 if (pos < 0 || pos >= dunderLen()) {
1356 throw nb::index_error("attempt to access out of bounds element");
1357 }
1358
1359 MlirType type = mlirAttributeGetType(*this);
1360 type = mlirShapedTypeGetElementType(type);
1361 // Index type can also appear as a DenseIntElementsAttr and therefore can be
1362 // casted to integer.
1363 assert(mlirTypeIsAInteger(type) ||
1364 mlirTypeIsAIndex(type) && "expected integer/index element type in "
1365 "dense int elements attribute");
1366 // Dispatch element extraction to an appropriate C function based on the
1367 // elemental type of the attribute. nb::int_ is implicitly constructible
1368 // from any C++ integral type and handles bitwidth correctly.
1369 // TODO: consider caching the type properties in the constructor to avoid
1370 // querying them on each element access.
1371 if (mlirTypeIsAIndex(type)) {
1372 return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1373 }
1374 unsigned width = mlirIntegerTypeGetWidth(type);
1375 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1376 if (isUnsigned) {
1377 if (width == 1) {
1378 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1379 }
1380 if (width == 8) {
1381 return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1382 }
1383 if (width == 16) {
1384 return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1385 }
1386 if (width == 32) {
1387 return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1388 }
1389 if (width == 64) {
1390 return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1391 }
1392 } else {
1393 if (width == 1) {
1394 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1395 }
1396 if (width == 8) {
1397 return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1398 }
1399 if (width == 16) {
1400 return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1401 }
1402 if (width == 32) {
1403 return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1404 }
1405 if (width == 64) {
1406 return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1407 }
1408 }
1409 throw nb::type_error("Unsupported integer type");
1410 }
1411
1412 static void bindDerived(ClassTy &c) {
1413 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1414 }
1415};
1416
1417// Check if the python version is less than 3.13. Py_IsFinalizing is a part
1418// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
1419#if PY_VERSION_HEX < 0x030d0000
1420#define Py_IsFinalizing _Py_IsFinalizing
1421#endif
1422
1423class PyDenseResourceElementsAttribute
1424 : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1425public:
1426 static constexpr IsAFunctionTy isaFunction =
1428 static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1430
1431 static PyDenseResourceElementsAttribute
1432 getFromBuffer(const nb_buffer &buffer, const std::string &name,
1433 const PyType &type, std::optional<size_t> alignment,
1434 bool isMutable, DefaultingPyMlirContext contextWrapper) {
1435 if (!mlirTypeIsAShaped(type)) {
1436 throw std::invalid_argument(
1437 "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1438 }
1439
1440 // Do not request any conversions as we must ensure to use caller
1441 // managed memory.
1442 int flags = PyBUF_STRIDES;
1443 std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1444 if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1445 throw nb::python_error();
1446 }
1447
1448 // This scope releaser will only release if we haven't yet transferred
1449 // ownership.
1450 auto freeBuffer = llvm::make_scope_exit([&]() {
1451 if (view)
1452 PyBuffer_Release(view.get());
1453 });
1454
1455 if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1456 throw std::invalid_argument("Contiguous buffer is required.");
1457 }
1458
1459 // Infer alignment to be the stride of one element if not explicit.
1460 size_t inferredAlignment;
1461 if (alignment)
1462 inferredAlignment = *alignment;
1463 else
1464 inferredAlignment = view->strides[view->ndim - 1];
1465
1466 // The userData is a Py_buffer* that the deleter owns.
1467 auto deleter = [](void *userData, const void *data, size_t size,
1468 size_t align) {
1469 if (Py_IsFinalizing())
1470 return;
1471 assert(Py_IsInitialized() && "expected interpreter to be initialized");
1472 Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1473 nb::gil_scoped_acquire gil;
1474 PyBuffer_Release(ownedView);
1475 delete ownedView;
1476 };
1477
1478 size_t rawBufferSize = view->len;
1479 MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1480 type, toMlirStringRef(name), view->buf, rawBufferSize,
1481 inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1482 if (mlirAttributeIsNull(attr)) {
1483 throw std::invalid_argument(
1484 "DenseResourceElementsAttr could not be constructed from the given "
1485 "buffer. "
1486 "This may mean that the Python buffer layout does not match that "
1487 "MLIR expected layout and is a bug.");
1488 }
1489 view.release();
1490 return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1491 }
1492
1493 static void bindDerived(ClassTy &c) {
1494 c.def_static("get_from_buffer",
1495 PyDenseResourceElementsAttribute::getFromBuffer,
1496 nb::arg("array"), nb::arg("name"), nb::arg("type"),
1497 nb::arg("alignment") = nb::none(),
1498 nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
1499 // clang-format off
1500 nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
1501 // clang-format on
1503 }
1504};
1505
1506class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1507public:
1508 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1509 static constexpr const char *pyClassName = "DictAttr";
1511 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1513
1514 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1515
1516 bool dunderContains(const std::string &name) {
1517 return !mlirAttributeIsNull(
1519 }
1520
1521 static void bindDerived(ClassTy &c) {
1522 c.def("__contains__", &PyDictAttribute::dunderContains);
1523 c.def("__len__", &PyDictAttribute::dunderLen);
1524 c.def_static(
1525 "get",
1526 [](const nb::dict &attributes, DefaultingPyMlirContext context) {
1527 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1528 mlirNamedAttributes.reserve(attributes.size());
1529 for (std::pair<nb::handle, nb::handle> it : attributes) {
1530 auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1531 auto name = nb::cast<std::string>(it.first);
1532 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1534 toMlirStringRef(name)),
1535 mlirAttr));
1536 }
1537 MlirAttribute attr =
1538 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1539 mlirNamedAttributes.data());
1540 return PyDictAttribute(context->getRef(), attr);
1541 },
1542 nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
1543 "Gets an uniqued dict attribute");
1544 c.def("__getitem__",
1545 [](PyDictAttribute &self,
1546 const std::string &name) -> nb::typed<nb::object, PyAttribute> {
1547 MlirAttribute attr =
1549 if (mlirAttributeIsNull(attr))
1550 throw nb::key_error("attempt to access a non-existent attribute");
1551 return PyAttribute(self.getContext(), attr).maybeDownCast();
1552 });
1553 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1554 if (index < 0 || index >= self.dunderLen()) {
1555 throw nb::index_error("attempt to access out of bounds attribute");
1556 }
1557 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1558 return PyNamedAttribute(
1559 namedAttr.attribute,
1560 std::string(mlirIdentifierStr(namedAttr.name).data));
1561 });
1562 }
1563};
1564
1565/// Refinement of PyDenseElementsAttribute for attributes containing
1566/// floating-point values. Supports element access.
1567class PyDenseFPElementsAttribute
1568 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1569 PyDenseElementsAttribute> {
1570public:
1571 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1572 static constexpr const char *pyClassName = "DenseFPElementsAttr";
1574
1575 nb::float_ dunderGetItem(intptr_t pos) {
1576 if (pos < 0 || pos >= dunderLen()) {
1577 throw nb::index_error("attempt to access out of bounds element");
1578 }
1579
1580 MlirType type = mlirAttributeGetType(*this);
1581 type = mlirShapedTypeGetElementType(type);
1582 // Dispatch element extraction to an appropriate C function based on the
1583 // elemental type of the attribute. nb::float_ is implicitly constructible
1584 // from float and double.
1585 // TODO: consider caching the type properties in the constructor to avoid
1586 // querying them on each element access.
1587 if (mlirTypeIsAF32(type)) {
1588 return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1589 }
1590 if (mlirTypeIsAF64(type)) {
1591 return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1592 }
1593 throw nb::type_error("Unsupported floating-point type");
1594 }
1595
1596 static void bindDerived(ClassTy &c) {
1597 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1598 }
1599};
1600
1601class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1602public:
1603 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1604 static constexpr const char *pyClassName = "TypeAttr";
1606 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1608
1609 static void bindDerived(ClassTy &c) {
1610 c.def_static(
1611 "get",
1612 [](const PyType &value, DefaultingPyMlirContext context) {
1613 MlirAttribute attr = mlirTypeAttrGet(value.get());
1614 return PyTypeAttribute(context->getRef(), attr);
1615 },
1616 nb::arg("value"), nb::arg("context") = nb::none(),
1617 "Gets a uniqued Type attribute");
1618 c.def_prop_ro(
1619 "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
1620 return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1621 .maybeDownCast();
1622 });
1623 }
1624};
1625
1626/// Unit Attribute subclass. Unit attributes don't have values.
1627class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1628public:
1629 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1630 static constexpr const char *pyClassName = "UnitAttr";
1632 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1634
1635 static void bindDerived(ClassTy &c) {
1636 c.def_static(
1637 "get",
1638 [](DefaultingPyMlirContext context) {
1639 return PyUnitAttribute(context->getRef(),
1640 mlirUnitAttrGet(context->get()));
1641 },
1642 nb::arg("context") = nb::none(), "Create a Unit attribute.");
1643 }
1644};
1645
1646/// Strided layout attribute subclass.
1647class PyStridedLayoutAttribute
1648 : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1649public:
1650 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1651 static constexpr const char *pyClassName = "StridedLayoutAttr";
1653 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1655
1656 static void bindDerived(ClassTy &c) {
1657 c.def_static(
1658 "get",
1659 [](int64_t offset, const std::vector<int64_t> &strides,
1660 DefaultingPyMlirContext ctx) {
1661 MlirAttribute attr = mlirStridedLayoutAttrGet(
1662 ctx->get(), offset, strides.size(), strides.data());
1663 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1664 },
1665 nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
1666 "Gets a strided layout attribute.");
1667 c.def_static(
1668 "get_fully_dynamic",
1669 [](int64_t rank, DefaultingPyMlirContext ctx) {
1671 std::vector<int64_t> strides(rank);
1672 llvm::fill(strides, dynamic);
1673 MlirAttribute attr = mlirStridedLayoutAttrGet(
1674 ctx->get(), dynamic, strides.size(), strides.data());
1675 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1676 },
1677 nb::arg("rank"), nb::arg("context") = nb::none(),
1678 "Gets a strided layout attribute with dynamic offset and strides of "
1679 "a "
1680 "given rank.");
1681 c.def_prop_ro(
1682 "offset",
1683 [](PyStridedLayoutAttribute &self) {
1684 return mlirStridedLayoutAttrGetOffset(self);
1685 },
1686 "Returns the value of the float point attribute");
1687 c.def_prop_ro(
1688 "strides",
1689 [](PyStridedLayoutAttribute &self) {
1690 intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1691 std::vector<int64_t> strides(size);
1692 for (intptr_t i = 0; i < size; i++) {
1693 strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1694 }
1695 return strides;
1696 },
1697 "Returns the value of the float point attribute");
1698 }
1699};
1700
1701nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1702 if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1703 return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
1704 if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1705 return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
1706 if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1707 return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
1708 if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1709 return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
1710 if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1711 return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
1712 if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1713 return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
1714 if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1715 return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
1716 std::string msg =
1717 std::string("Can't cast unknown element type DenseArrayAttr (") +
1718 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1719 throw nb::type_error(msg.c_str());
1720}
1721
1722nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1723 if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1724 return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
1725 if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1726 return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
1727 std::string msg =
1728 std::string(
1729 "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1730 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1731 throw nb::type_error(msg.c_str());
1732}
1733
1734nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1735 if (PyBoolAttribute::isaFunction(pyAttribute))
1736 return nb::cast(PyBoolAttribute(pyAttribute));
1737 if (PyIntegerAttribute::isaFunction(pyAttribute))
1738 return nb::cast(PyIntegerAttribute(pyAttribute));
1739 std::string msg = std::string("Can't cast unknown attribute type Attr (") +
1740 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1741 ")";
1742 throw nb::type_error(msg.c_str());
1743}
1744
1745nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1746 if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1747 return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
1748 if (PySymbolRefAttribute::isaFunction(pyAttribute))
1749 return nb::cast(PySymbolRefAttribute(pyAttribute));
1750 std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1751 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1752 ")";
1753 throw nb::type_error(msg.c_str());
1754}
1755
1756} // namespace
1757
1759 c.def_static(
1760 "get",
1761 [](const std::string &value, DefaultingPyMlirContext context) {
1762 MlirAttribute attr =
1763 mlirStringAttrGet(context->get(), toMlirStringRef(value));
1764 return PyStringAttribute(context->getRef(), attr);
1765 },
1766 nb::arg("value"), nb::arg("context") = nb::none(),
1767 "Gets a uniqued string attribute");
1768 c.def_static(
1769 "get",
1770 [](const nb::bytes &value, DefaultingPyMlirContext context) {
1771 MlirAttribute attr =
1772 mlirStringAttrGet(context->get(), toMlirStringRef(value));
1773 return PyStringAttribute(context->getRef(), attr);
1774 },
1775 nb::arg("value"), nb::arg("context") = nb::none(),
1776 "Gets a uniqued string attribute");
1777 c.def_static(
1778 "get_typed",
1779 [](PyType &type, const std::string &value) {
1780 MlirAttribute attr =
1782 return PyStringAttribute(type.getContext(), attr);
1783 },
1784 nb::arg("type"), nb::arg("value"),
1785 "Gets a uniqued string attribute associated to a type");
1786 c.def_prop_ro(
1787 "value",
1788 [](PyStringAttribute &self) {
1789 MlirStringRef stringRef = mlirStringAttrGetValue(self);
1790 return nb::str(stringRef.data, stringRef.length);
1791 },
1792 "Returns the value of the string attribute");
1793 c.def_prop_ro(
1794 "value_bytes",
1795 [](PyStringAttribute &self) {
1796 MlirStringRef stringRef = mlirStringAttrGetValue(self);
1797 return nb::bytes(stringRef.data, stringRef.length);
1798 },
1799 "Returns the value of the string attribute as `bytes`");
1800}
1801
1802void mlir::python::populateIRAttributes(nb::module_ &m) {
1803 PyAffineMapAttribute::bind(m);
1804 PyDenseBoolArrayAttribute::bind(m);
1805 PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1806 PyDenseI8ArrayAttribute::bind(m);
1807 PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1808 PyDenseI16ArrayAttribute::bind(m);
1809 PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1810 PyDenseI32ArrayAttribute::bind(m);
1811 PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1812 PyDenseI64ArrayAttribute::bind(m);
1813 PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1814 PyDenseF32ArrayAttribute::bind(m);
1815 PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1816 PyDenseF64ArrayAttribute::bind(m);
1817 PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1820 nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1821
1822 PyArrayAttribute::bind(m);
1823 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1824 PyBoolAttribute::bind(m);
1825 PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1826 PyDenseFPElementsAttribute::bind(m);
1827 PyDenseIntElementsAttribute::bind(m);
1830 nb::cast<nb::callable>(
1831 nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1832 PyDenseResourceElementsAttribute::bind(m);
1833
1834 PyDictAttribute::bind(m);
1835 PySymbolRefAttribute::bind(m);
1838 nb::cast<nb::callable>(
1839 nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
1840
1841 PyFlatSymbolRefAttribute::bind(m);
1842 PyOpaqueAttribute::bind(m);
1843 PyFloatAttribute::bind(m);
1844 PyIntegerAttribute::bind(m);
1845 PyIntegerSetAttribute::bind(m);
1847 PyTypeAttribute::bind(m);
1850 nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1851 PyUnitAttribute::bind(m);
1852
1853 PyStridedLayoutAttribute::bind(m);
1854}
#define Py_IsFinalizing
static const char kDenseElementsAttrGetDocstring[]
static const char kDenseResourceElementsAttrGetFromBufferDocstring[]
static const char kDenseElementsAttrGetFromListDocstring[]
static MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.cpp:77
MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Definition IR.cpp:1275
MlirType mlirAttributeGetType(MlirAttribute attribute)
Definition IR.cpp:1279
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
std::string str() const
Converts the diagnostic to a string.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRModule.h:292
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRModule.h:499
Used in function arguments when None should resolve to the current context manager set instance.
Definition IRModule.h:273
static PyMlirContext & resolve()
Definition IRCore.cpp:672
ReferrentTy * get() const
MlirAffineMap get() const
Definition IRModule.h:1230
Wrapper around the generic MlirAttribute.
Definition IRModule.h:1008
nanobind::object maybeDownCast()
Definition IRCore.cpp:2044
MlirAttribute get() const
Definition IRModule.h:1014
CRTP base classes for Python attributes that subclass Attribute and should be castable from it (i....
Definition IRModule.h:1060
nanobind::class_< DerivedTy, BaseTy > ClassTy
Definition IRModule.h:1065
static void bind(nanobind::module_ &m, PyType_Slot *slots=nullptr)
Definition IRModule.h:1089
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace=false)
Adds a user-friendly type caster.
Definition IRModule.cpp:87
static PyGlobals & get()
Most code should get the globals via this static accessor.
Definition Globals.h:39
MlirIntegerSet get() const
Definition IRModule.h:1251
MlirContext get()
Accesses the underlying MlirContext.
Definition IRModule.h:204
PyMlirContextRef getRef()
Gets a strong reference to this context, which will ensure it is kept alive for the life of the refer...
Definition IRModule.h:208
static PyMlirContextRef forContext(MlirContext context)
Returns a context reference for the singleton PyMlirContext wrapper for the given context.
Definition IRCore.cpp:568
static void bindDerived(ClassTy &c)
A TypeID provides an efficient and unique identifier for a specific C++ type.
Definition IRModule.h:904
Wrapper around the generic MlirType.
Definition IRModule.h:878
MlirType get() const
Definition IRModule.h:884
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map)
Creates an affine map attribute wrapping the given map.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type)
Creates an opaque attribute in the given context associated with the dialect identified by its namesp...
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, double value)
Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a construction of a FloatAttr,...
MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr)
Returns the affine map wrapped in the given affine map attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, const int64_t *strides)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr)
Checks whether the given attribute is a unit attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr)
Checks whether the given attribute is a dense elements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr)
Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAAffineMap(MlirAttribute attr)
Checks whether the given attribute is an affine map attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol)
Creates a flat symbol reference attribute in the given context referencing a symbol identified by the...
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void)
Returns the typeID of a StridedLayout attribute.
MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr)
Returns the raw data of the given dense elements attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void)
Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseResourceElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr)
Returns the string reference to the root referenced symbol.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAType(MlirAttribute attr)
Checks whether the given attribute is a type attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void)
Returns the typeID of an DenseIntOrFPElements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAArray(MlirAttribute attr)
Checks whether the given attribute is an array attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAInteger(MlirAttribute attr)
Checks whether the given attribute is an integer attribute.
MLIR_CAPI_EXPORTED intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr)
Returns the number of attributes contained in a dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set)
Creates an integer set attribute wrapping the given set.
MLIR_CAPI_EXPORTED uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr)
Returns the value stored in the given bool attribute.
MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value)
Creates an integer attribute of the given type with the given integer value.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void)
Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos)
Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense e...
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute const *elements)
Creates a dictionary attribute containing the given list of elements in the provided context.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, size_t dataAlignment, bool dataIsMutable, void(*deleter)(void *userData, const void *data, size_t size, size_t align), void *userData)
Unlike the typed accessors below, constructs the attribute with a raw data buffer and no type/alignme...
MLIR_CAPI_EXPORTED bool mlirAttributeIsABool(MlirAttribute attr)
Checks whether the given attribute is a bool attribute.
MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos)
Get an element of a dense array.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, int64_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void)
Returns the typeID of an AffineMap attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos)
Returns pos-th reference nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void)
Returns the typeID of an Array attribute.
MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, double const *values)
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signless type and f...
MLIR_CAPI_EXPORTED intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)
Returns the number of references nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr)
Returns the type stored in the given type attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrIsSplat(MlirAttribute attr)
Checks whether the given dense elements attribute contains a single replicated value (splat).
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements)
Creates a dense elements attribute with the given Shaped type and elements in the same context as the...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr)
Checks whether the given attribute is a dense array attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value)
Creates a bool attribute in the given context with the given value.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void)
Returns the typeID of a Float attribute.
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signed type and fit...
MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element of the given dictionary attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void)
Returns the typeID of a Unit attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, float const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values)
Create a dense array attribute with the given elements.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaque(MlirAttribute attr)
Checks whether the given attribute is an opaque attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name)
Returns the dictionary attribute element with the given name or NULL if the given name does not exist...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void)
Returns the typeID of an SymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)
Returns the single replicated value (splat) of a specific type contained by the given dense elements ...
MLIR_CAPI_EXPORTED bool mlirAttributeIsASymbolRef(MlirAttribute attr)
Checks whether the given attribute is a symbol reference attribute.
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr)
Gets the total number of elements in the given elements attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)
Returns the namespace of the dialect with which the given opaque attribute is associated.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADictionary(MlirAttribute attr)
Checks whether the given attribute is a dictionary attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr)
Returns the attribute values as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, int16_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void)
Returns the typeID of an Opaque attribute.
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr)
Returns the value stored in the given floating point attribute, interpreting the value as double.
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of unsigned type and f...
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, int8_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx)
Creates a unit attribute in the given context.
MLIR_CAPI_EXPORTED double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value)
Creates a floating point attribute in the given context with the given double value and double-precis...
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements)
Creates an array element containing the given list of elements in the given context.
MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr)
Get the size of a dense array.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFloat(MlirAttribute attr)
Checks whether the given attribute is a floating point attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references)
Creates a symbol reference attribute in the given context referencing a symbol identified by the give...
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void)
Returns the typeID of a Type attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element)
Creates a dense elements attribute with the given Shaped type containing a single replicated element ...
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void)
Returns the typeID of a Dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type)
Creates a type attribute wrapping the given type in the same context as the type.
MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr)
Returns the number of elements stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer)
Creates a dense elements attribute with the given Shaped type and elements populated from a packed,...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)
Checks whether the given attribute is a flat symbol reference attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)
Returns the referenced symbol as a string reference.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition IR.cpp:1306
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition IR.cpp:1327
struct MlirNamedAttribute MlirNamedAttribute
Definition IR.h:80
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition IR.cpp:1244
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition IR.cpp:1256
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition IR.cpp:1315
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:82
void populateIRAttributes(nanobind::module_ &m)
PyObjectRef< PyMlirContext > PyMlirContextRef
Wrapper around MlirContext.
Definition IRModule.h:190
Include the generated interface declarations.
MlirAttribute attribute
Definition IR.h:78
MlirIdentifier name
Definition IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition Support.h:73
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75
Custom exception that allows access to error diagnostic information.
Definition IRModule.h:1318
RAII object that captures any error diagnostics emitted to the provided context.
Definition IRModule.h:408
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition IRModule.h:418
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.