MLIR  20.0.0git
IRAttributes.cpp
Go to the documentation of this file.
1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <string>
12 #include <string_view>
13 #include <utility>
14 
15 #include "IRModule.h"
16 #include "NanobindUtils.h"
18 #include "mlir-c/BuiltinTypes.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 namespace nb = nanobind;
25 using namespace nanobind::literals;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kDenseElementsAttrGetDocstring[] =
36  R"(Gets a DenseElementsAttr from a Python buffer or array.
37 
38 When `type` is not provided, then some limited type inferencing is done based
39 on the buffer format. Support presently exists for 8/16/32/64 signed and
40 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41 types can also be converted back to a corresponding buffer.
42 
43 For conversions outside of these types, a `type=` must be explicitly provided
44 and the buffer contents must be bit-castable to the MLIR internal
45 representation:
46 
47  * Integer types (except for i1): the buffer must be byte aligned to the
48  next byte boundary.
49  * Floating point types: Must be bit-castable to the given floating point
50  size.
51  * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52  row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53  this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
54 
55 If a single element buffer is passed (or for i1, a single byte with value 0
56 or 255), then a splat will be created.
57 
58 Args:
59  array: The array or buffer to convert.
60  signless: If inferring an appropriate MLIR type, use signless types for
61  integers (defaults True).
62  type: Skips inference of the MLIR element type and uses this instead. The
63  storage size must be consistent with the actual contents of the buffer.
64  shape: Overrides the shape of the buffer when constructing the MLIR
65  shaped type. This is needed when the physical and logical shape differ (as
66  for i1).
67  context: Explicit context, if not from context manager.
68 
69 Returns:
70  DenseElementsAttr on success.
71 
72 Raises:
73  ValueError: If the type of the buffer or array cannot be matched to an MLIR
74  type or if the buffer does not meet expectations.
75 )";
76 
78  R"(Gets a DenseElementsAttr from a Python list of attributes.
79 
80 Note that it can be expensive to construct attributes individually.
81 For a large number of elements, consider using a Python buffer or array instead.
82 
83 Args:
84  attrs: A list of attributes.
85  type: The desired shape and type of the resulting DenseElementsAttr.
86  If not provided, the element type is determined based on the type
87  of the 0th attribute and the shape is `[len(attrs)]`.
88  context: Explicit context, if not from context manager.
89 
90 Returns:
91  DenseElementsAttr on success.
92 
93 Raises:
94  ValueError: If the type of the attributes does not match the type
95  specified by `shaped_type`.
96 )";
97 
99  R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100 
101 This function does minimal validation or massaging of the data, and it is
102 up to the caller to ensure that the buffer meets the characteristics
103 implied by the shape.
104 
105 The backing buffer and any user objects will be retained for the lifetime
106 of the resource blob. This is typically bounded to the context but the
107 resource can have a shorter lifespan depending on how it is used in
108 subsequent processing.
109 
110 Args:
111  buffer: The array or buffer to convert.
112  name: Name to provide to the resource (may be changed upon collision).
113  type: The explicit ShapedType to construct the attribute with.
114  context: Explicit context, if not from context manager.
115 
116 Returns:
117  DenseResourceElementsAttr on success.
118 
119 Raises:
120  ValueError: If the type of the buffer or array cannot be matched to an MLIR
121  type or if the buffer does not meet expectations.
122 )";
123 
124 namespace {
125 
126 struct nb_buffer_info {
127  void *ptr = nullptr;
128  ssize_t itemsize = 0;
129  ssize_t size = 0;
130  const char *format = nullptr;
131  ssize_t ndim = 0;
133  SmallVector<ssize_t, 4> strides;
134  bool readonly = false;
135 
136  nb_buffer_info(
137  void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
139  bool readonly = false,
140  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142  : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143  shape(std::move(shape_in)), strides(std::move(strides_in)),
144  readonly(readonly), owned_view(std::move(owned_view_in)) {
145  size = 1;
146  for (ssize_t i = 0; i < ndim; ++i) {
147  size *= shape[i];
148  }
149  }
150 
151  explicit nb_buffer_info(Py_buffer *view)
152  : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153  {view->shape, view->shape + view->ndim},
154  // TODO(phawkins): check for null strides
155  {view->strides, view->strides + view->ndim},
156  view->readonly != 0,
157  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158  view, PyBuffer_Release)) {}
159 
160  nb_buffer_info(const nb_buffer_info &) = delete;
161  nb_buffer_info(nb_buffer_info &&) = default;
162  nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163  nb_buffer_info &operator=(nb_buffer_info &&) = default;
164 
165 private:
166  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167 };
168 
169 class nb_buffer : public nb::object {
170  NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
171 
172  nb_buffer_info request() const {
173  int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174  auto *view = new Py_buffer();
175  if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176  delete view;
177  throw nb::python_error();
178  }
179  return nb_buffer_info(view);
180  }
181 };
182 
183 template <typename T>
184 struct nb_format_descriptor {};
185 
186 template <>
187 struct nb_format_descriptor<bool> {
188  static const char *format() { return "?"; }
189 };
190 template <>
191 struct nb_format_descriptor<int8_t> {
192  static const char *format() { return "b"; }
193 };
194 template <>
195 struct nb_format_descriptor<uint8_t> {
196  static const char *format() { return "B"; }
197 };
198 template <>
199 struct nb_format_descriptor<int16_t> {
200  static const char *format() { return "h"; }
201 };
202 template <>
203 struct nb_format_descriptor<uint16_t> {
204  static const char *format() { return "H"; }
205 };
206 template <>
207 struct nb_format_descriptor<int32_t> {
208  static const char *format() { return "i"; }
209 };
210 template <>
211 struct nb_format_descriptor<uint32_t> {
212  static const char *format() { return "I"; }
213 };
214 template <>
215 struct nb_format_descriptor<int64_t> {
216  static const char *format() { return "q"; }
217 };
218 template <>
219 struct nb_format_descriptor<uint64_t> {
220  static const char *format() { return "Q"; }
221 };
222 template <>
223 struct nb_format_descriptor<float> {
224  static const char *format() { return "f"; }
225 };
226 template <>
227 struct nb_format_descriptor<double> {
228  static const char *format() { return "d"; }
229 };
230 
231 static MlirStringRef toMlirStringRef(const std::string &s) {
232  return mlirStringRefCreate(s.data(), s.size());
233 }
234 
235 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236  return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237 }
238 
239 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240 public:
241  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242  static constexpr const char *pyClassName = "AffineMapAttr";
243  using PyConcreteAttribute::PyConcreteAttribute;
244  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
246 
247  static void bindDerived(ClassTy &c) {
248  c.def_static(
249  "get",
250  [](PyAffineMap &affineMap) {
251  MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252  return PyAffineMapAttribute(affineMap.getContext(), attr);
253  },
254  nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255  c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256  "Returns the value of the AffineMap attribute");
257  }
258 };
259 
260 class PyIntegerSetAttribute
261  : public PyConcreteAttribute<PyIntegerSetAttribute> {
262 public:
263  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
264  static constexpr const char *pyClassName = "IntegerSetAttr";
265  using PyConcreteAttribute::PyConcreteAttribute;
266  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
268 
269  static void bindDerived(ClassTy &c) {
270  c.def_static(
271  "get",
272  [](PyIntegerSet &integerSet) {
273  MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
274  return PyIntegerSetAttribute(integerSet.getContext(), attr);
275  },
276  nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
277  }
278 };
279 
280 template <typename T>
281 static T pyTryCast(nb::handle object) {
282  try {
283  return nb::cast<T>(object);
284  } catch (nb::cast_error &err) {
285  std::string msg = std::string("Invalid attribute when attempting to "
286  "create an ArrayAttribute (") +
287  err.what() + ")";
288  throw std::runtime_error(msg.c_str());
289  } catch (std::runtime_error &err) {
290  std::string msg = std::string("Invalid attribute (None?) when attempting "
291  "to create an ArrayAttribute (") +
292  err.what() + ")";
293  throw std::runtime_error(msg.c_str());
294  }
295 }
296 
297 /// A python-wrapped dense array attribute with an element type and a derived
298 /// implementation class.
299 template <typename EltTy, typename DerivedT>
300 class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
301 public:
303 
304  /// Iterator over the integer elements of a dense array.
305  class PyDenseArrayIterator {
306  public:
307  PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
308 
309  /// Return a copy of the iterator.
310  PyDenseArrayIterator dunderIter() { return *this; }
311 
312  /// Return the next element.
313  EltTy dunderNext() {
314  // Throw if the index has reached the end.
316  throw nb::stop_iteration();
317  return DerivedT::getElement(attr.get(), nextIndex++);
318  }
319 
320  /// Bind the iterator class.
321  static void bind(nb::module_ &m) {
322  nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
323  .def("__iter__", &PyDenseArrayIterator::dunderIter)
324  .def("__next__", &PyDenseArrayIterator::dunderNext);
325  }
326 
327  private:
328  /// The referenced dense array attribute.
329  PyAttribute attr;
330  /// The next index to read.
331  int nextIndex = 0;
332  };
333 
334  /// Get the element at the given index.
335  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
336 
337  /// Bind the attribute class.
338  static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
339  // Bind the constructor.
340  if constexpr (std::is_same_v<EltTy, bool>) {
341  c.def_static(
342  "get",
343  [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
344  std::vector<bool> values;
345  for (nb::handle py_value : py_values) {
346  int is_true = PyObject_IsTrue(py_value.ptr());
347  if (is_true < 0) {
348  throw nb::python_error();
349  }
350  values.push_back(is_true);
351  }
352  return getAttribute(values, ctx->getRef());
353  },
354  nb::arg("values"), nb::arg("context").none() = nb::none(),
355  "Gets a uniqued dense array attribute");
356  } else {
357  c.def_static(
358  "get",
359  [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
360  return getAttribute(values, ctx->getRef());
361  },
362  nb::arg("values"), nb::arg("context").none() = nb::none(),
363  "Gets a uniqued dense array attribute");
364  }
365  // Bind the array methods.
366  c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
367  if (i >= mlirDenseArrayGetNumElements(arr))
368  throw nb::index_error("DenseArray index out of range");
369  return arr.getItem(i);
370  });
371  c.def("__len__", [](const DerivedT &arr) {
372  return mlirDenseArrayGetNumElements(arr);
373  });
374  c.def("__iter__",
375  [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
376  c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
377  std::vector<EltTy> values;
378  intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
379  values.reserve(numOldElements + nb::len(extras));
380  for (intptr_t i = 0; i < numOldElements; ++i)
381  values.push_back(arr.getItem(i));
382  for (nb::handle attr : extras)
383  values.push_back(pyTryCast<EltTy>(attr));
384  return getAttribute(values, arr.getContext());
385  });
386  }
387 
388 private:
389  static DerivedT getAttribute(const std::vector<EltTy> &values,
390  PyMlirContextRef ctx) {
391  if constexpr (std::is_same_v<EltTy, bool>) {
392  std::vector<int> intValues(values.begin(), values.end());
393  MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
394  intValues.data());
395  return DerivedT(ctx, attr);
396  } else {
397  MlirAttribute attr =
398  DerivedT::getAttribute(ctx->get(), values.size(), values.data());
399  return DerivedT(ctx, attr);
400  }
401  }
402 };
403 
404 /// Instantiate the python dense array classes.
405 struct PyDenseBoolArrayAttribute
406  : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
407  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
408  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
409  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
410  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
411  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
412  using PyDenseArrayAttribute::PyDenseArrayAttribute;
413 };
414 struct PyDenseI8ArrayAttribute
415  : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
416  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
417  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
418  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
419  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
420  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
421  using PyDenseArrayAttribute::PyDenseArrayAttribute;
422 };
423 struct PyDenseI16ArrayAttribute
424  : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
425  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
426  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
427  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
428  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
429  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
430  using PyDenseArrayAttribute::PyDenseArrayAttribute;
431 };
432 struct PyDenseI32ArrayAttribute
433  : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
434  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
435  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
436  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
437  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
438  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
439  using PyDenseArrayAttribute::PyDenseArrayAttribute;
440 };
441 struct PyDenseI64ArrayAttribute
442  : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
443  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
444  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
445  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
446  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
447  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
448  using PyDenseArrayAttribute::PyDenseArrayAttribute;
449 };
450 struct PyDenseF32ArrayAttribute
451  : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
452  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
453  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
454  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
455  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
456  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
457  using PyDenseArrayAttribute::PyDenseArrayAttribute;
458 };
459 struct PyDenseF64ArrayAttribute
460  : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
461  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
462  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
463  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
464  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
465  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
466  using PyDenseArrayAttribute::PyDenseArrayAttribute;
467 };
468 
469 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
470 public:
471  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
472  static constexpr const char *pyClassName = "ArrayAttr";
473  using PyConcreteAttribute::PyConcreteAttribute;
474  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
476 
477  class PyArrayAttributeIterator {
478  public:
479  PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
480 
481  PyArrayAttributeIterator &dunderIter() { return *this; }
482 
483  MlirAttribute dunderNext() {
484  // TODO: Throw is an inefficient way to stop iteration.
486  throw nb::stop_iteration();
487  return mlirArrayAttrGetElement(attr.get(), nextIndex++);
488  }
489 
490  static void bind(nb::module_ &m) {
491  nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
492  .def("__iter__", &PyArrayAttributeIterator::dunderIter)
493  .def("__next__", &PyArrayAttributeIterator::dunderNext);
494  }
495 
496  private:
497  PyAttribute attr;
498  int nextIndex = 0;
499  };
500 
501  MlirAttribute getItem(intptr_t i) {
502  return mlirArrayAttrGetElement(*this, i);
503  }
504 
505  static void bindDerived(ClassTy &c) {
506  c.def_static(
507  "get",
508  [](nb::list attributes, DefaultingPyMlirContext context) {
509  SmallVector<MlirAttribute> mlirAttributes;
510  mlirAttributes.reserve(nb::len(attributes));
511  for (auto attribute : attributes) {
512  mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
513  }
514  MlirAttribute attr = mlirArrayAttrGet(
515  context->get(), mlirAttributes.size(), mlirAttributes.data());
516  return PyArrayAttribute(context->getRef(), attr);
517  },
518  nb::arg("attributes"), nb::arg("context").none() = nb::none(),
519  "Gets a uniqued Array attribute");
520  c.def("__getitem__",
521  [](PyArrayAttribute &arr, intptr_t i) {
522  if (i >= mlirArrayAttrGetNumElements(arr))
523  throw nb::index_error("ArrayAttribute index out of range");
524  return arr.getItem(i);
525  })
526  .def("__len__",
527  [](const PyArrayAttribute &arr) {
528  return mlirArrayAttrGetNumElements(arr);
529  })
530  .def("__iter__", [](const PyArrayAttribute &arr) {
531  return PyArrayAttributeIterator(arr);
532  });
533  c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
534  std::vector<MlirAttribute> attributes;
535  intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
536  attributes.reserve(numOldElements + nb::len(extras));
537  for (intptr_t i = 0; i < numOldElements; ++i)
538  attributes.push_back(arr.getItem(i));
539  for (nb::handle attr : extras)
540  attributes.push_back(pyTryCast<PyAttribute>(attr));
541  MlirAttribute arrayAttr = mlirArrayAttrGet(
542  arr.getContext()->get(), attributes.size(), attributes.data());
543  return PyArrayAttribute(arr.getContext(), arrayAttr);
544  });
545  }
546 };
547 
548 /// Float Point Attribute subclass - FloatAttr.
549 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
550 public:
551  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
552  static constexpr const char *pyClassName = "FloatAttr";
553  using PyConcreteAttribute::PyConcreteAttribute;
554  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
556 
557  static void bindDerived(ClassTy &c) {
558  c.def_static(
559  "get",
560  [](PyType &type, double value, DefaultingPyLocation loc) {
561  PyMlirContext::ErrorCapture errors(loc->getContext());
562  MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
563  if (mlirAttributeIsNull(attr))
564  throw MLIRError("Invalid attribute", errors.take());
565  return PyFloatAttribute(type.getContext(), attr);
566  },
567  nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
568  "Gets an uniqued float point attribute associated to a type");
569  c.def_static(
570  "get_f32",
571  [](double value, DefaultingPyMlirContext context) {
572  MlirAttribute attr = mlirFloatAttrDoubleGet(
573  context->get(), mlirF32TypeGet(context->get()), value);
574  return PyFloatAttribute(context->getRef(), attr);
575  },
576  nb::arg("value"), nb::arg("context").none() = nb::none(),
577  "Gets an uniqued float point attribute associated to a f32 type");
578  c.def_static(
579  "get_f64",
580  [](double value, DefaultingPyMlirContext context) {
581  MlirAttribute attr = mlirFloatAttrDoubleGet(
582  context->get(), mlirF64TypeGet(context->get()), value);
583  return PyFloatAttribute(context->getRef(), attr);
584  },
585  nb::arg("value"), nb::arg("context").none() = nb::none(),
586  "Gets an uniqued float point attribute associated to a f64 type");
587  c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
588  "Returns the value of the float attribute");
589  c.def("__float__", mlirFloatAttrGetValueDouble,
590  "Converts the value of the float attribute to a Python float");
591  }
592 };
593 
594 /// Integer Attribute subclass - IntegerAttr.
595 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
596 public:
597  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
598  static constexpr const char *pyClassName = "IntegerAttr";
599  using PyConcreteAttribute::PyConcreteAttribute;
600 
601  static void bindDerived(ClassTy &c) {
602  c.def_static(
603  "get",
604  [](PyType &type, int64_t value) {
605  MlirAttribute attr = mlirIntegerAttrGet(type, value);
606  return PyIntegerAttribute(type.getContext(), attr);
607  },
608  nb::arg("type"), nb::arg("value"),
609  "Gets an uniqued integer attribute associated to a type");
610  c.def_prop_ro("value", toPyInt,
611  "Returns the value of the integer attribute");
612  c.def("__int__", toPyInt,
613  "Converts the value of the integer attribute to a Python int");
614  c.def_prop_ro_static("static_typeid",
615  [](nb::object & /*class*/) -> MlirTypeID {
616  return mlirIntegerAttrGetTypeID();
617  });
618  }
619 
620 private:
621  static int64_t toPyInt(PyIntegerAttribute &self) {
622  MlirType type = mlirAttributeGetType(self);
623  if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624  return mlirIntegerAttrGetValueInt(self);
625  if (mlirIntegerTypeIsSigned(type))
626  return mlirIntegerAttrGetValueSInt(self);
627  return mlirIntegerAttrGetValueUInt(self);
628  }
629 };
630 
631 /// Bool Attribute subclass - BoolAttr.
632 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
633 public:
634  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
635  static constexpr const char *pyClassName = "BoolAttr";
636  using PyConcreteAttribute::PyConcreteAttribute;
637 
638  static void bindDerived(ClassTy &c) {
639  c.def_static(
640  "get",
641  [](bool value, DefaultingPyMlirContext context) {
642  MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
643  return PyBoolAttribute(context->getRef(), attr);
644  },
645  nb::arg("value"), nb::arg("context").none() = nb::none(),
646  "Gets an uniqued bool attribute");
647  c.def_prop_ro("value", mlirBoolAttrGetValue,
648  "Returns the value of the bool attribute");
649  c.def("__bool__", mlirBoolAttrGetValue,
650  "Converts the value of the bool attribute to a Python bool");
651  }
652 };
653 
654 class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
655 public:
656  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
657  static constexpr const char *pyClassName = "SymbolRefAttr";
658  using PyConcreteAttribute::PyConcreteAttribute;
659 
660  static MlirAttribute fromList(const std::vector<std::string> &symbols,
661  PyMlirContext &context) {
662  if (symbols.empty())
663  throw std::runtime_error("SymbolRefAttr must be composed of at least "
664  "one symbol.");
665  MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
666  SmallVector<MlirAttribute, 3> referenceAttrs;
667  for (size_t i = 1; i < symbols.size(); ++i) {
668  referenceAttrs.push_back(
669  mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
670  }
671  return mlirSymbolRefAttrGet(context.get(), rootSymbol,
672  referenceAttrs.size(), referenceAttrs.data());
673  }
674 
675  static void bindDerived(ClassTy &c) {
676  c.def_static(
677  "get",
678  [](const std::vector<std::string> &symbols,
679  DefaultingPyMlirContext context) {
680  return PySymbolRefAttribute::fromList(symbols, context.resolve());
681  },
682  nb::arg("symbols"), nb::arg("context").none() = nb::none(),
683  "Gets a uniqued SymbolRef attribute from a list of symbol names");
684  c.def_prop_ro(
685  "value",
686  [](PySymbolRefAttribute &self) {
687  std::vector<std::string> symbols = {
689  for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
690  ++i)
691  symbols.push_back(
694  .str());
695  return symbols;
696  },
697  "Returns the value of the SymbolRef attribute as a list[str]");
698  }
699 };
700 
701 class PyFlatSymbolRefAttribute
702  : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
703 public:
704  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
705  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
706  using PyConcreteAttribute::PyConcreteAttribute;
707 
708  static void bindDerived(ClassTy &c) {
709  c.def_static(
710  "get",
711  [](std::string value, DefaultingPyMlirContext context) {
712  MlirAttribute attr =
713  mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
714  return PyFlatSymbolRefAttribute(context->getRef(), attr);
715  },
716  nb::arg("value"), nb::arg("context").none() = nb::none(),
717  "Gets a uniqued FlatSymbolRef attribute");
718  c.def_prop_ro(
719  "value",
720  [](PyFlatSymbolRefAttribute &self) {
722  return nb::str(stringRef.data, stringRef.length);
723  },
724  "Returns the value of the FlatSymbolRef attribute as a string");
725  }
726 };
727 
728 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
729 public:
730  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
731  static constexpr const char *pyClassName = "OpaqueAttr";
732  using PyConcreteAttribute::PyConcreteAttribute;
733  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
735 
736  static void bindDerived(ClassTy &c) {
737  c.def_static(
738  "get",
739  [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
740  DefaultingPyMlirContext context) {
741  const nb_buffer_info bufferInfo = buffer.request();
742  intptr_t bufferSize = bufferInfo.size;
743  MlirAttribute attr = mlirOpaqueAttrGet(
744  context->get(), toMlirStringRef(dialectNamespace), bufferSize,
745  static_cast<char *>(bufferInfo.ptr), type);
746  return PyOpaqueAttribute(context->getRef(), attr);
747  },
748  nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749  nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
750  c.def_prop_ro(
751  "dialect_namespace",
752  [](PyOpaqueAttribute &self) {
754  return nb::str(stringRef.data, stringRef.length);
755  },
756  "Returns the dialect namespace for the Opaque attribute as a string");
757  c.def_prop_ro(
758  "data",
759  [](PyOpaqueAttribute &self) {
760  MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
761  return nb::bytes(stringRef.data, stringRef.length);
762  },
763  "Returns the data for the Opaqued attributes as `bytes`");
764  }
765 };
766 
767 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768 public:
769  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770  static constexpr const char *pyClassName = "StringAttr";
771  using PyConcreteAttribute::PyConcreteAttribute;
772  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
774 
775  static void bindDerived(ClassTy &c) {
776  c.def_static(
777  "get",
778  [](std::string value, DefaultingPyMlirContext context) {
779  MlirAttribute attr =
780  mlirStringAttrGet(context->get(), toMlirStringRef(value));
781  return PyStringAttribute(context->getRef(), attr);
782  },
783  nb::arg("value"), nb::arg("context").none() = nb::none(),
784  "Gets a uniqued string attribute");
785  c.def_static(
786  "get",
787  [](nb::bytes value, DefaultingPyMlirContext context) {
788  MlirAttribute attr =
789  mlirStringAttrGet(context->get(), toMlirStringRef(value));
790  return PyStringAttribute(context->getRef(), attr);
791  },
792  nb::arg("value"), nb::arg("context").none() = nb::none(),
793  "Gets a uniqued string attribute");
794  c.def_static(
795  "get_typed",
796  [](PyType &type, std::string value) {
797  MlirAttribute attr =
799  return PyStringAttribute(type.getContext(), attr);
800  },
801  nb::arg("type"), nb::arg("value"),
802  "Gets a uniqued string attribute associated to a type");
803  c.def_prop_ro(
804  "value",
805  [](PyStringAttribute &self) {
806  MlirStringRef stringRef = mlirStringAttrGetValue(self);
807  return nb::str(stringRef.data, stringRef.length);
808  },
809  "Returns the value of the string attribute");
810  c.def_prop_ro(
811  "value_bytes",
812  [](PyStringAttribute &self) {
813  MlirStringRef stringRef = mlirStringAttrGetValue(self);
814  return nb::bytes(stringRef.data, stringRef.length);
815  },
816  "Returns the value of the string attribute as `bytes`");
817  }
818 };
819 
820 // TODO: Support construction of string elements.
821 class PyDenseElementsAttribute
822  : public PyConcreteAttribute<PyDenseElementsAttribute> {
823 public:
824  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
825  static constexpr const char *pyClassName = "DenseElementsAttr";
826  using PyConcreteAttribute::PyConcreteAttribute;
827 
828  static PyDenseElementsAttribute
829  getFromList(nb::list attributes, std::optional<PyType> explicitType,
830  DefaultingPyMlirContext contextWrapper) {
831  const size_t numAttributes = nb::len(attributes);
832  if (numAttributes == 0)
833  throw nb::value_error("Attributes list must be non-empty.");
834 
835  MlirType shapedType;
836  if (explicitType) {
837  if ((!mlirTypeIsAShaped(*explicitType) ||
838  !mlirShapedTypeHasStaticShape(*explicitType))) {
839 
840  std::string message;
841  llvm::raw_string_ostream os(message);
842  os << "Expected a static ShapedType for the shaped_type parameter: "
843  << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
844  throw nb::value_error(message.c_str());
845  }
846  shapedType = *explicitType;
847  } else {
848  SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
849  shapedType = mlirRankedTensorTypeGet(
850  shape.size(), shape.data(),
851  mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
853  }
854 
855  SmallVector<MlirAttribute> mlirAttributes;
856  mlirAttributes.reserve(numAttributes);
857  for (const nb::handle &attribute : attributes) {
858  MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
859  MlirType attrType = mlirAttributeGetType(mlirAttribute);
860  mlirAttributes.push_back(mlirAttribute);
861 
862  if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
863  std::string message;
864  llvm::raw_string_ostream os(message);
865  os << "All attributes must be of the same type and match "
866  << "the type parameter: expected="
867  << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
868  << ", but got="
869  << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
870  throw nb::value_error(message.c_str());
871  }
872  }
873 
874  MlirAttribute elements = mlirDenseElementsAttrGet(
875  shapedType, mlirAttributes.size(), mlirAttributes.data());
876 
877  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
878  }
879 
880  static PyDenseElementsAttribute
881  getFromBuffer(nb_buffer array, bool signless,
882  std::optional<PyType> explicitType,
883  std::optional<std::vector<int64_t>> explicitShape,
884  DefaultingPyMlirContext contextWrapper) {
885  // Request a contiguous view. In exotic cases, this will cause a copy.
886  int flags = PyBUF_ND;
887  if (!explicitType) {
888  flags |= PyBUF_FORMAT;
889  }
890  Py_buffer view;
891  if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
892  throw nb::python_error();
893  }
894  auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
895 
896  MlirContext context = contextWrapper->get();
897  MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
898  explicitShape, context);
899  if (mlirAttributeIsNull(attr)) {
900  throw std::invalid_argument(
901  "DenseElementsAttr could not be constructed from the given buffer. "
902  "This may mean that the Python buffer layout does not match that "
903  "MLIR expected layout and is a bug.");
904  }
905  return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
906  }
907 
908  static PyDenseElementsAttribute getSplat(const PyType &shapedType,
909  PyAttribute &elementAttr) {
910  auto contextWrapper =
911  PyMlirContext::forContext(mlirTypeGetContext(shapedType));
912  if (!mlirAttributeIsAInteger(elementAttr) &&
913  !mlirAttributeIsAFloat(elementAttr)) {
914  std::string message = "Illegal element type for DenseElementsAttr: ";
915  message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
916  throw nb::value_error(message.c_str());
917  }
918  if (!mlirTypeIsAShaped(shapedType) ||
919  !mlirShapedTypeHasStaticShape(shapedType)) {
920  std::string message =
921  "Expected a static ShapedType for the shaped_type parameter: ";
922  message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
923  throw nb::value_error(message.c_str());
924  }
925  MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
926  MlirType attrType = mlirAttributeGetType(elementAttr);
927  if (!mlirTypeEqual(shapedElementType, attrType)) {
928  std::string message =
929  "Shaped element type and attribute type must be equal: shaped=";
930  message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
931  message.append(", element=");
932  message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
933  throw nb::value_error(message.c_str());
934  }
935 
936  MlirAttribute elements =
937  mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
938  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
939  }
940 
941  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
942 
943  std::unique_ptr<nb_buffer_info> accessBuffer() {
944  MlirType shapedType = mlirAttributeGetType(*this);
945  MlirType elementType = mlirShapedTypeGetElementType(shapedType);
946  std::string format;
947 
948  if (mlirTypeIsAF32(elementType)) {
949  // f32
950  return bufferInfo<float>(shapedType);
951  }
952  if (mlirTypeIsAF64(elementType)) {
953  // f64
954  return bufferInfo<double>(shapedType);
955  }
956  if (mlirTypeIsAF16(elementType)) {
957  // f16
958  return bufferInfo<uint16_t>(shapedType, "e");
959  }
960  if (mlirTypeIsAIndex(elementType)) {
961  // Same as IndexType::kInternalStorageBitWidth
962  return bufferInfo<int64_t>(shapedType);
963  }
964  if (mlirTypeIsAInteger(elementType) &&
965  mlirIntegerTypeGetWidth(elementType) == 32) {
966  if (mlirIntegerTypeIsSignless(elementType) ||
967  mlirIntegerTypeIsSigned(elementType)) {
968  // i32
969  return bufferInfo<int32_t>(shapedType);
970  }
971  if (mlirIntegerTypeIsUnsigned(elementType)) {
972  // unsigned i32
973  return bufferInfo<uint32_t>(shapedType);
974  }
975  } else if (mlirTypeIsAInteger(elementType) &&
976  mlirIntegerTypeGetWidth(elementType) == 64) {
977  if (mlirIntegerTypeIsSignless(elementType) ||
978  mlirIntegerTypeIsSigned(elementType)) {
979  // i64
980  return bufferInfo<int64_t>(shapedType);
981  }
982  if (mlirIntegerTypeIsUnsigned(elementType)) {
983  // unsigned i64
984  return bufferInfo<uint64_t>(shapedType);
985  }
986  } else if (mlirTypeIsAInteger(elementType) &&
987  mlirIntegerTypeGetWidth(elementType) == 8) {
988  if (mlirIntegerTypeIsSignless(elementType) ||
989  mlirIntegerTypeIsSigned(elementType)) {
990  // i8
991  return bufferInfo<int8_t>(shapedType);
992  }
993  if (mlirIntegerTypeIsUnsigned(elementType)) {
994  // unsigned i8
995  return bufferInfo<uint8_t>(shapedType);
996  }
997  } else if (mlirTypeIsAInteger(elementType) &&
998  mlirIntegerTypeGetWidth(elementType) == 16) {
999  if (mlirIntegerTypeIsSignless(elementType) ||
1000  mlirIntegerTypeIsSigned(elementType)) {
1001  // i16
1002  return bufferInfo<int16_t>(shapedType);
1003  }
1004  if (mlirIntegerTypeIsUnsigned(elementType)) {
1005  // unsigned i16
1006  return bufferInfo<uint16_t>(shapedType);
1007  }
1008  } else if (mlirTypeIsAInteger(elementType) &&
1009  mlirIntegerTypeGetWidth(elementType) == 1) {
1010  // i1 / bool
1011  // We can not send the buffer directly back to Python, because the i1
1012  // values are bitpacked within MLIR. We call numpy's unpackbits function
1013  // to convert the bytes.
1014  return getBooleanBufferFromBitpackedAttribute();
1015  }
1016 
1017  // TODO: Currently crashes the program.
1018  // Reported as https://github.com/pybind/pybind11/issues/3336
1019  throw std::invalid_argument(
1020  "unsupported data type for conversion to Python buffer");
1021  }
1022 
1023  static void bindDerived(ClassTy &c) {
1024 #if PY_VERSION_HEX < 0x03090000
1025  PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1026  tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1027  tp->tp_as_buffer->bf_releasebuffer =
1028  PyDenseElementsAttribute::bf_releasebuffer;
1029 #endif
1030  c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031  .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032  nb::arg("array"), nb::arg("signless") = true,
1033  nb::arg("type").none() = nb::none(),
1034  nb::arg("shape").none() = nb::none(),
1035  nb::arg("context").none() = nb::none(),
1037  .def_static("get", PyDenseElementsAttribute::getFromList,
1038  nb::arg("attrs"), nb::arg("type").none() = nb::none(),
1039  nb::arg("context").none() = nb::none(),
1041  .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1042  nb::arg("shaped_type"), nb::arg("element_attr"),
1043  "Gets a DenseElementsAttr where all values are the same")
1044  .def_prop_ro("is_splat",
1045  [](PyDenseElementsAttribute &self) -> bool {
1046  return mlirDenseElementsAttrIsSplat(self);
1047  })
1048  .def("get_splat_value", [](PyDenseElementsAttribute &self) {
1049  if (!mlirDenseElementsAttrIsSplat(self))
1050  throw nb::value_error(
1051  "get_splat_value called on a non-splat attribute");
1053  });
1054  }
1055 
1056  static PyType_Slot slots[];
1057 
1058 private:
1059  static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1060  static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1061 
1062  static bool isUnsignedIntegerFormat(std::string_view format) {
1063  if (format.empty())
1064  return false;
1065  char code = format[0];
1066  return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1067  code == 'Q';
1068  }
1069 
1070  static bool isSignedIntegerFormat(std::string_view format) {
1071  if (format.empty())
1072  return false;
1073  char code = format[0];
1074  return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1075  code == 'q';
1076  }
1077 
1078  static MlirType
1079  getShapedType(std::optional<MlirType> bulkLoadElementType,
1080  std::optional<std::vector<int64_t>> explicitShape,
1081  Py_buffer &view) {
1082  SmallVector<int64_t> shape;
1083  if (explicitShape) {
1084  shape.append(explicitShape->begin(), explicitShape->end());
1085  } else {
1086  shape.append(view.shape, view.shape + view.ndim);
1087  }
1088 
1089  if (mlirTypeIsAShaped(*bulkLoadElementType)) {
1090  if (explicitShape) {
1091  throw std::invalid_argument("Shape can only be specified explicitly "
1092  "when the type is not a shaped type.");
1093  }
1094  return *bulkLoadElementType;
1095  } else {
1096  MlirAttribute encodingAttr = mlirAttributeGetNull();
1097  return mlirRankedTensorTypeGet(shape.size(), shape.data(),
1098  *bulkLoadElementType, encodingAttr);
1099  }
1100  }
1101 
1102  static MlirAttribute getAttributeFromBuffer(
1103  Py_buffer &view, bool signless, std::optional<PyType> explicitType,
1104  std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
1105  // Detect format codes that are suitable for bulk loading. This includes
1106  // all byte aligned integer and floating point types up to 8 bytes.
1107  // Notably, this excludes exotics types which do not have a direct
1108  // representation in the buffer protocol (i.e. complex, etc).
1109  std::optional<MlirType> bulkLoadElementType;
1110  if (explicitType) {
1111  bulkLoadElementType = *explicitType;
1112  } else {
1113  std::string_view format(view.format);
1114  if (format == "f") {
1115  // f32
1116  assert(view.itemsize == 4 && "mismatched array itemsize");
1117  bulkLoadElementType = mlirF32TypeGet(context);
1118  } else if (format == "d") {
1119  // f64
1120  assert(view.itemsize == 8 && "mismatched array itemsize");
1121  bulkLoadElementType = mlirF64TypeGet(context);
1122  } else if (format == "e") {
1123  // f16
1124  assert(view.itemsize == 2 && "mismatched array itemsize");
1125  bulkLoadElementType = mlirF16TypeGet(context);
1126  } else if (format == "?") {
1127  // i1
1128  // The i1 type needs to be bit-packed, so we will handle it seperately
1129  return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
1130  context);
1131  } else if (isSignedIntegerFormat(format)) {
1132  if (view.itemsize == 4) {
1133  // i32
1134  bulkLoadElementType = signless
1135  ? mlirIntegerTypeGet(context, 32)
1136  : mlirIntegerTypeSignedGet(context, 32);
1137  } else if (view.itemsize == 8) {
1138  // i64
1139  bulkLoadElementType = signless
1140  ? mlirIntegerTypeGet(context, 64)
1141  : mlirIntegerTypeSignedGet(context, 64);
1142  } else if (view.itemsize == 1) {
1143  // i8
1144  bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1145  : mlirIntegerTypeSignedGet(context, 8);
1146  } else if (view.itemsize == 2) {
1147  // i16
1148  bulkLoadElementType = signless
1149  ? mlirIntegerTypeGet(context, 16)
1150  : mlirIntegerTypeSignedGet(context, 16);
1151  }
1152  } else if (isUnsignedIntegerFormat(format)) {
1153  if (view.itemsize == 4) {
1154  // unsigned i32
1155  bulkLoadElementType = signless
1156  ? mlirIntegerTypeGet(context, 32)
1157  : mlirIntegerTypeUnsignedGet(context, 32);
1158  } else if (view.itemsize == 8) {
1159  // unsigned i64
1160  bulkLoadElementType = signless
1161  ? mlirIntegerTypeGet(context, 64)
1162  : mlirIntegerTypeUnsignedGet(context, 64);
1163  } else if (view.itemsize == 1) {
1164  // i8
1165  bulkLoadElementType = signless
1166  ? mlirIntegerTypeGet(context, 8)
1167  : mlirIntegerTypeUnsignedGet(context, 8);
1168  } else if (view.itemsize == 2) {
1169  // i16
1170  bulkLoadElementType = signless
1171  ? mlirIntegerTypeGet(context, 16)
1172  : mlirIntegerTypeUnsignedGet(context, 16);
1173  }
1174  }
1175  if (!bulkLoadElementType) {
1176  throw std::invalid_argument(
1177  std::string("unimplemented array format conversion from format: ") +
1178  std::string(format));
1179  }
1180  }
1181 
1182  MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1183  return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1184  }
1185 
1186  // There is a complication for boolean numpy arrays, as numpy represents
1187  // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1188  // booleans per byte.
1189  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1190  Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1191  MlirContext &context) {
1192  if (llvm::endianness::native != llvm::endianness::little) {
1193  // Given we have no good way of testing the behavior on big-endian
1194  // systems we will throw
1195  throw nb::type_error("Constructing a bit-packed MLIR attribute is "
1196  "unsupported on big-endian systems");
1197  }
1198  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1199  /*data=*/static_cast<uint8_t *>(view.buf),
1200  /*shape=*/{static_cast<size_t>(view.len)});
1201 
1202  nb::module_ numpy = nb::module_::import_("numpy");
1203  nb::object packbitsFunc = numpy.attr("packbits");
1204  nb::object packedBooleans =
1205  packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1206  nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
1207 
1208  MlirType bitpackedType =
1209  getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1210  assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1211  // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1212  // packedBooleans, hence the MlirAttribute will remain valid even when
1213  // packedBooleans get reclaimed by the end of the function.
1214  return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1215  pythonBuffer.ptr);
1216  }
1217 
1218  // This does the opposite transformation of
1219  // `getBitpackedAttributeFromBooleanBuffer`
1220  std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
1221  if (llvm::endianness::native != llvm::endianness::little) {
1222  // Given we have no good way of testing the behavior on big-endian
1223  // systems we will throw
1224  throw nb::type_error("Constructing a numpy array from a MLIR attribute "
1225  "is unsupported on big-endian systems");
1226  }
1227 
1228  int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1229  int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1230  uint8_t *bitpackedData = static_cast<uint8_t *>(
1231  const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1232  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1233  /*data=*/bitpackedData,
1234  /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
1235 
1236  nb::module_ numpy = nb::module_::import_("numpy");
1237  nb::object unpackbitsFunc = numpy.attr("unpackbits");
1238  nb::object equalFunc = numpy.attr("equal");
1239  nb::object reshapeFunc = numpy.attr("reshape");
1240  nb::object unpackedBooleans =
1241  unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
1242 
1243  // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1244  // We need to:
1245  // 1. Slice away the padded bits
1246  // 2. Make the boolean array have the correct shape
1247  // 3. Convert the array to a boolean array
1248  unpackedBooleans = unpackedBooleans[nb::slice(
1249  nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
1250  unpackedBooleans = equalFunc(unpackedBooleans, 1);
1251 
1252  MlirType shapedType = mlirAttributeGetType(*this);
1253  intptr_t rank = mlirShapedTypeGetRank(shapedType);
1254  std::vector<intptr_t> shape(rank);
1255  for (intptr_t i = 0; i < rank; ++i) {
1256  shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
1257  }
1258  unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1259 
1260  // Make sure the returned nb::buffer_view claims ownership of the data in
1261  // `pythonBuffer` so it remains valid when Python reads it
1262  nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1263  return std::make_unique<nb_buffer_info>(pythonBuffer.request());
1264  }
1265 
1266  template <typename Type>
1267  std::unique_ptr<nb_buffer_info>
1268  bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
1269  intptr_t rank = mlirShapedTypeGetRank(shapedType);
1270  // Prepare the data for the buffer_info.
1271  // Buffer is configured for read-only access below.
1272  Type *data = static_cast<Type *>(
1273  const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1274  // Prepare the shape for the buffer_info.
1276  for (intptr_t i = 0; i < rank; ++i)
1277  shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1278  // Prepare the strides for the buffer_info.
1279  SmallVector<intptr_t, 4> strides;
1280  if (mlirDenseElementsAttrIsSplat(*this)) {
1281  // Splats are special, only the single value is stored.
1282  strides.assign(rank, 0);
1283  } else {
1284  for (intptr_t i = 1; i < rank; ++i) {
1285  intptr_t strideFactor = 1;
1286  for (intptr_t j = i; j < rank; ++j)
1287  strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1288  strides.push_back(sizeof(Type) * strideFactor);
1289  }
1290  strides.push_back(sizeof(Type));
1291  }
1292  const char *format;
1293  if (explicitFormat) {
1294  format = explicitFormat;
1295  } else {
1296  format = nb_format_descriptor<Type>::format();
1297  }
1298  return std::make_unique<nb_buffer_info>(
1299  data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
1300  /*readonly=*/true);
1301  }
1302 }; // namespace
1303 
1304 PyType_Slot PyDenseElementsAttribute::slots[] = {
1305 // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1306 #if PY_VERSION_HEX >= 0x03090000
1307  {Py_bf_getbuffer,
1308  reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1309  {Py_bf_releasebuffer,
1310  reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1311 #endif
1312  {0, nullptr},
1313 };
1314 
1315 /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1316  Py_buffer *view,
1317  int flags) {
1318  view->obj = nullptr;
1319  std::unique_ptr<nb_buffer_info> info;
1320  try {
1321  auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1322  info = attr->accessBuffer();
1323  } catch (nb::python_error &e) {
1324  e.restore();
1325  nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1326  return -1;
1327  }
1328  view->obj = obj;
1329  view->ndim = 1;
1330  view->buf = info->ptr;
1331  view->itemsize = info->itemsize;
1332  view->len = info->itemsize;
1333  for (auto s : info->shape) {
1334  view->len *= s;
1335  }
1336  view->readonly = info->readonly;
1337  if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1338  view->format = const_cast<char *>(info->format);
1339  }
1340  if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1341  view->ndim = static_cast<int>(info->ndim);
1342  view->strides = info->strides.data();
1343  view->shape = info->shape.data();
1344  }
1345  view->suboffsets = nullptr;
1346  view->internal = info.release();
1347  Py_INCREF(obj);
1348  return 0;
1349 }
1350 
1351 /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1352  Py_buffer *view) {
1353  delete reinterpret_cast<nb_buffer_info *>(view->internal);
1354 }
1355 
1356 /// Refinement of the PyDenseElementsAttribute for attributes containing
1357 /// integer (and boolean) values. Supports element access.
1358 class PyDenseIntElementsAttribute
1359  : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1360  PyDenseElementsAttribute> {
1361 public:
1362  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1363  static constexpr const char *pyClassName = "DenseIntElementsAttr";
1364  using PyConcreteAttribute::PyConcreteAttribute;
1365 
1366  /// Returns the element at the given linear position. Asserts if the index
1367  /// is out of range.
1368  nb::object dunderGetItem(intptr_t pos) {
1369  if (pos < 0 || pos >= dunderLen()) {
1370  throw nb::index_error("attempt to access out of bounds element");
1371  }
1372 
1373  MlirType type = mlirAttributeGetType(*this);
1374  type = mlirShapedTypeGetElementType(type);
1375  assert(mlirTypeIsAInteger(type) &&
1376  "expected integer element type in dense int elements attribute");
1377  // Dispatch element extraction to an appropriate C function based on the
1378  // elemental type of the attribute. nb::int_ is implicitly constructible
1379  // from any C++ integral type and handles bitwidth correctly.
1380  // TODO: consider caching the type properties in the constructor to avoid
1381  // querying them on each element access.
1382  unsigned width = mlirIntegerTypeGetWidth(type);
1383  bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1384  if (isUnsigned) {
1385  if (width == 1) {
1386  return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1387  }
1388  if (width == 8) {
1389  return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1390  }
1391  if (width == 16) {
1392  return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1393  }
1394  if (width == 32) {
1395  return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1396  }
1397  if (width == 64) {
1398  return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1399  }
1400  } else {
1401  if (width == 1) {
1402  return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1403  }
1404  if (width == 8) {
1405  return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1406  }
1407  if (width == 16) {
1408  return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1409  }
1410  if (width == 32) {
1411  return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1412  }
1413  if (width == 64) {
1414  return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1415  }
1416  }
1417  throw nb::type_error("Unsupported integer type");
1418  }
1419 
1420  static void bindDerived(ClassTy &c) {
1421  c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1422  }
1423 };
1424 
1425 class PyDenseResourceElementsAttribute
1426  : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1427 public:
1428  static constexpr IsAFunctionTy isaFunction =
1430  static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1431  using PyConcreteAttribute::PyConcreteAttribute;
1432 
1433  static PyDenseResourceElementsAttribute
1434  getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
1435  std::optional<size_t> alignment, bool isMutable,
1436  DefaultingPyMlirContext contextWrapper) {
1437  if (!mlirTypeIsAShaped(type)) {
1438  throw std::invalid_argument(
1439  "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1440  }
1441 
1442  // Do not request any conversions as we must ensure to use caller
1443  // managed memory.
1444  int flags = PyBUF_STRIDES;
1445  std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1446  if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1447  throw nb::python_error();
1448  }
1449 
1450  // This scope releaser will only release if we haven't yet transferred
1451  // ownership.
1452  auto freeBuffer = llvm::make_scope_exit([&]() {
1453  if (view)
1454  PyBuffer_Release(view.get());
1455  });
1456 
1457  if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1458  throw std::invalid_argument("Contiguous buffer is required.");
1459  }
1460 
1461  // Infer alignment to be the stride of one element if not explicit.
1462  size_t inferredAlignment;
1463  if (alignment)
1464  inferredAlignment = *alignment;
1465  else
1466  inferredAlignment = view->strides[view->ndim - 1];
1467 
1468  // The userData is a Py_buffer* that the deleter owns.
1469  auto deleter = [](void *userData, const void *data, size_t size,
1470  size_t align) {
1471  Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1472  PyBuffer_Release(ownedView);
1473  delete ownedView;
1474  };
1475 
1476  size_t rawBufferSize = view->len;
1477  MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1478  type, toMlirStringRef(name), view->buf, rawBufferSize,
1479  inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1480  if (mlirAttributeIsNull(attr)) {
1481  throw std::invalid_argument(
1482  "DenseResourceElementsAttr could not be constructed from the given "
1483  "buffer. "
1484  "This may mean that the Python buffer layout does not match that "
1485  "MLIR expected layout and is a bug.");
1486  }
1487  view.release();
1488  return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1489  }
1490 
1491  static void bindDerived(ClassTy &c) {
1492  c.def_static(
1493  "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
1494  nb::arg("array"), nb::arg("name"), nb::arg("type"),
1495  nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1496  nb::arg("context").none() = nb::none(),
1498  }
1499 };
1500 
1501 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1502 public:
1503  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1504  static constexpr const char *pyClassName = "DictAttr";
1505  using PyConcreteAttribute::PyConcreteAttribute;
1506  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1508 
1509  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1510 
1511  bool dunderContains(const std::string &name) {
1512  return !mlirAttributeIsNull(
1514  }
1515 
1516  static void bindDerived(ClassTy &c) {
1517  c.def("__contains__", &PyDictAttribute::dunderContains);
1518  c.def("__len__", &PyDictAttribute::dunderLen);
1519  c.def_static(
1520  "get",
1521  [](nb::dict attributes, DefaultingPyMlirContext context) {
1522  SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1523  mlirNamedAttributes.reserve(attributes.size());
1524  for (std::pair<nb::handle, nb::handle> it : attributes) {
1525  auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1526  auto name = nb::cast<std::string>(it.first);
1527  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1529  toMlirStringRef(name)),
1530  mlirAttr));
1531  }
1532  MlirAttribute attr =
1533  mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1534  mlirNamedAttributes.data());
1535  return PyDictAttribute(context->getRef(), attr);
1536  },
1537  nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
1538  "Gets an uniqued dict attribute");
1539  c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1540  MlirAttribute attr =
1542  if (mlirAttributeIsNull(attr))
1543  throw nb::key_error("attempt to access a non-existent attribute");
1544  return attr;
1545  });
1546  c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1547  if (index < 0 || index >= self.dunderLen()) {
1548  throw nb::index_error("attempt to access out of bounds attribute");
1549  }
1550  MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1551  return PyNamedAttribute(
1552  namedAttr.attribute,
1553  std::string(mlirIdentifierStr(namedAttr.name).data));
1554  });
1555  }
1556 };
1557 
1558 /// Refinement of PyDenseElementsAttribute for attributes containing
1559 /// floating-point values. Supports element access.
1560 class PyDenseFPElementsAttribute
1561  : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1562  PyDenseElementsAttribute> {
1563 public:
1564  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1565  static constexpr const char *pyClassName = "DenseFPElementsAttr";
1566  using PyConcreteAttribute::PyConcreteAttribute;
1567 
1568  nb::float_ dunderGetItem(intptr_t pos) {
1569  if (pos < 0 || pos >= dunderLen()) {
1570  throw nb::index_error("attempt to access out of bounds element");
1571  }
1572 
1573  MlirType type = mlirAttributeGetType(*this);
1574  type = mlirShapedTypeGetElementType(type);
1575  // Dispatch element extraction to an appropriate C function based on the
1576  // elemental type of the attribute. nb::float_ is implicitly constructible
1577  // from float and double.
1578  // TODO: consider caching the type properties in the constructor to avoid
1579  // querying them on each element access.
1580  if (mlirTypeIsAF32(type)) {
1581  return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1582  }
1583  if (mlirTypeIsAF64(type)) {
1584  return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1585  }
1586  throw nb::type_error("Unsupported floating-point type");
1587  }
1588 
1589  static void bindDerived(ClassTy &c) {
1590  c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1591  }
1592 };
1593 
1594 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1595 public:
1596  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1597  static constexpr const char *pyClassName = "TypeAttr";
1598  using PyConcreteAttribute::PyConcreteAttribute;
1599  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1601 
1602  static void bindDerived(ClassTy &c) {
1603  c.def_static(
1604  "get",
1605  [](PyType value, DefaultingPyMlirContext context) {
1606  MlirAttribute attr = mlirTypeAttrGet(value.get());
1607  return PyTypeAttribute(context->getRef(), attr);
1608  },
1609  nb::arg("value"), nb::arg("context").none() = nb::none(),
1610  "Gets a uniqued Type attribute");
1611  c.def_prop_ro("value", [](PyTypeAttribute &self) {
1612  return mlirTypeAttrGetValue(self.get());
1613  });
1614  }
1615 };
1616 
1617 /// Unit Attribute subclass. Unit attributes don't have values.
1618 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1619 public:
1620  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1621  static constexpr const char *pyClassName = "UnitAttr";
1622  using PyConcreteAttribute::PyConcreteAttribute;
1623  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1625 
1626  static void bindDerived(ClassTy &c) {
1627  c.def_static(
1628  "get",
1629  [](DefaultingPyMlirContext context) {
1630  return PyUnitAttribute(context->getRef(),
1631  mlirUnitAttrGet(context->get()));
1632  },
1633  nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
1634  }
1635 };
1636 
1637 /// Strided layout attribute subclass.
1638 class PyStridedLayoutAttribute
1639  : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1640 public:
1641  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1642  static constexpr const char *pyClassName = "StridedLayoutAttr";
1643  using PyConcreteAttribute::PyConcreteAttribute;
1644  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1646 
1647  static void bindDerived(ClassTy &c) {
1648  c.def_static(
1649  "get",
1650  [](int64_t offset, const std::vector<int64_t> strides,
1652  MlirAttribute attr = mlirStridedLayoutAttrGet(
1653  ctx->get(), offset, strides.size(), strides.data());
1654  return PyStridedLayoutAttribute(ctx->getRef(), attr);
1655  },
1656  nb::arg("offset"), nb::arg("strides"),
1657  nb::arg("context").none() = nb::none(),
1658  "Gets a strided layout attribute.");
1659  c.def_static(
1660  "get_fully_dynamic",
1661  [](int64_t rank, DefaultingPyMlirContext ctx) {
1662  auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1663  std::vector<int64_t> strides(rank);
1664  std::fill(strides.begin(), strides.end(), dynamic);
1665  MlirAttribute attr = mlirStridedLayoutAttrGet(
1666  ctx->get(), dynamic, strides.size(), strides.data());
1667  return PyStridedLayoutAttribute(ctx->getRef(), attr);
1668  },
1669  nb::arg("rank"), nb::arg("context").none() = nb::none(),
1670  "Gets a strided layout attribute with dynamic offset and strides of "
1671  "a "
1672  "given rank.");
1673  c.def_prop_ro(
1674  "offset",
1675  [](PyStridedLayoutAttribute &self) {
1676  return mlirStridedLayoutAttrGetOffset(self);
1677  },
1678  "Returns the value of the float point attribute");
1679  c.def_prop_ro(
1680  "strides",
1681  [](PyStridedLayoutAttribute &self) {
1682  intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1683  std::vector<int64_t> strides(size);
1684  for (intptr_t i = 0; i < size; i++) {
1685  strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1686  }
1687  return strides;
1688  },
1689  "Returns the value of the float point attribute");
1690  }
1691 };
1692 
1693 nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1694  if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1695  return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
1696  if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1697  return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
1698  if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1699  return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
1700  if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1701  return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
1702  if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1703  return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
1704  if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1705  return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
1706  if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1707  return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
1708  std::string msg =
1709  std::string("Can't cast unknown element type DenseArrayAttr (") +
1710  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1711  throw nb::type_error(msg.c_str());
1712 }
1713 
1714 nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1715  if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1716  return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
1717  if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1718  return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
1719  std::string msg =
1720  std::string(
1721  "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1722  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1723  throw nb::type_error(msg.c_str());
1724 }
1725 
1726 nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1727  if (PyBoolAttribute::isaFunction(pyAttribute))
1728  return nb::cast(PyBoolAttribute(pyAttribute));
1729  if (PyIntegerAttribute::isaFunction(pyAttribute))
1730  return nb::cast(PyIntegerAttribute(pyAttribute));
1731  std::string msg =
1732  std::string("Can't cast unknown element type DenseArrayAttr (") +
1733  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1734  throw nb::type_error(msg.c_str());
1735 }
1736 
1737 nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1738  if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1739  return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
1740  if (PySymbolRefAttribute::isaFunction(pyAttribute))
1741  return nb::cast(PySymbolRefAttribute(pyAttribute));
1742  std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1743  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1744  ")";
1745  throw nb::type_error(msg.c_str());
1746 }
1747 
1748 } // namespace
1749 
1750 void mlir::python::populateIRAttributes(nb::module_ &m) {
1751  PyAffineMapAttribute::bind(m);
1752  PyDenseBoolArrayAttribute::bind(m);
1753  PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1754  PyDenseI8ArrayAttribute::bind(m);
1755  PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1756  PyDenseI16ArrayAttribute::bind(m);
1757  PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1758  PyDenseI32ArrayAttribute::bind(m);
1759  PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1760  PyDenseI64ArrayAttribute::bind(m);
1761  PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1762  PyDenseF32ArrayAttribute::bind(m);
1763  PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1764  PyDenseF64ArrayAttribute::bind(m);
1765  PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1766  PyGlobals::get().registerTypeCaster(
1768  nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1769 
1770  PyArrayAttribute::bind(m);
1771  PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1772  PyBoolAttribute::bind(m);
1773  PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1774  PyDenseFPElementsAttribute::bind(m);
1775  PyDenseIntElementsAttribute::bind(m);
1776  PyGlobals::get().registerTypeCaster(
1778  nb::cast<nb::callable>(
1779  nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1780  PyDenseResourceElementsAttribute::bind(m);
1781 
1782  PyDictAttribute::bind(m);
1783  PySymbolRefAttribute::bind(m);
1784  PyGlobals::get().registerTypeCaster(
1786  nb::cast<nb::callable>(
1787  nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
1788 
1789  PyFlatSymbolRefAttribute::bind(m);
1790  PyOpaqueAttribute::bind(m);
1791  PyFloatAttribute::bind(m);
1792  PyIntegerAttribute::bind(m);
1793  PyIntegerSetAttribute::bind(m);
1794  PyStringAttribute::bind(m);
1795  PyTypeAttribute::bind(m);
1796  PyGlobals::get().registerTypeCaster(
1798  nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1799  PyUnitAttribute::bind(m);
1800 
1801  PyStridedLayoutAttribute::bind(m);
1802 }
static const char kDenseElementsAttrGetDocstring[]
static const char kDenseResourceElementsAttrGetFromBufferDocstring[]
static const char kDenseElementsAttrGetFromListDocstring[]
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:210
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
std::string str() const
Converts the diagnostic to a string.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:517
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:291
static PyMlirContext & resolve()
Definition: IRCore.cpp:794
ReferrentTy * get() const
Definition: NanobindUtils.h:55
MlirAffineMap get() const
Definition: IRModule.h:1206
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1003
MlirAttribute get() const
Definition: IRModule.h:1009
CRTP base classes for Python attributes that subclass Attribute and should be castable from it (i....
Definition: IRModule.h:1053
nanobind::class_< DerivedTy, BaseTy > ClassTy
Definition: IRModule.h:1058
MlirIntegerSet get() const
Definition: IRModule.h:1227
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:186
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1027
Wrapper around the generic MlirType.
Definition: IRModule.h:877
MlirType get() const
Definition: IRModule.h:883
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map)
Creates an affine map attribute wrapping the given map.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type)
Creates an opaque attribute in the given context associated with the dialect identified by its namesp...
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, double value)
Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a construction of a FloatAttr,...
MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr)
Returns the affine map wrapped in the given affine map attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, const int64_t *strides)
MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void)
Returns the typeID of a String attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr)
Checks whether the given attribute is a unit attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr)
Checks whether the given attribute is a dense elements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr)
Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAAffineMap(MlirAttribute attr)
Checks whether the given attribute is an affine map attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol)
Creates a flat symbol reference attribute in the given context referencing a symbol identified by the...
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void)
Returns the typeID of a StridedLayout attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void)
Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseResourceElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr)
Returns the string reference to the root referenced symbol.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAType(MlirAttribute attr)
Checks whether the given attribute is a type attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void)
Returns the typeID of an DenseIntOrFPElements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAArray(MlirAttribute attr)
Checks whether the given attribute is an array attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAInteger(MlirAttribute attr)
Checks whether the given attribute is an integer attribute.
MLIR_CAPI_EXPORTED intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr)
Returns the number of attributes contained in a dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set)
Creates an integer set attribute wrapping the given set.
MLIR_CAPI_EXPORTED uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr)
Returns the value stored in the given bool attribute.
MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value)
Creates an integer attribute of the given type with the given integer value.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void)
Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos)
Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense e...
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute const *elements)
Creates a dictionary attribute containing the given list of elements in the provided context.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, size_t dataAlignment, bool dataIsMutable, void(*deleter)(void *userData, const void *data, size_t size, size_t align), void *userData)
Unlike the typed accessors below, constructs the attribute with a raw data buffer and no type/alignme...
MLIR_CAPI_EXPORTED bool mlirAttributeIsABool(MlirAttribute attr)
Checks whether the given attribute is a bool attribute.
MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos)
Get an element of a dense array.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, int64_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void)
Returns the typeID of an AffineMap attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos)
Returns pos-th reference nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void)
Returns the typeID of an Array attribute.
MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, double const *values)
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signless type and f...
MLIR_CAPI_EXPORTED intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)
Returns the number of references nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr)
Returns the type stored in the given type attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrIsSplat(MlirAttribute attr)
Checks whether the given dense elements attribute contains a single replicated value (splat).
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements)
Creates a dense elements attribute with the given Shaped type and elements in the same context as the...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr)
Checks whether the given attribute is a dense array attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value)
Creates a bool attribute in the given context with the given value.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void)
Returns the typeID of a Float attribute.
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signed type and fit...
MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element of the given dictionary attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void)
Returns the typeID of a Unit attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, float const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values)
Create a dense array attribute with the given elements.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaque(MlirAttribute attr)
Checks whether the given attribute is an opaque attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name)
Returns the dictionary attribute element with the given name or NULL if the given name does not exist...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void)
Returns the typeID of an SymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)
Returns the single replicated value (splat) of a specific type contained by the given dense elements ...
MLIR_CAPI_EXPORTED bool mlirAttributeIsASymbolRef(MlirAttribute attr)
Checks whether the given attribute is a symbol reference attribute.
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAString(MlirAttribute attr)
Checks whether the given attribute is a string attribute.
MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr)
Gets the total number of elements in the given elements attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)
Returns the namespace of the dialect with which the given opaque attribute is associated.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADictionary(MlirAttribute attr)
Checks whether the given attribute is a dictionary attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr)
Returns the attribute values as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, int16_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void)
Returns the typeID of an Opaque attribute.
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr)
Returns the value stored in the given floating point attribute, interpreting the value as double.
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of unsigned type and f...
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, int8_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx)
Creates a unit attribute in the given context.
MLIR_CAPI_EXPORTED double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value)
Creates a floating point attribute in the given context with the given double value and double-precis...
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements)
Creates an array element containing the given list of elements in the given context.
MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr)
Returns the raw data of the given dense elements attribute.
MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr)
Get the size of a dense array.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFloat(MlirAttribute attr)
Checks whether the given attribute is a floating point attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references)
Creates a symbol reference attribute in the given context referencing a symbol identified by the give...
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void)
Returns the typeID of a Type attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element)
Creates a dense elements attribute with the given Shaped type containing a single replicated element ...
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void)
Returns the typeID of a Dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type)
Creates a type attribute wrapping the given type in the same context as the type.
MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr)
Returns the number of elements stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer)
Creates a dense elements attribute with the given Shaped type and elements populated from a packed,...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)
Checks whether the given attribute is a flat symbol reference attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)
Returns the referenced symbol as a string reference.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1043
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1128
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1149
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1101
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1066
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1078
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1097
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1137
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
void populateIRAttributes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1294
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:426
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:436
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.