MLIR  21.0.0git
IRAttributes.cpp
Go to the documentation of this file.
1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cstdint>
10 #include <optional>
11 #include <string>
12 #include <string_view>
13 #include <utility>
14 
15 #include "IRModule.h"
16 #include "NanobindUtils.h"
18 #include "mlir-c/BuiltinTypes.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 namespace nb = nanobind;
25 using namespace nanobind::literals;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kDenseElementsAttrGetDocstring[] =
36  R"(Gets a DenseElementsAttr from a Python buffer or array.
37 
38 When `type` is not provided, then some limited type inferencing is done based
39 on the buffer format. Support presently exists for 8/16/32/64 signed and
40 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41 types can also be converted back to a corresponding buffer.
42 
43 For conversions outside of these types, a `type=` must be explicitly provided
44 and the buffer contents must be bit-castable to the MLIR internal
45 representation:
46 
47  * Integer types (except for i1): the buffer must be byte aligned to the
48  next byte boundary.
49  * Floating point types: Must be bit-castable to the given floating point
50  size.
51  * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52  row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53  this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
54 
55 If a single element buffer is passed (or for i1, a single byte with value 0
56 or 255), then a splat will be created.
57 
58 Args:
59  array: The array or buffer to convert.
60  signless: If inferring an appropriate MLIR type, use signless types for
61  integers (defaults True).
62  type: Skips inference of the MLIR element type and uses this instead. The
63  storage size must be consistent with the actual contents of the buffer.
64  shape: Overrides the shape of the buffer when constructing the MLIR
65  shaped type. This is needed when the physical and logical shape differ (as
66  for i1).
67  context: Explicit context, if not from context manager.
68 
69 Returns:
70  DenseElementsAttr on success.
71 
72 Raises:
73  ValueError: If the type of the buffer or array cannot be matched to an MLIR
74  type or if the buffer does not meet expectations.
75 )";
76 
78  R"(Gets a DenseElementsAttr from a Python list of attributes.
79 
80 Note that it can be expensive to construct attributes individually.
81 For a large number of elements, consider using a Python buffer or array instead.
82 
83 Args:
84  attrs: A list of attributes.
85  type: The desired shape and type of the resulting DenseElementsAttr.
86  If not provided, the element type is determined based on the type
87  of the 0th attribute and the shape is `[len(attrs)]`.
88  context: Explicit context, if not from context manager.
89 
90 Returns:
91  DenseElementsAttr on success.
92 
93 Raises:
94  ValueError: If the type of the attributes does not match the type
95  specified by `shaped_type`.
96 )";
97 
99  R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100 
101 This function does minimal validation or massaging of the data, and it is
102 up to the caller to ensure that the buffer meets the characteristics
103 implied by the shape.
104 
105 The backing buffer and any user objects will be retained for the lifetime
106 of the resource blob. This is typically bounded to the context but the
107 resource can have a shorter lifespan depending on how it is used in
108 subsequent processing.
109 
110 Args:
111  buffer: The array or buffer to convert.
112  name: Name to provide to the resource (may be changed upon collision).
113  type: The explicit ShapedType to construct the attribute with.
114  context: Explicit context, if not from context manager.
115 
116 Returns:
117  DenseResourceElementsAttr on success.
118 
119 Raises:
120  ValueError: If the type of the buffer or array cannot be matched to an MLIR
121  type or if the buffer does not meet expectations.
122 )";
123 
124 namespace {
125 
126 struct nb_buffer_info {
127  void *ptr = nullptr;
128  ssize_t itemsize = 0;
129  ssize_t size = 0;
130  const char *format = nullptr;
131  ssize_t ndim = 0;
133  SmallVector<ssize_t, 4> strides;
134  bool readonly = false;
135 
136  nb_buffer_info(
137  void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
139  bool readonly = false,
140  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142  : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143  shape(std::move(shape_in)), strides(std::move(strides_in)),
144  readonly(readonly), owned_view(std::move(owned_view_in)) {
145  size = 1;
146  for (ssize_t i = 0; i < ndim; ++i) {
147  size *= shape[i];
148  }
149  }
150 
151  explicit nb_buffer_info(Py_buffer *view)
152  : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153  {view->shape, view->shape + view->ndim},
154  // TODO(phawkins): check for null strides
155  {view->strides, view->strides + view->ndim},
156  view->readonly != 0,
157  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158  view, PyBuffer_Release)) {}
159 
160  nb_buffer_info(const nb_buffer_info &) = delete;
161  nb_buffer_info(nb_buffer_info &&) = default;
162  nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163  nb_buffer_info &operator=(nb_buffer_info &&) = default;
164 
165 private:
166  std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167 };
168 
169 class nb_buffer : public nb::object {
170  NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
171 
172  nb_buffer_info request() const {
173  int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174  auto *view = new Py_buffer();
175  if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176  delete view;
177  throw nb::python_error();
178  }
179  return nb_buffer_info(view);
180  }
181 };
182 
183 template <typename T>
184 struct nb_format_descriptor {};
185 
186 template <>
187 struct nb_format_descriptor<bool> {
188  static const char *format() { return "?"; }
189 };
190 template <>
191 struct nb_format_descriptor<int8_t> {
192  static const char *format() { return "b"; }
193 };
194 template <>
195 struct nb_format_descriptor<uint8_t> {
196  static const char *format() { return "B"; }
197 };
198 template <>
199 struct nb_format_descriptor<int16_t> {
200  static const char *format() { return "h"; }
201 };
202 template <>
203 struct nb_format_descriptor<uint16_t> {
204  static const char *format() { return "H"; }
205 };
206 template <>
207 struct nb_format_descriptor<int32_t> {
208  static const char *format() { return "i"; }
209 };
210 template <>
211 struct nb_format_descriptor<uint32_t> {
212  static const char *format() { return "I"; }
213 };
214 template <>
215 struct nb_format_descriptor<int64_t> {
216  static const char *format() { return "q"; }
217 };
218 template <>
219 struct nb_format_descriptor<uint64_t> {
220  static const char *format() { return "Q"; }
221 };
222 template <>
223 struct nb_format_descriptor<float> {
224  static const char *format() { return "f"; }
225 };
226 template <>
227 struct nb_format_descriptor<double> {
228  static const char *format() { return "d"; }
229 };
230 
231 static MlirStringRef toMlirStringRef(const std::string &s) {
232  return mlirStringRefCreate(s.data(), s.size());
233 }
234 
235 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236  return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237 }
238 
239 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240 public:
241  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242  static constexpr const char *pyClassName = "AffineMapAttr";
243  using PyConcreteAttribute::PyConcreteAttribute;
244  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
246 
247  static void bindDerived(ClassTy &c) {
248  c.def_static(
249  "get",
250  [](PyAffineMap &affineMap) {
251  MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252  return PyAffineMapAttribute(affineMap.getContext(), attr);
253  },
254  nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255  c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256  "Returns the value of the AffineMap attribute");
257  }
258 };
259 
260 class PyIntegerSetAttribute
261  : public PyConcreteAttribute<PyIntegerSetAttribute> {
262 public:
263  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
264  static constexpr const char *pyClassName = "IntegerSetAttr";
265  using PyConcreteAttribute::PyConcreteAttribute;
266  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
268 
269  static void bindDerived(ClassTy &c) {
270  c.def_static(
271  "get",
272  [](PyIntegerSet &integerSet) {
273  MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
274  return PyIntegerSetAttribute(integerSet.getContext(), attr);
275  },
276  nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
277  }
278 };
279 
280 template <typename T>
281 static T pyTryCast(nb::handle object) {
282  try {
283  return nb::cast<T>(object);
284  } catch (nb::cast_error &err) {
285  std::string msg = std::string("Invalid attribute when attempting to "
286  "create an ArrayAttribute (") +
287  err.what() + ")";
288  throw std::runtime_error(msg.c_str());
289  } catch (std::runtime_error &err) {
290  std::string msg = std::string("Invalid attribute (None?) when attempting "
291  "to create an ArrayAttribute (") +
292  err.what() + ")";
293  throw std::runtime_error(msg.c_str());
294  }
295 }
296 
297 /// A python-wrapped dense array attribute with an element type and a derived
298 /// implementation class.
299 template <typename EltTy, typename DerivedT>
300 class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
301 public:
303 
304  /// Iterator over the integer elements of a dense array.
305  class PyDenseArrayIterator {
306  public:
307  PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
308 
309  /// Return a copy of the iterator.
310  PyDenseArrayIterator dunderIter() { return *this; }
311 
312  /// Return the next element.
313  EltTy dunderNext() {
314  // Throw if the index has reached the end.
316  throw nb::stop_iteration();
317  return DerivedT::getElement(attr.get(), nextIndex++);
318  }
319 
320  /// Bind the iterator class.
321  static void bind(nb::module_ &m) {
322  nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
323  .def("__iter__", &PyDenseArrayIterator::dunderIter)
324  .def("__next__", &PyDenseArrayIterator::dunderNext);
325  }
326 
327  private:
328  /// The referenced dense array attribute.
329  PyAttribute attr;
330  /// The next index to read.
331  int nextIndex = 0;
332  };
333 
334  /// Get the element at the given index.
335  EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
336 
337  /// Bind the attribute class.
338  static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
339  // Bind the constructor.
340  if constexpr (std::is_same_v<EltTy, bool>) {
341  c.def_static(
342  "get",
343  [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
344  std::vector<bool> values;
345  for (nb::handle py_value : py_values) {
346  int is_true = PyObject_IsTrue(py_value.ptr());
347  if (is_true < 0) {
348  throw nb::python_error();
349  }
350  values.push_back(is_true);
351  }
352  return getAttribute(values, ctx->getRef());
353  },
354  nb::arg("values"), nb::arg("context").none() = nb::none(),
355  "Gets a uniqued dense array attribute");
356  } else {
357  c.def_static(
358  "get",
359  [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
360  return getAttribute(values, ctx->getRef());
361  },
362  nb::arg("values"), nb::arg("context").none() = nb::none(),
363  "Gets a uniqued dense array attribute");
364  }
365  // Bind the array methods.
366  c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
367  if (i >= mlirDenseArrayGetNumElements(arr))
368  throw nb::index_error("DenseArray index out of range");
369  return arr.getItem(i);
370  });
371  c.def("__len__", [](const DerivedT &arr) {
372  return mlirDenseArrayGetNumElements(arr);
373  });
374  c.def("__iter__",
375  [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
376  c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
377  std::vector<EltTy> values;
378  intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
379  values.reserve(numOldElements + nb::len(extras));
380  for (intptr_t i = 0; i < numOldElements; ++i)
381  values.push_back(arr.getItem(i));
382  for (nb::handle attr : extras)
383  values.push_back(pyTryCast<EltTy>(attr));
384  return getAttribute(values, arr.getContext());
385  });
386  }
387 
388 private:
389  static DerivedT getAttribute(const std::vector<EltTy> &values,
390  PyMlirContextRef ctx) {
391  if constexpr (std::is_same_v<EltTy, bool>) {
392  std::vector<int> intValues(values.begin(), values.end());
393  MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
394  intValues.data());
395  return DerivedT(ctx, attr);
396  } else {
397  MlirAttribute attr =
398  DerivedT::getAttribute(ctx->get(), values.size(), values.data());
399  return DerivedT(ctx, attr);
400  }
401  }
402 };
403 
404 /// Instantiate the python dense array classes.
405 struct PyDenseBoolArrayAttribute
406  : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
407  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
408  static constexpr auto getAttribute = mlirDenseBoolArrayGet;
409  static constexpr auto getElement = mlirDenseBoolArrayGetElement;
410  static constexpr const char *pyClassName = "DenseBoolArrayAttr";
411  static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
412  using PyDenseArrayAttribute::PyDenseArrayAttribute;
413 };
414 struct PyDenseI8ArrayAttribute
415  : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
416  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
417  static constexpr auto getAttribute = mlirDenseI8ArrayGet;
418  static constexpr auto getElement = mlirDenseI8ArrayGetElement;
419  static constexpr const char *pyClassName = "DenseI8ArrayAttr";
420  static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
421  using PyDenseArrayAttribute::PyDenseArrayAttribute;
422 };
423 struct PyDenseI16ArrayAttribute
424  : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
425  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
426  static constexpr auto getAttribute = mlirDenseI16ArrayGet;
427  static constexpr auto getElement = mlirDenseI16ArrayGetElement;
428  static constexpr const char *pyClassName = "DenseI16ArrayAttr";
429  static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
430  using PyDenseArrayAttribute::PyDenseArrayAttribute;
431 };
432 struct PyDenseI32ArrayAttribute
433  : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
434  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
435  static constexpr auto getAttribute = mlirDenseI32ArrayGet;
436  static constexpr auto getElement = mlirDenseI32ArrayGetElement;
437  static constexpr const char *pyClassName = "DenseI32ArrayAttr";
438  static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
439  using PyDenseArrayAttribute::PyDenseArrayAttribute;
440 };
441 struct PyDenseI64ArrayAttribute
442  : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
443  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
444  static constexpr auto getAttribute = mlirDenseI64ArrayGet;
445  static constexpr auto getElement = mlirDenseI64ArrayGetElement;
446  static constexpr const char *pyClassName = "DenseI64ArrayAttr";
447  static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
448  using PyDenseArrayAttribute::PyDenseArrayAttribute;
449 };
450 struct PyDenseF32ArrayAttribute
451  : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
452  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
453  static constexpr auto getAttribute = mlirDenseF32ArrayGet;
454  static constexpr auto getElement = mlirDenseF32ArrayGetElement;
455  static constexpr const char *pyClassName = "DenseF32ArrayAttr";
456  static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
457  using PyDenseArrayAttribute::PyDenseArrayAttribute;
458 };
459 struct PyDenseF64ArrayAttribute
460  : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
461  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
462  static constexpr auto getAttribute = mlirDenseF64ArrayGet;
463  static constexpr auto getElement = mlirDenseF64ArrayGetElement;
464  static constexpr const char *pyClassName = "DenseF64ArrayAttr";
465  static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
466  using PyDenseArrayAttribute::PyDenseArrayAttribute;
467 };
468 
469 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
470 public:
471  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
472  static constexpr const char *pyClassName = "ArrayAttr";
473  using PyConcreteAttribute::PyConcreteAttribute;
474  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
476 
477  class PyArrayAttributeIterator {
478  public:
479  PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
480 
481  PyArrayAttributeIterator &dunderIter() { return *this; }
482 
483  MlirAttribute dunderNext() {
484  // TODO: Throw is an inefficient way to stop iteration.
486  throw nb::stop_iteration();
487  return mlirArrayAttrGetElement(attr.get(), nextIndex++);
488  }
489 
490  static void bind(nb::module_ &m) {
491  nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
492  .def("__iter__", &PyArrayAttributeIterator::dunderIter)
493  .def("__next__", &PyArrayAttributeIterator::dunderNext);
494  }
495 
496  private:
497  PyAttribute attr;
498  int nextIndex = 0;
499  };
500 
501  MlirAttribute getItem(intptr_t i) {
502  return mlirArrayAttrGetElement(*this, i);
503  }
504 
505  static void bindDerived(ClassTy &c) {
506  c.def_static(
507  "get",
508  [](nb::list attributes, DefaultingPyMlirContext context) {
509  SmallVector<MlirAttribute> mlirAttributes;
510  mlirAttributes.reserve(nb::len(attributes));
511  for (auto attribute : attributes) {
512  mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
513  }
514  MlirAttribute attr = mlirArrayAttrGet(
515  context->get(), mlirAttributes.size(), mlirAttributes.data());
516  return PyArrayAttribute(context->getRef(), attr);
517  },
518  nb::arg("attributes"), nb::arg("context").none() = nb::none(),
519  "Gets a uniqued Array attribute");
520  c.def("__getitem__",
521  [](PyArrayAttribute &arr, intptr_t i) {
522  if (i >= mlirArrayAttrGetNumElements(arr))
523  throw nb::index_error("ArrayAttribute index out of range");
524  return arr.getItem(i);
525  })
526  .def("__len__",
527  [](const PyArrayAttribute &arr) {
528  return mlirArrayAttrGetNumElements(arr);
529  })
530  .def("__iter__", [](const PyArrayAttribute &arr) {
531  return PyArrayAttributeIterator(arr);
532  });
533  c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
534  std::vector<MlirAttribute> attributes;
535  intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
536  attributes.reserve(numOldElements + nb::len(extras));
537  for (intptr_t i = 0; i < numOldElements; ++i)
538  attributes.push_back(arr.getItem(i));
539  for (nb::handle attr : extras)
540  attributes.push_back(pyTryCast<PyAttribute>(attr));
541  MlirAttribute arrayAttr = mlirArrayAttrGet(
542  arr.getContext()->get(), attributes.size(), attributes.data());
543  return PyArrayAttribute(arr.getContext(), arrayAttr);
544  });
545  }
546 };
547 
548 /// Float Point Attribute subclass - FloatAttr.
549 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
550 public:
551  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
552  static constexpr const char *pyClassName = "FloatAttr";
553  using PyConcreteAttribute::PyConcreteAttribute;
554  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
556 
557  static void bindDerived(ClassTy &c) {
558  c.def_static(
559  "get",
560  [](PyType &type, double value, DefaultingPyLocation loc) {
561  PyMlirContext::ErrorCapture errors(loc->getContext());
562  MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
563  if (mlirAttributeIsNull(attr))
564  throw MLIRError("Invalid attribute", errors.take());
565  return PyFloatAttribute(type.getContext(), attr);
566  },
567  nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
568  "Gets an uniqued float point attribute associated to a type");
569  c.def_static(
570  "get_f32",
571  [](double value, DefaultingPyMlirContext context) {
572  MlirAttribute attr = mlirFloatAttrDoubleGet(
573  context->get(), mlirF32TypeGet(context->get()), value);
574  return PyFloatAttribute(context->getRef(), attr);
575  },
576  nb::arg("value"), nb::arg("context").none() = nb::none(),
577  "Gets an uniqued float point attribute associated to a f32 type");
578  c.def_static(
579  "get_f64",
580  [](double value, DefaultingPyMlirContext context) {
581  MlirAttribute attr = mlirFloatAttrDoubleGet(
582  context->get(), mlirF64TypeGet(context->get()), value);
583  return PyFloatAttribute(context->getRef(), attr);
584  },
585  nb::arg("value"), nb::arg("context").none() = nb::none(),
586  "Gets an uniqued float point attribute associated to a f64 type");
587  c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
588  "Returns the value of the float attribute");
589  c.def("__float__", mlirFloatAttrGetValueDouble,
590  "Converts the value of the float attribute to a Python float");
591  }
592 };
593 
594 /// Integer Attribute subclass - IntegerAttr.
595 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
596 public:
597  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
598  static constexpr const char *pyClassName = "IntegerAttr";
599  using PyConcreteAttribute::PyConcreteAttribute;
600 
601  static void bindDerived(ClassTy &c) {
602  c.def_static(
603  "get",
604  [](PyType &type, int64_t value) {
605  MlirAttribute attr = mlirIntegerAttrGet(type, value);
606  return PyIntegerAttribute(type.getContext(), attr);
607  },
608  nb::arg("type"), nb::arg("value"),
609  "Gets an uniqued integer attribute associated to a type");
610  c.def_prop_ro("value", toPyInt,
611  "Returns the value of the integer attribute");
612  c.def("__int__", toPyInt,
613  "Converts the value of the integer attribute to a Python int");
614  c.def_prop_ro_static("static_typeid",
615  [](nb::object & /*class*/) -> MlirTypeID {
616  return mlirIntegerAttrGetTypeID();
617  });
618  }
619 
620 private:
621  static int64_t toPyInt(PyIntegerAttribute &self) {
622  MlirType type = mlirAttributeGetType(self);
623  if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624  return mlirIntegerAttrGetValueInt(self);
625  if (mlirIntegerTypeIsSigned(type))
626  return mlirIntegerAttrGetValueSInt(self);
627  return mlirIntegerAttrGetValueUInt(self);
628  }
629 };
630 
631 /// Bool Attribute subclass - BoolAttr.
632 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
633 public:
634  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
635  static constexpr const char *pyClassName = "BoolAttr";
636  using PyConcreteAttribute::PyConcreteAttribute;
637 
638  static void bindDerived(ClassTy &c) {
639  c.def_static(
640  "get",
641  [](bool value, DefaultingPyMlirContext context) {
642  MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
643  return PyBoolAttribute(context->getRef(), attr);
644  },
645  nb::arg("value"), nb::arg("context").none() = nb::none(),
646  "Gets an uniqued bool attribute");
647  c.def_prop_ro("value", mlirBoolAttrGetValue,
648  "Returns the value of the bool attribute");
649  c.def("__bool__", mlirBoolAttrGetValue,
650  "Converts the value of the bool attribute to a Python bool");
651  }
652 };
653 
654 class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
655 public:
656  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
657  static constexpr const char *pyClassName = "SymbolRefAttr";
658  using PyConcreteAttribute::PyConcreteAttribute;
659 
660  static MlirAttribute fromList(const std::vector<std::string> &symbols,
661  PyMlirContext &context) {
662  if (symbols.empty())
663  throw std::runtime_error("SymbolRefAttr must be composed of at least "
664  "one symbol.");
665  MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
666  SmallVector<MlirAttribute, 3> referenceAttrs;
667  for (size_t i = 1; i < symbols.size(); ++i) {
668  referenceAttrs.push_back(
669  mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
670  }
671  return mlirSymbolRefAttrGet(context.get(), rootSymbol,
672  referenceAttrs.size(), referenceAttrs.data());
673  }
674 
675  static void bindDerived(ClassTy &c) {
676  c.def_static(
677  "get",
678  [](const std::vector<std::string> &symbols,
679  DefaultingPyMlirContext context) {
680  return PySymbolRefAttribute::fromList(symbols, context.resolve());
681  },
682  nb::arg("symbols"), nb::arg("context").none() = nb::none(),
683  "Gets a uniqued SymbolRef attribute from a list of symbol names");
684  c.def_prop_ro(
685  "value",
686  [](PySymbolRefAttribute &self) {
687  std::vector<std::string> symbols = {
689  for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
690  ++i)
691  symbols.push_back(
694  .str());
695  return symbols;
696  },
697  "Returns the value of the SymbolRef attribute as a list[str]");
698  }
699 };
700 
701 class PyFlatSymbolRefAttribute
702  : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
703 public:
704  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
705  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
706  using PyConcreteAttribute::PyConcreteAttribute;
707 
708  static void bindDerived(ClassTy &c) {
709  c.def_static(
710  "get",
711  [](std::string value, DefaultingPyMlirContext context) {
712  MlirAttribute attr =
713  mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
714  return PyFlatSymbolRefAttribute(context->getRef(), attr);
715  },
716  nb::arg("value"), nb::arg("context").none() = nb::none(),
717  "Gets a uniqued FlatSymbolRef attribute");
718  c.def_prop_ro(
719  "value",
720  [](PyFlatSymbolRefAttribute &self) {
722  return nb::str(stringRef.data, stringRef.length);
723  },
724  "Returns the value of the FlatSymbolRef attribute as a string");
725  }
726 };
727 
728 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
729 public:
730  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
731  static constexpr const char *pyClassName = "OpaqueAttr";
732  using PyConcreteAttribute::PyConcreteAttribute;
733  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
735 
736  static void bindDerived(ClassTy &c) {
737  c.def_static(
738  "get",
739  [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
740  DefaultingPyMlirContext context) {
741  const nb_buffer_info bufferInfo = buffer.request();
742  intptr_t bufferSize = bufferInfo.size;
743  MlirAttribute attr = mlirOpaqueAttrGet(
744  context->get(), toMlirStringRef(dialectNamespace), bufferSize,
745  static_cast<char *>(bufferInfo.ptr), type);
746  return PyOpaqueAttribute(context->getRef(), attr);
747  },
748  nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749  nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
750  c.def_prop_ro(
751  "dialect_namespace",
752  [](PyOpaqueAttribute &self) {
754  return nb::str(stringRef.data, stringRef.length);
755  },
756  "Returns the dialect namespace for the Opaque attribute as a string");
757  c.def_prop_ro(
758  "data",
759  [](PyOpaqueAttribute &self) {
760  MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
761  return nb::bytes(stringRef.data, stringRef.length);
762  },
763  "Returns the data for the Opaqued attributes as `bytes`");
764  }
765 };
766 
767 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768 public:
769  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770  static constexpr const char *pyClassName = "StringAttr";
771  using PyConcreteAttribute::PyConcreteAttribute;
772  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
774 
775  static void bindDerived(ClassTy &c) {
776  c.def_static(
777  "get",
778  [](std::string value, DefaultingPyMlirContext context) {
779  MlirAttribute attr =
780  mlirStringAttrGet(context->get(), toMlirStringRef(value));
781  return PyStringAttribute(context->getRef(), attr);
782  },
783  nb::arg("value"), nb::arg("context").none() = nb::none(),
784  "Gets a uniqued string attribute");
785  c.def_static(
786  "get",
787  [](nb::bytes value, DefaultingPyMlirContext context) {
788  MlirAttribute attr =
789  mlirStringAttrGet(context->get(), toMlirStringRef(value));
790  return PyStringAttribute(context->getRef(), attr);
791  },
792  nb::arg("value"), nb::arg("context").none() = nb::none(),
793  "Gets a uniqued string attribute");
794  c.def_static(
795  "get_typed",
796  [](PyType &type, std::string value) {
797  MlirAttribute attr =
799  return PyStringAttribute(type.getContext(), attr);
800  },
801  nb::arg("type"), nb::arg("value"),
802  "Gets a uniqued string attribute associated to a type");
803  c.def_prop_ro(
804  "value",
805  [](PyStringAttribute &self) {
806  MlirStringRef stringRef = mlirStringAttrGetValue(self);
807  return nb::str(stringRef.data, stringRef.length);
808  },
809  "Returns the value of the string attribute");
810  c.def_prop_ro(
811  "value_bytes",
812  [](PyStringAttribute &self) {
813  MlirStringRef stringRef = mlirStringAttrGetValue(self);
814  return nb::bytes(stringRef.data, stringRef.length);
815  },
816  "Returns the value of the string attribute as `bytes`");
817  }
818 };
819 
820 // TODO: Support construction of string elements.
821 class PyDenseElementsAttribute
822  : public PyConcreteAttribute<PyDenseElementsAttribute> {
823 public:
824  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
825  static constexpr const char *pyClassName = "DenseElementsAttr";
826  using PyConcreteAttribute::PyConcreteAttribute;
827 
828  static PyDenseElementsAttribute
829  getFromList(nb::list attributes, std::optional<PyType> explicitType,
830  DefaultingPyMlirContext contextWrapper) {
831  const size_t numAttributes = nb::len(attributes);
832  if (numAttributes == 0)
833  throw nb::value_error("Attributes list must be non-empty.");
834 
835  MlirType shapedType;
836  if (explicitType) {
837  if ((!mlirTypeIsAShaped(*explicitType) ||
838  !mlirShapedTypeHasStaticShape(*explicitType))) {
839 
840  std::string message;
841  llvm::raw_string_ostream os(message);
842  os << "Expected a static ShapedType for the shaped_type parameter: "
843  << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
844  throw nb::value_error(message.c_str());
845  }
846  shapedType = *explicitType;
847  } else {
848  SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
849  shapedType = mlirRankedTensorTypeGet(
850  shape.size(), shape.data(),
851  mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
853  }
854 
855  SmallVector<MlirAttribute> mlirAttributes;
856  mlirAttributes.reserve(numAttributes);
857  for (const nb::handle &attribute : attributes) {
858  MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
859  MlirType attrType = mlirAttributeGetType(mlirAttribute);
860  mlirAttributes.push_back(mlirAttribute);
861 
862  if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
863  std::string message;
864  llvm::raw_string_ostream os(message);
865  os << "All attributes must be of the same type and match "
866  << "the type parameter: expected="
867  << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
868  << ", but got="
869  << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
870  throw nb::value_error(message.c_str());
871  }
872  }
873 
874  MlirAttribute elements = mlirDenseElementsAttrGet(
875  shapedType, mlirAttributes.size(), mlirAttributes.data());
876 
877  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
878  }
879 
880  static PyDenseElementsAttribute
881  getFromBuffer(nb_buffer array, bool signless,
882  std::optional<PyType> explicitType,
883  std::optional<std::vector<int64_t>> explicitShape,
884  DefaultingPyMlirContext contextWrapper) {
885  // Request a contiguous view. In exotic cases, this will cause a copy.
886  int flags = PyBUF_ND;
887  if (!explicitType) {
888  flags |= PyBUF_FORMAT;
889  }
890  Py_buffer view;
891  if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
892  throw nb::python_error();
893  }
894  auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
895 
896  MlirContext context = contextWrapper->get();
897  MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
898  explicitShape, context);
899  if (mlirAttributeIsNull(attr)) {
900  throw std::invalid_argument(
901  "DenseElementsAttr could not be constructed from the given buffer. "
902  "This may mean that the Python buffer layout does not match that "
903  "MLIR expected layout and is a bug.");
904  }
905  return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
906  }
907 
908  static PyDenseElementsAttribute getSplat(const PyType &shapedType,
909  PyAttribute &elementAttr) {
910  auto contextWrapper =
911  PyMlirContext::forContext(mlirTypeGetContext(shapedType));
912  if (!mlirAttributeIsAInteger(elementAttr) &&
913  !mlirAttributeIsAFloat(elementAttr)) {
914  std::string message = "Illegal element type for DenseElementsAttr: ";
915  message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
916  throw nb::value_error(message.c_str());
917  }
918  if (!mlirTypeIsAShaped(shapedType) ||
919  !mlirShapedTypeHasStaticShape(shapedType)) {
920  std::string message =
921  "Expected a static ShapedType for the shaped_type parameter: ";
922  message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
923  throw nb::value_error(message.c_str());
924  }
925  MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
926  MlirType attrType = mlirAttributeGetType(elementAttr);
927  if (!mlirTypeEqual(shapedElementType, attrType)) {
928  std::string message =
929  "Shaped element type and attribute type must be equal: shaped=";
930  message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
931  message.append(", element=");
932  message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
933  throw nb::value_error(message.c_str());
934  }
935 
936  MlirAttribute elements =
937  mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
938  return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
939  }
940 
941  intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
942 
943  std::unique_ptr<nb_buffer_info> accessBuffer() {
944  MlirType shapedType = mlirAttributeGetType(*this);
945  MlirType elementType = mlirShapedTypeGetElementType(shapedType);
946  std::string format;
947 
948  if (mlirTypeIsAF32(elementType)) {
949  // f32
950  return bufferInfo<float>(shapedType);
951  }
952  if (mlirTypeIsAF64(elementType)) {
953  // f64
954  return bufferInfo<double>(shapedType);
955  }
956  if (mlirTypeIsAF16(elementType)) {
957  // f16
958  return bufferInfo<uint16_t>(shapedType, "e");
959  }
960  if (mlirTypeIsAIndex(elementType)) {
961  // Same as IndexType::kInternalStorageBitWidth
962  return bufferInfo<int64_t>(shapedType);
963  }
964  if (mlirTypeIsAInteger(elementType) &&
965  mlirIntegerTypeGetWidth(elementType) == 32) {
966  if (mlirIntegerTypeIsSignless(elementType) ||
967  mlirIntegerTypeIsSigned(elementType)) {
968  // i32
969  return bufferInfo<int32_t>(shapedType);
970  }
971  if (mlirIntegerTypeIsUnsigned(elementType)) {
972  // unsigned i32
973  return bufferInfo<uint32_t>(shapedType);
974  }
975  } else if (mlirTypeIsAInteger(elementType) &&
976  mlirIntegerTypeGetWidth(elementType) == 64) {
977  if (mlirIntegerTypeIsSignless(elementType) ||
978  mlirIntegerTypeIsSigned(elementType)) {
979  // i64
980  return bufferInfo<int64_t>(shapedType);
981  }
982  if (mlirIntegerTypeIsUnsigned(elementType)) {
983  // unsigned i64
984  return bufferInfo<uint64_t>(shapedType);
985  }
986  } else if (mlirTypeIsAInteger(elementType) &&
987  mlirIntegerTypeGetWidth(elementType) == 8) {
988  if (mlirIntegerTypeIsSignless(elementType) ||
989  mlirIntegerTypeIsSigned(elementType)) {
990  // i8
991  return bufferInfo<int8_t>(shapedType);
992  }
993  if (mlirIntegerTypeIsUnsigned(elementType)) {
994  // unsigned i8
995  return bufferInfo<uint8_t>(shapedType);
996  }
997  } else if (mlirTypeIsAInteger(elementType) &&
998  mlirIntegerTypeGetWidth(elementType) == 16) {
999  if (mlirIntegerTypeIsSignless(elementType) ||
1000  mlirIntegerTypeIsSigned(elementType)) {
1001  // i16
1002  return bufferInfo<int16_t>(shapedType);
1003  }
1004  if (mlirIntegerTypeIsUnsigned(elementType)) {
1005  // unsigned i16
1006  return bufferInfo<uint16_t>(shapedType);
1007  }
1008  } else if (mlirTypeIsAInteger(elementType) &&
1009  mlirIntegerTypeGetWidth(elementType) == 1) {
1010  // i1 / bool
1011  // We can not send the buffer directly back to Python, because the i1
1012  // values are bitpacked within MLIR. We call numpy's unpackbits function
1013  // to convert the bytes.
1014  return getBooleanBufferFromBitpackedAttribute();
1015  }
1016 
1017  // TODO: Currently crashes the program.
1018  // Reported as https://github.com/pybind/pybind11/issues/3336
1019  throw std::invalid_argument(
1020  "unsupported data type for conversion to Python buffer");
1021  }
1022 
1023  static void bindDerived(ClassTy &c) {
1024 #if PY_VERSION_HEX < 0x03090000
1025  PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1026  tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1027  tp->tp_as_buffer->bf_releasebuffer =
1028  PyDenseElementsAttribute::bf_releasebuffer;
1029 #endif
1030  c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031  .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032  nb::arg("array"), nb::arg("signless") = true,
1033  nb::arg("type").none() = nb::none(),
1034  nb::arg("shape").none() = nb::none(),
1035  nb::arg("context").none() = nb::none(),
1037  .def_static("get", PyDenseElementsAttribute::getFromList,
1038  nb::arg("attrs"), nb::arg("type").none() = nb::none(),
1039  nb::arg("context").none() = nb::none(),
1041  .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1042  nb::arg("shaped_type"), nb::arg("element_attr"),
1043  "Gets a DenseElementsAttr where all values are the same")
1044  .def_prop_ro("is_splat",
1045  [](PyDenseElementsAttribute &self) -> bool {
1046  return mlirDenseElementsAttrIsSplat(self);
1047  })
1048  .def("get_splat_value", [](PyDenseElementsAttribute &self) {
1049  if (!mlirDenseElementsAttrIsSplat(self))
1050  throw nb::value_error(
1051  "get_splat_value called on a non-splat attribute");
1053  });
1054  }
1055 
1056  static PyType_Slot slots[];
1057 
1058 private:
1059  static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1060  static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1061 
1062  static bool isUnsignedIntegerFormat(std::string_view format) {
1063  if (format.empty())
1064  return false;
1065  char code = format[0];
1066  return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1067  code == 'Q';
1068  }
1069 
1070  static bool isSignedIntegerFormat(std::string_view format) {
1071  if (format.empty())
1072  return false;
1073  char code = format[0];
1074  return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1075  code == 'q';
1076  }
1077 
1078  static MlirType
1079  getShapedType(std::optional<MlirType> bulkLoadElementType,
1080  std::optional<std::vector<int64_t>> explicitShape,
1081  Py_buffer &view) {
1082  SmallVector<int64_t> shape;
1083  if (explicitShape) {
1084  shape.append(explicitShape->begin(), explicitShape->end());
1085  } else {
1086  shape.append(view.shape, view.shape + view.ndim);
1087  }
1088 
1089  if (mlirTypeIsAShaped(*bulkLoadElementType)) {
1090  if (explicitShape) {
1091  throw std::invalid_argument("Shape can only be specified explicitly "
1092  "when the type is not a shaped type.");
1093  }
1094  return *bulkLoadElementType;
1095  } else {
1096  MlirAttribute encodingAttr = mlirAttributeGetNull();
1097  return mlirRankedTensorTypeGet(shape.size(), shape.data(),
1098  *bulkLoadElementType, encodingAttr);
1099  }
1100  }
1101 
1102  static MlirAttribute getAttributeFromBuffer(
1103  Py_buffer &view, bool signless, std::optional<PyType> explicitType,
1104  std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
1105  // Detect format codes that are suitable for bulk loading. This includes
1106  // all byte aligned integer and floating point types up to 8 bytes.
1107  // Notably, this excludes exotics types which do not have a direct
1108  // representation in the buffer protocol (i.e. complex, etc).
1109  std::optional<MlirType> bulkLoadElementType;
1110  if (explicitType) {
1111  bulkLoadElementType = *explicitType;
1112  } else {
1113  std::string_view format(view.format);
1114  if (format == "f") {
1115  // f32
1116  assert(view.itemsize == 4 && "mismatched array itemsize");
1117  bulkLoadElementType = mlirF32TypeGet(context);
1118  } else if (format == "d") {
1119  // f64
1120  assert(view.itemsize == 8 && "mismatched array itemsize");
1121  bulkLoadElementType = mlirF64TypeGet(context);
1122  } else if (format == "e") {
1123  // f16
1124  assert(view.itemsize == 2 && "mismatched array itemsize");
1125  bulkLoadElementType = mlirF16TypeGet(context);
1126  } else if (format == "?") {
1127  // i1
1128  // The i1 type needs to be bit-packed, so we will handle it seperately
1129  return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
1130  context);
1131  } else if (isSignedIntegerFormat(format)) {
1132  if (view.itemsize == 4) {
1133  // i32
1134  bulkLoadElementType = signless
1135  ? mlirIntegerTypeGet(context, 32)
1136  : mlirIntegerTypeSignedGet(context, 32);
1137  } else if (view.itemsize == 8) {
1138  // i64
1139  bulkLoadElementType = signless
1140  ? mlirIntegerTypeGet(context, 64)
1141  : mlirIntegerTypeSignedGet(context, 64);
1142  } else if (view.itemsize == 1) {
1143  // i8
1144  bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1145  : mlirIntegerTypeSignedGet(context, 8);
1146  } else if (view.itemsize == 2) {
1147  // i16
1148  bulkLoadElementType = signless
1149  ? mlirIntegerTypeGet(context, 16)
1150  : mlirIntegerTypeSignedGet(context, 16);
1151  }
1152  } else if (isUnsignedIntegerFormat(format)) {
1153  if (view.itemsize == 4) {
1154  // unsigned i32
1155  bulkLoadElementType = signless
1156  ? mlirIntegerTypeGet(context, 32)
1157  : mlirIntegerTypeUnsignedGet(context, 32);
1158  } else if (view.itemsize == 8) {
1159  // unsigned i64
1160  bulkLoadElementType = signless
1161  ? mlirIntegerTypeGet(context, 64)
1162  : mlirIntegerTypeUnsignedGet(context, 64);
1163  } else if (view.itemsize == 1) {
1164  // i8
1165  bulkLoadElementType = signless
1166  ? mlirIntegerTypeGet(context, 8)
1167  : mlirIntegerTypeUnsignedGet(context, 8);
1168  } else if (view.itemsize == 2) {
1169  // i16
1170  bulkLoadElementType = signless
1171  ? mlirIntegerTypeGet(context, 16)
1172  : mlirIntegerTypeUnsignedGet(context, 16);
1173  }
1174  }
1175  if (!bulkLoadElementType) {
1176  throw std::invalid_argument(
1177  std::string("unimplemented array format conversion from format: ") +
1178  std::string(format));
1179  }
1180  }
1181 
1182  MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1183  return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1184  }
1185 
1186  // There is a complication for boolean numpy arrays, as numpy represents
1187  // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1188  // booleans per byte.
1189  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1190  Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1191  MlirContext &context) {
1192  if (llvm::endianness::native != llvm::endianness::little) {
1193  // Given we have no good way of testing the behavior on big-endian
1194  // systems we will throw
1195  throw nb::type_error("Constructing a bit-packed MLIR attribute is "
1196  "unsupported on big-endian systems");
1197  }
1198  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1199  /*data=*/static_cast<uint8_t *>(view.buf),
1200  /*shape=*/{static_cast<size_t>(view.len)});
1201 
1202  nb::module_ numpy = nb::module_::import_("numpy");
1203  nb::object packbitsFunc = numpy.attr("packbits");
1204  nb::object packedBooleans =
1205  packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1206  nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
1207 
1208  MlirType bitpackedType =
1209  getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1210  assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1211  // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1212  // packedBooleans, hence the MlirAttribute will remain valid even when
1213  // packedBooleans get reclaimed by the end of the function.
1214  return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1215  pythonBuffer.ptr);
1216  }
1217 
1218  // This does the opposite transformation of
1219  // `getBitpackedAttributeFromBooleanBuffer`
1220  std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
1221  if (llvm::endianness::native != llvm::endianness::little) {
1222  // Given we have no good way of testing the behavior on big-endian
1223  // systems we will throw
1224  throw nb::type_error("Constructing a numpy array from a MLIR attribute "
1225  "is unsupported on big-endian systems");
1226  }
1227 
1228  int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1229  int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1230  uint8_t *bitpackedData = static_cast<uint8_t *>(
1231  const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1232  nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1233  /*data=*/bitpackedData,
1234  /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
1235 
1236  nb::module_ numpy = nb::module_::import_("numpy");
1237  nb::object unpackbitsFunc = numpy.attr("unpackbits");
1238  nb::object equalFunc = numpy.attr("equal");
1239  nb::object reshapeFunc = numpy.attr("reshape");
1240  nb::object unpackedBooleans =
1241  unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
1242 
1243  // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1244  // We need to:
1245  // 1. Slice away the padded bits
1246  // 2. Make the boolean array have the correct shape
1247  // 3. Convert the array to a boolean array
1248  unpackedBooleans = unpackedBooleans[nb::slice(
1249  nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
1250  unpackedBooleans = equalFunc(unpackedBooleans, 1);
1251 
1252  MlirType shapedType = mlirAttributeGetType(*this);
1253  intptr_t rank = mlirShapedTypeGetRank(shapedType);
1254  std::vector<intptr_t> shape(rank);
1255  for (intptr_t i = 0; i < rank; ++i) {
1256  shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
1257  }
1258  unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1259 
1260  // Make sure the returned nb::buffer_view claims ownership of the data in
1261  // `pythonBuffer` so it remains valid when Python reads it
1262  nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1263  return std::make_unique<nb_buffer_info>(pythonBuffer.request());
1264  }
1265 
1266  template <typename Type>
1267  std::unique_ptr<nb_buffer_info>
1268  bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
1269  intptr_t rank = mlirShapedTypeGetRank(shapedType);
1270  // Prepare the data for the buffer_info.
1271  // Buffer is configured for read-only access below.
1272  Type *data = static_cast<Type *>(
1273  const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1274  // Prepare the shape for the buffer_info.
1276  for (intptr_t i = 0; i < rank; ++i)
1277  shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1278  // Prepare the strides for the buffer_info.
1279  SmallVector<intptr_t, 4> strides;
1280  if (mlirDenseElementsAttrIsSplat(*this)) {
1281  // Splats are special, only the single value is stored.
1282  strides.assign(rank, 0);
1283  } else {
1284  for (intptr_t i = 1; i < rank; ++i) {
1285  intptr_t strideFactor = 1;
1286  for (intptr_t j = i; j < rank; ++j)
1287  strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1288  strides.push_back(sizeof(Type) * strideFactor);
1289  }
1290  strides.push_back(sizeof(Type));
1291  }
1292  const char *format;
1293  if (explicitFormat) {
1294  format = explicitFormat;
1295  } else {
1296  format = nb_format_descriptor<Type>::format();
1297  }
1298  return std::make_unique<nb_buffer_info>(
1299  data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
1300  /*readonly=*/true);
1301  }
1302 }; // namespace
1303 
1304 PyType_Slot PyDenseElementsAttribute::slots[] = {
1305 // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1306 #if PY_VERSION_HEX >= 0x03090000
1307  {Py_bf_getbuffer,
1308  reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1309  {Py_bf_releasebuffer,
1310  reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1311 #endif
1312  {0, nullptr},
1313 };
1314 
1315 /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1316  Py_buffer *view,
1317  int flags) {
1318  view->obj = nullptr;
1319  std::unique_ptr<nb_buffer_info> info;
1320  try {
1321  auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1322  info = attr->accessBuffer();
1323  } catch (nb::python_error &e) {
1324  e.restore();
1325  nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1326  return -1;
1327  }
1328  view->obj = obj;
1329  view->ndim = 1;
1330  view->buf = info->ptr;
1331  view->itemsize = info->itemsize;
1332  view->len = info->itemsize;
1333  for (auto s : info->shape) {
1334  view->len *= s;
1335  }
1336  view->readonly = info->readonly;
1337  if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1338  view->format = const_cast<char *>(info->format);
1339  }
1340  if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1341  view->ndim = static_cast<int>(info->ndim);
1342  view->strides = info->strides.data();
1343  view->shape = info->shape.data();
1344  }
1345  view->suboffsets = nullptr;
1346  view->internal = info.release();
1347  Py_INCREF(obj);
1348  return 0;
1349 }
1350 
1351 /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1352  Py_buffer *view) {
1353  delete reinterpret_cast<nb_buffer_info *>(view->internal);
1354 }
1355 
1356 /// Refinement of the PyDenseElementsAttribute for attributes containing
1357 /// integer (and boolean) values. Supports element access.
1358 class PyDenseIntElementsAttribute
1359  : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1360  PyDenseElementsAttribute> {
1361 public:
1362  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1363  static constexpr const char *pyClassName = "DenseIntElementsAttr";
1364  using PyConcreteAttribute::PyConcreteAttribute;
1365 
1366  /// Returns the element at the given linear position. Asserts if the index
1367  /// is out of range.
1368  nb::object dunderGetItem(intptr_t pos) {
1369  if (pos < 0 || pos >= dunderLen()) {
1370  throw nb::index_error("attempt to access out of bounds element");
1371  }
1372 
1373  MlirType type = mlirAttributeGetType(*this);
1374  type = mlirShapedTypeGetElementType(type);
1375  // Index type can also appear as a DenseIntElementsAttr and therefore can be
1376  // casted to integer.
1377  assert(mlirTypeIsAInteger(type) ||
1378  mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379  "dense int elements attribute");
1380  // Dispatch element extraction to an appropriate C function based on the
1381  // elemental type of the attribute. nb::int_ is implicitly constructible
1382  // from any C++ integral type and handles bitwidth correctly.
1383  // TODO: consider caching the type properties in the constructor to avoid
1384  // querying them on each element access.
1385  if (mlirTypeIsAIndex(type)) {
1386  return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1387  }
1388  unsigned width = mlirIntegerTypeGetWidth(type);
1389  bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1390  if (isUnsigned) {
1391  if (width == 1) {
1392  return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1393  }
1394  if (width == 8) {
1395  return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1396  }
1397  if (width == 16) {
1398  return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1399  }
1400  if (width == 32) {
1401  return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1402  }
1403  if (width == 64) {
1404  return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1405  }
1406  } else {
1407  if (width == 1) {
1408  return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1409  }
1410  if (width == 8) {
1411  return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1412  }
1413  if (width == 16) {
1414  return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1415  }
1416  if (width == 32) {
1417  return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1418  }
1419  if (width == 64) {
1420  return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1421  }
1422  }
1423  throw nb::type_error("Unsupported integer type");
1424  }
1425 
1426  static void bindDerived(ClassTy &c) {
1427  c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1428  }
1429 };
1430 
1431 class PyDenseResourceElementsAttribute
1432  : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1433 public:
1434  static constexpr IsAFunctionTy isaFunction =
1436  static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1437  using PyConcreteAttribute::PyConcreteAttribute;
1438 
1439  static PyDenseResourceElementsAttribute
1440  getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
1441  std::optional<size_t> alignment, bool isMutable,
1442  DefaultingPyMlirContext contextWrapper) {
1443  if (!mlirTypeIsAShaped(type)) {
1444  throw std::invalid_argument(
1445  "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1446  }
1447 
1448  // Do not request any conversions as we must ensure to use caller
1449  // managed memory.
1450  int flags = PyBUF_STRIDES;
1451  std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1452  if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1453  throw nb::python_error();
1454  }
1455 
1456  // This scope releaser will only release if we haven't yet transferred
1457  // ownership.
1458  auto freeBuffer = llvm::make_scope_exit([&]() {
1459  if (view)
1460  PyBuffer_Release(view.get());
1461  });
1462 
1463  if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1464  throw std::invalid_argument("Contiguous buffer is required.");
1465  }
1466 
1467  // Infer alignment to be the stride of one element if not explicit.
1468  size_t inferredAlignment;
1469  if (alignment)
1470  inferredAlignment = *alignment;
1471  else
1472  inferredAlignment = view->strides[view->ndim - 1];
1473 
1474  // The userData is a Py_buffer* that the deleter owns.
1475  auto deleter = [](void *userData, const void *data, size_t size,
1476  size_t align) {
1477  if (!Py_IsInitialized())
1478  Py_Initialize();
1479  Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1480  nb::gil_scoped_acquire gil;
1481  PyBuffer_Release(ownedView);
1482  delete ownedView;
1483  };
1484 
1485  size_t rawBufferSize = view->len;
1486  MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1487  type, toMlirStringRef(name), view->buf, rawBufferSize,
1488  inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1489  if (mlirAttributeIsNull(attr)) {
1490  throw std::invalid_argument(
1491  "DenseResourceElementsAttr could not be constructed from the given "
1492  "buffer. "
1493  "This may mean that the Python buffer layout does not match that "
1494  "MLIR expected layout and is a bug.");
1495  }
1496  view.release();
1497  return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1498  }
1499 
1500  static void bindDerived(ClassTy &c) {
1501  c.def_static(
1502  "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
1503  nb::arg("array"), nb::arg("name"), nb::arg("type"),
1504  nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1505  nb::arg("context").none() = nb::none(),
1507  }
1508 };
1509 
1510 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1511 public:
1512  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1513  static constexpr const char *pyClassName = "DictAttr";
1514  using PyConcreteAttribute::PyConcreteAttribute;
1515  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1517 
1518  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1519 
1520  bool dunderContains(const std::string &name) {
1521  return !mlirAttributeIsNull(
1523  }
1524 
1525  static void bindDerived(ClassTy &c) {
1526  c.def("__contains__", &PyDictAttribute::dunderContains);
1527  c.def("__len__", &PyDictAttribute::dunderLen);
1528  c.def_static(
1529  "get",
1530  [](nb::dict attributes, DefaultingPyMlirContext context) {
1531  SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1532  mlirNamedAttributes.reserve(attributes.size());
1533  for (std::pair<nb::handle, nb::handle> it : attributes) {
1534  auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1535  auto name = nb::cast<std::string>(it.first);
1536  mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1538  toMlirStringRef(name)),
1539  mlirAttr));
1540  }
1541  MlirAttribute attr =
1542  mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1543  mlirNamedAttributes.data());
1544  return PyDictAttribute(context->getRef(), attr);
1545  },
1546  nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
1547  "Gets an uniqued dict attribute");
1548  c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1549  MlirAttribute attr =
1551  if (mlirAttributeIsNull(attr))
1552  throw nb::key_error("attempt to access a non-existent attribute");
1553  return attr;
1554  });
1555  c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1556  if (index < 0 || index >= self.dunderLen()) {
1557  throw nb::index_error("attempt to access out of bounds attribute");
1558  }
1559  MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1560  return PyNamedAttribute(
1561  namedAttr.attribute,
1562  std::string(mlirIdentifierStr(namedAttr.name).data));
1563  });
1564  }
1565 };
1566 
1567 /// Refinement of PyDenseElementsAttribute for attributes containing
1568 /// floating-point values. Supports element access.
1569 class PyDenseFPElementsAttribute
1570  : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1571  PyDenseElementsAttribute> {
1572 public:
1573  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1574  static constexpr const char *pyClassName = "DenseFPElementsAttr";
1575  using PyConcreteAttribute::PyConcreteAttribute;
1576 
1577  nb::float_ dunderGetItem(intptr_t pos) {
1578  if (pos < 0 || pos >= dunderLen()) {
1579  throw nb::index_error("attempt to access out of bounds element");
1580  }
1581 
1582  MlirType type = mlirAttributeGetType(*this);
1583  type = mlirShapedTypeGetElementType(type);
1584  // Dispatch element extraction to an appropriate C function based on the
1585  // elemental type of the attribute. nb::float_ is implicitly constructible
1586  // from float and double.
1587  // TODO: consider caching the type properties in the constructor to avoid
1588  // querying them on each element access.
1589  if (mlirTypeIsAF32(type)) {
1590  return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1591  }
1592  if (mlirTypeIsAF64(type)) {
1593  return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1594  }
1595  throw nb::type_error("Unsupported floating-point type");
1596  }
1597 
1598  static void bindDerived(ClassTy &c) {
1599  c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1600  }
1601 };
1602 
1603 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1604 public:
1605  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1606  static constexpr const char *pyClassName = "TypeAttr";
1607  using PyConcreteAttribute::PyConcreteAttribute;
1608  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1610 
1611  static void bindDerived(ClassTy &c) {
1612  c.def_static(
1613  "get",
1614  [](PyType value, DefaultingPyMlirContext context) {
1615  MlirAttribute attr = mlirTypeAttrGet(value.get());
1616  return PyTypeAttribute(context->getRef(), attr);
1617  },
1618  nb::arg("value"), nb::arg("context").none() = nb::none(),
1619  "Gets a uniqued Type attribute");
1620  c.def_prop_ro("value", [](PyTypeAttribute &self) {
1621  return mlirTypeAttrGetValue(self.get());
1622  });
1623  }
1624 };
1625 
1626 /// Unit Attribute subclass. Unit attributes don't have values.
1627 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1628 public:
1629  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1630  static constexpr const char *pyClassName = "UnitAttr";
1631  using PyConcreteAttribute::PyConcreteAttribute;
1632  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1634 
1635  static void bindDerived(ClassTy &c) {
1636  c.def_static(
1637  "get",
1638  [](DefaultingPyMlirContext context) {
1639  return PyUnitAttribute(context->getRef(),
1640  mlirUnitAttrGet(context->get()));
1641  },
1642  nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
1643  }
1644 };
1645 
1646 /// Strided layout attribute subclass.
1647 class PyStridedLayoutAttribute
1648  : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1649 public:
1650  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1651  static constexpr const char *pyClassName = "StridedLayoutAttr";
1652  using PyConcreteAttribute::PyConcreteAttribute;
1653  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1655 
1656  static void bindDerived(ClassTy &c) {
1657  c.def_static(
1658  "get",
1659  [](int64_t offset, const std::vector<int64_t> strides,
1661  MlirAttribute attr = mlirStridedLayoutAttrGet(
1662  ctx->get(), offset, strides.size(), strides.data());
1663  return PyStridedLayoutAttribute(ctx->getRef(), attr);
1664  },
1665  nb::arg("offset"), nb::arg("strides"),
1666  nb::arg("context").none() = nb::none(),
1667  "Gets a strided layout attribute.");
1668  c.def_static(
1669  "get_fully_dynamic",
1670  [](int64_t rank, DefaultingPyMlirContext ctx) {
1671  auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1672  std::vector<int64_t> strides(rank);
1673  std::fill(strides.begin(), strides.end(), dynamic);
1674  MlirAttribute attr = mlirStridedLayoutAttrGet(
1675  ctx->get(), dynamic, strides.size(), strides.data());
1676  return PyStridedLayoutAttribute(ctx->getRef(), attr);
1677  },
1678  nb::arg("rank"), nb::arg("context").none() = nb::none(),
1679  "Gets a strided layout attribute with dynamic offset and strides of "
1680  "a "
1681  "given rank.");
1682  c.def_prop_ro(
1683  "offset",
1684  [](PyStridedLayoutAttribute &self) {
1685  return mlirStridedLayoutAttrGetOffset(self);
1686  },
1687  "Returns the value of the float point attribute");
1688  c.def_prop_ro(
1689  "strides",
1690  [](PyStridedLayoutAttribute &self) {
1691  intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1692  std::vector<int64_t> strides(size);
1693  for (intptr_t i = 0; i < size; i++) {
1694  strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1695  }
1696  return strides;
1697  },
1698  "Returns the value of the float point attribute");
1699  }
1700 };
1701 
1702 nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1703  if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1704  return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
1705  if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1706  return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
1707  if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1708  return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
1709  if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1710  return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
1711  if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1712  return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
1713  if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1714  return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
1715  if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1716  return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
1717  std::string msg =
1718  std::string("Can't cast unknown element type DenseArrayAttr (") +
1719  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1720  throw nb::type_error(msg.c_str());
1721 }
1722 
1723 nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1724  if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1725  return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
1726  if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1727  return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
1728  std::string msg =
1729  std::string(
1730  "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1731  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1732  throw nb::type_error(msg.c_str());
1733 }
1734 
1735 nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1736  if (PyBoolAttribute::isaFunction(pyAttribute))
1737  return nb::cast(PyBoolAttribute(pyAttribute));
1738  if (PyIntegerAttribute::isaFunction(pyAttribute))
1739  return nb::cast(PyIntegerAttribute(pyAttribute));
1740  std::string msg =
1741  std::string("Can't cast unknown element type DenseArrayAttr (") +
1742  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1743  throw nb::type_error(msg.c_str());
1744 }
1745 
1746 nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1747  if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1748  return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
1749  if (PySymbolRefAttribute::isaFunction(pyAttribute))
1750  return nb::cast(PySymbolRefAttribute(pyAttribute));
1751  std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1752  nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1753  ")";
1754  throw nb::type_error(msg.c_str());
1755 }
1756 
1757 } // namespace
1758 
1759 void mlir::python::populateIRAttributes(nb::module_ &m) {
1760  PyAffineMapAttribute::bind(m);
1761  PyDenseBoolArrayAttribute::bind(m);
1762  PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1763  PyDenseI8ArrayAttribute::bind(m);
1764  PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1765  PyDenseI16ArrayAttribute::bind(m);
1766  PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1767  PyDenseI32ArrayAttribute::bind(m);
1768  PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1769  PyDenseI64ArrayAttribute::bind(m);
1770  PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1771  PyDenseF32ArrayAttribute::bind(m);
1772  PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1773  PyDenseF64ArrayAttribute::bind(m);
1774  PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1775  PyGlobals::get().registerTypeCaster(
1777  nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1778 
1779  PyArrayAttribute::bind(m);
1780  PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1781  PyBoolAttribute::bind(m);
1782  PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1783  PyDenseFPElementsAttribute::bind(m);
1784  PyDenseIntElementsAttribute::bind(m);
1785  PyGlobals::get().registerTypeCaster(
1787  nb::cast<nb::callable>(
1788  nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1789  PyDenseResourceElementsAttribute::bind(m);
1790 
1791  PyDictAttribute::bind(m);
1792  PySymbolRefAttribute::bind(m);
1793  PyGlobals::get().registerTypeCaster(
1795  nb::cast<nb::callable>(
1796  nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
1797 
1798  PyFlatSymbolRefAttribute::bind(m);
1799  PyOpaqueAttribute::bind(m);
1800  PyFloatAttribute::bind(m);
1801  PyIntegerAttribute::bind(m);
1802  PyIntegerSetAttribute::bind(m);
1803  PyStringAttribute::bind(m);
1804  PyTypeAttribute::bind(m);
1805  PyGlobals::get().registerTypeCaster(
1807  nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1808  PyUnitAttribute::bind(m);
1809 
1810  PyStridedLayoutAttribute::bind(m);
1811 }
static const char kDenseElementsAttrGetDocstring[]
static const char kDenseResourceElementsAttrGetFromBufferDocstring[]
static const char kDenseElementsAttrGetFromListDocstring[]
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:216
static LogicalResult nextIndex(ArrayRef< int64_t > shape, MutableArrayRef< int64_t > index)
Walks over the indices of the elements of a tensor of a given shape by updating index in place to the...
std::string str() const
Converts the diagnostic to a string.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:339
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:546
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:320
static PyMlirContext & resolve()
Definition: IRCore.cpp:856
ReferrentTy * get() const
Definition: NanobindUtils.h:60
MlirAffineMap get() const
Definition: IRModule.h:1241
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1038
MlirAttribute get() const
Definition: IRModule.h:1044
CRTP base classes for Python attributes that subclass Attribute and should be castable from it (i....
Definition: IRModule.h:1088
nanobind::class_< DerivedTy, BaseTy > ClassTy
Definition: IRModule.h:1093
MlirIntegerSet get() const
Definition: IRModule.h:1262
MlirContext get()
Accesses the underlying MlirContext.
Definition: IRModule.h:211
Represents a Python MlirNamedAttr, carrying an optional owned name.
Definition: IRModule.h:1062
Wrapper around the generic MlirType.
Definition: IRModule.h:912
MlirType get() const
Definition: IRModule.h:918
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map)
Creates an affine map attribute wrapping the given map.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type)
Creates an opaque attribute in the given context associated with the dialect identified by its namesp...
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, double value)
Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a construction of a FloatAttr,...
MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr)
Returns the affine map wrapped in the given affine map attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, const int64_t *strides)
MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void)
Returns the typeID of a String attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr)
Checks whether the given attribute is a unit attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr)
Checks whether the given attribute is a dense elements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr)
Checks whether the given attribute is an integer set attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAAffineMap(MlirAttribute attr)
Checks whether the given attribute is an affine map attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol)
Creates a flat symbol reference attribute in the given context referencing a symbol identified by the...
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void)
Returns the typeID of a StridedLayout attribute.
MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void)
Returns the typeID of an Integer attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseResourceElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr)
Returns the string reference to the root referenced symbol.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAType(MlirAttribute attr)
Checks whether the given attribute is a type attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void)
Returns the typeID of an DenseIntOrFPElements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAArray(MlirAttribute attr)
Checks whether the given attribute is an array attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAInteger(MlirAttribute attr)
Checks whether the given attribute is an integer attribute.
MLIR_CAPI_EXPORTED intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr)
Returns the number of attributes contained in a dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set)
Creates an integer set attribute wrapping the given set.
MLIR_CAPI_EXPORTED uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr)
Returns the value stored in the given bool attribute.
MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value)
Creates an integer attribute of the given type with the given integer value.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void)
Returns the typeID of an IntegerSet attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos)
Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense e...
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute const *elements)
Creates a dictionary attribute containing the given list of elements in the provided context.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, size_t dataAlignment, bool dataIsMutable, void(*deleter)(void *userData, const void *data, size_t size, size_t align), void *userData)
Unlike the typed accessors below, constructs the attribute with a raw data buffer and no type/alignme...
MLIR_CAPI_EXPORTED bool mlirAttributeIsABool(MlirAttribute attr)
Checks whether the given attribute is a bool attribute.
MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos)
Get an element of a dense array.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, int64_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void)
Returns the typeID of an AffineMap attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos)
Returns pos-th reference nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void)
Returns the typeID of an Array attribute.
MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, double const *values)
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signless type and f...
MLIR_CAPI_EXPORTED intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)
Returns the number of references nested in the given symbol reference attribute.
MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr)
Returns the type stored in the given type attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrIsSplat(MlirAttribute attr)
Checks whether the given dense elements attribute contains a single replicated value (splat).
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements)
Creates a dense elements attribute with the given Shaped type and elements in the same context as the...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr)
Checks whether the given attribute is a dense array attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value)
Creates a bool attribute in the given context with the given value.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void)
Returns the typeID of a Float attribute.
MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of signed type and fit...
MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values)
MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element of the given dictionary attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void)
Returns the typeID of a Unit attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, float const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values)
Create a dense array attribute with the given elements.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaque(MlirAttribute attr)
Checks whether the given attribute is an opaque attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos)
Returns pos-th element stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name)
Returns the dictionary attribute element with the given name or NULL if the given name does not exist...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void)
Returns the typeID of an SymbolRef attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void)
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)
Returns the single replicated value (splat) of a specific type contained by the given dense elements ...
MLIR_CAPI_EXPORTED bool mlirAttributeIsASymbolRef(MlirAttribute attr)
Checks whether the given attribute is a symbol reference attribute.
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool mlirAttributeIsAString(MlirAttribute attr)
Checks whether the given attribute is a string attribute.
MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr)
Gets the total number of elements in the given elements attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)
Returns the namespace of the dialect with which the given opaque attribute is associated.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADictionary(MlirAttribute attr)
Checks whether the given attribute is a dictionary attribute.
MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr)
Returns the attribute values as a string reference.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, int16_t const *values)
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void)
Returns the typeID of an Opaque attribute.
MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr)
Returns the value stored in the given floating point attribute, interpreting the value as double.
MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr)
Returns the value stored in the given integer attribute, assuming the value is of unsigned type and f...
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, int8_t const *values)
MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx)
Creates a unit attribute in the given context.
MLIR_CAPI_EXPORTED double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value)
Creates a floating point attribute in the given context with the given double value and double-precis...
MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements)
Creates an array element containing the given list of elements in the given context.
MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr)
Returns the raw data of the given dense elements attribute.
MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr)
Get the size of a dense array.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFloat(MlirAttribute attr)
Checks whether the given attribute is a floating point attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references)
Creates a symbol reference attribute in the given context referencing a symbol identified by the give...
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void)
Returns the typeID of a Type attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element)
Creates a dense elements attribute with the given Shaped type containing a single replicated element ...
MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void)
Returns the typeID of a Dictionary attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type)
Creates a type attribute wrapping the given type in the same context as the type.
MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr)
Returns the number of elements stored in the given array attribute.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer)
Creates a dense elements attribute with the given Shaped type and elements populated from a packed,...
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str)
Creates a string attribute in the given context containing the given string.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)
Checks whether the given attribute is a flat symbol reference attribute.
MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)
Returns the referenced symbol as a string reference.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr)
MLIR_CAPI_EXPORTED uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1145
MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, MlirAttribute attr)
Associates an attribute with the name. Takes ownership of neither.
Definition: IR.cpp:1261
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident)
Gets the string value of the identifier.
Definition: IR.cpp:1282
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute)
Gets the type of this attribute.
Definition: IR.cpp:1234
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type)
Gets the context that a type was created with.
Definition: IR.cpp:1199
MLIR_CAPI_EXPORTED bool mlirTypeEqual(MlirType t1, MlirType t2)
Checks if two types are equal.
Definition: IR.cpp:1211
MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute)
Gets the context that an attribute was created with.
Definition: IR.cpp:1230
MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str)
Gets an identifier with the given string value.
Definition: IR.cpp:1270
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
void populateIRAttributes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Named MLIR attribute.
Definition: IR.h:76
MlirAttribute attribute
Definition: IR.h:78
MlirIdentifier name
Definition: IR.h:77
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1329
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:455
std::vector< PyDiagnostic::DiagnosticInfo > take()
Definition: IRModule.h:465
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.