MLIR  22.0.0git
IRTypes.cpp
Go to the documentation of this file.
1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // clang-format off
10 #include "IRModule.h"
12 // clang-format on
13 
14 #include <optional>
15 
16 #include "IRModule.h"
17 #include "NanobindUtils.h"
19 #include "mlir-c/BuiltinTypes.h"
20 #include "mlir-c/Support.h"
21 
22 namespace nb = nanobind;
23 using namespace mlir;
24 using namespace mlir::python;
25 
26 using llvm::SmallVector;
27 using llvm::Twine;
28 
29 namespace {
30 
31 /// Checks whether the given type is an integer or float type.
32 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34  mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35 }
36 
37 class PyIntegerType : public PyConcreteType<PyIntegerType> {
38 public:
39  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
42  static constexpr const char *pyClassName = "IntegerType";
44 
45  static void bindDerived(ClassTy &c) {
46  c.def_static(
47  "get_signless",
48  [](unsigned width, DefaultingPyMlirContext context) {
49  MlirType t = mlirIntegerTypeGet(context->get(), width);
50  return PyIntegerType(context->getRef(), t);
51  },
52  nb::arg("width"), nb::arg("context") = nb::none(),
53  "Create a signless integer type");
54  c.def_static(
55  "get_signed",
56  [](unsigned width, DefaultingPyMlirContext context) {
57  MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58  return PyIntegerType(context->getRef(), t);
59  },
60  nb::arg("width"), nb::arg("context") = nb::none(),
61  "Create a signed integer type");
62  c.def_static(
63  "get_unsigned",
64  [](unsigned width, DefaultingPyMlirContext context) {
65  MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66  return PyIntegerType(context->getRef(), t);
67  },
68  nb::arg("width"), nb::arg("context") = nb::none(),
69  "Create an unsigned integer type");
70  c.def_prop_ro(
71  "width",
72  [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73  "Returns the width of the integer type");
74  c.def_prop_ro(
75  "is_signless",
76  [](PyIntegerType &self) -> bool {
77  return mlirIntegerTypeIsSignless(self);
78  },
79  "Returns whether this is a signless integer");
80  c.def_prop_ro(
81  "is_signed",
82  [](PyIntegerType &self) -> bool {
83  return mlirIntegerTypeIsSigned(self);
84  },
85  "Returns whether this is a signed integer");
86  c.def_prop_ro(
87  "is_unsigned",
88  [](PyIntegerType &self) -> bool {
89  return mlirIntegerTypeIsUnsigned(self);
90  },
91  "Returns whether this is an unsigned integer");
92  }
93 };
94 
95 /// Index Type subclass - IndexType.
96 class PyIndexType : public PyConcreteType<PyIndexType> {
97 public:
98  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
101  static constexpr const char *pyClassName = "IndexType";
103 
104  static void bindDerived(ClassTy &c) {
105  c.def_static(
106  "get",
107  [](DefaultingPyMlirContext context) {
108  MlirType t = mlirIndexTypeGet(context->get());
109  return PyIndexType(context->getRef(), t);
110  },
111  nb::arg("context") = nb::none(), "Create a index type.");
112  }
113 };
114 
115 class PyFloatType : public PyConcreteType<PyFloatType> {
116 public:
117  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118  static constexpr const char *pyClassName = "FloatType";
120 
121  static void bindDerived(ClassTy &c) {
122  c.def_prop_ro(
123  "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124  "Returns the width of the floating-point type");
125  }
126 };
127 
128 /// Floating Point Type subclass - Float4E2M1FNType.
129 class PyFloat4E2M1FNType
130  : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131 public:
132  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
135  static constexpr const char *pyClassName = "Float4E2M1FNType";
137 
138  static void bindDerived(ClassTy &c) {
139  c.def_static(
140  "get",
141  [](DefaultingPyMlirContext context) {
142  MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143  return PyFloat4E2M1FNType(context->getRef(), t);
144  },
145  nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
146  }
147 };
148 
149 /// Floating Point Type subclass - Float6E2M3FNType.
150 class PyFloat6E2M3FNType
151  : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152 public:
153  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
156  static constexpr const char *pyClassName = "Float6E2M3FNType";
158 
159  static void bindDerived(ClassTy &c) {
160  c.def_static(
161  "get",
162  [](DefaultingPyMlirContext context) {
163  MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164  return PyFloat6E2M3FNType(context->getRef(), t);
165  },
166  nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
167  }
168 };
169 
170 /// Floating Point Type subclass - Float6E3M2FNType.
171 class PyFloat6E3M2FNType
172  : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173 public:
174  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
177  static constexpr const char *pyClassName = "Float6E3M2FNType";
179 
180  static void bindDerived(ClassTy &c) {
181  c.def_static(
182  "get",
183  [](DefaultingPyMlirContext context) {
184  MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185  return PyFloat6E3M2FNType(context->getRef(), t);
186  },
187  nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
188  }
189 };
190 
191 /// Floating Point Type subclass - Float8E4M3FNType.
192 class PyFloat8E4M3FNType
193  : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194 public:
195  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
198  static constexpr const char *pyClassName = "Float8E4M3FNType";
200 
201  static void bindDerived(ClassTy &c) {
202  c.def_static(
203  "get",
204  [](DefaultingPyMlirContext context) {
205  MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206  return PyFloat8E4M3FNType(context->getRef(), t);
207  },
208  nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
209  }
210 };
211 
212 /// Floating Point Type subclass - Float8E5M2Type.
213 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214 public:
215  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
218  static constexpr const char *pyClassName = "Float8E5M2Type";
220 
221  static void bindDerived(ClassTy &c) {
222  c.def_static(
223  "get",
224  [](DefaultingPyMlirContext context) {
225  MlirType t = mlirFloat8E5M2TypeGet(context->get());
226  return PyFloat8E5M2Type(context->getRef(), t);
227  },
228  nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
229  }
230 };
231 
232 /// Floating Point Type subclass - Float8E4M3Type.
233 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234 public:
235  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
238  static constexpr const char *pyClassName = "Float8E4M3Type";
240 
241  static void bindDerived(ClassTy &c) {
242  c.def_static(
243  "get",
244  [](DefaultingPyMlirContext context) {
245  MlirType t = mlirFloat8E4M3TypeGet(context->get());
246  return PyFloat8E4M3Type(context->getRef(), t);
247  },
248  nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
249  }
250 };
251 
252 /// Floating Point Type subclass - Float8E4M3FNUZ.
253 class PyFloat8E4M3FNUZType
254  : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255 public:
256  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
259  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
261 
262  static void bindDerived(ClassTy &c) {
263  c.def_static(
264  "get",
265  [](DefaultingPyMlirContext context) {
266  MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267  return PyFloat8E4M3FNUZType(context->getRef(), t);
268  },
269  nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
270  }
271 };
272 
273 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
274 class PyFloat8E4M3B11FNUZType
275  : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
276 public:
277  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
278  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282 
283  static void bindDerived(ClassTy &c) {
284  c.def_static(
285  "get",
286  [](DefaultingPyMlirContext context) {
287  MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
288  return PyFloat8E4M3B11FNUZType(context->getRef(), t);
289  },
290  nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
291  }
292 };
293 
294 /// Floating Point Type subclass - Float8E5M2FNUZ.
295 class PyFloat8E5M2FNUZType
296  : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
297 public:
298  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
299  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
301  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
303 
304  static void bindDerived(ClassTy &c) {
305  c.def_static(
306  "get",
307  [](DefaultingPyMlirContext context) {
308  MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
309  return PyFloat8E5M2FNUZType(context->getRef(), t);
310  },
311  nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
312  }
313 };
314 
315 /// Floating Point Type subclass - Float8E3M4Type.
316 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
317 public:
318  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
319  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
321  static constexpr const char *pyClassName = "Float8E3M4Type";
323 
324  static void bindDerived(ClassTy &c) {
325  c.def_static(
326  "get",
327  [](DefaultingPyMlirContext context) {
328  MlirType t = mlirFloat8E3M4TypeGet(context->get());
329  return PyFloat8E3M4Type(context->getRef(), t);
330  },
331  nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
332  }
333 };
334 
335 /// Floating Point Type subclass - Float8E8M0FNUType.
336 class PyFloat8E8M0FNUType
337  : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
338 public:
339  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
340  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
342  static constexpr const char *pyClassName = "Float8E8M0FNUType";
344 
345  static void bindDerived(ClassTy &c) {
346  c.def_static(
347  "get",
348  [](DefaultingPyMlirContext context) {
349  MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
350  return PyFloat8E8M0FNUType(context->getRef(), t);
351  },
352  nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
353  }
354 };
355 
356 /// Floating Point Type subclass - BF16Type.
357 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
358 public:
359  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
360  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
362  static constexpr const char *pyClassName = "BF16Type";
364 
365  static void bindDerived(ClassTy &c) {
366  c.def_static(
367  "get",
368  [](DefaultingPyMlirContext context) {
369  MlirType t = mlirBF16TypeGet(context->get());
370  return PyBF16Type(context->getRef(), t);
371  },
372  nb::arg("context") = nb::none(), "Create a bf16 type.");
373  }
374 };
375 
376 /// Floating Point Type subclass - F16Type.
377 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
378 public:
379  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
380  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
382  static constexpr const char *pyClassName = "F16Type";
384 
385  static void bindDerived(ClassTy &c) {
386  c.def_static(
387  "get",
388  [](DefaultingPyMlirContext context) {
389  MlirType t = mlirF16TypeGet(context->get());
390  return PyF16Type(context->getRef(), t);
391  },
392  nb::arg("context") = nb::none(), "Create a f16 type.");
393  }
394 };
395 
396 /// Floating Point Type subclass - TF32Type.
397 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
398 public:
399  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
400  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
402  static constexpr const char *pyClassName = "FloatTF32Type";
404 
405  static void bindDerived(ClassTy &c) {
406  c.def_static(
407  "get",
408  [](DefaultingPyMlirContext context) {
409  MlirType t = mlirTF32TypeGet(context->get());
410  return PyTF32Type(context->getRef(), t);
411  },
412  nb::arg("context") = nb::none(), "Create a tf32 type.");
413  }
414 };
415 
416 /// Floating Point Type subclass - F32Type.
417 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
418 public:
419  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
420  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
422  static constexpr const char *pyClassName = "F32Type";
424 
425  static void bindDerived(ClassTy &c) {
426  c.def_static(
427  "get",
428  [](DefaultingPyMlirContext context) {
429  MlirType t = mlirF32TypeGet(context->get());
430  return PyF32Type(context->getRef(), t);
431  },
432  nb::arg("context") = nb::none(), "Create a f32 type.");
433  }
434 };
435 
436 /// Floating Point Type subclass - F64Type.
437 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
438 public:
439  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
440  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
442  static constexpr const char *pyClassName = "F64Type";
444 
445  static void bindDerived(ClassTy &c) {
446  c.def_static(
447  "get",
448  [](DefaultingPyMlirContext context) {
449  MlirType t = mlirF64TypeGet(context->get());
450  return PyF64Type(context->getRef(), t);
451  },
452  nb::arg("context") = nb::none(), "Create a f64 type.");
453  }
454 };
455 
456 /// None Type subclass - NoneType.
457 class PyNoneType : public PyConcreteType<PyNoneType> {
458 public:
459  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
460  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
462  static constexpr const char *pyClassName = "NoneType";
464 
465  static void bindDerived(ClassTy &c) {
466  c.def_static(
467  "get",
468  [](DefaultingPyMlirContext context) {
469  MlirType t = mlirNoneTypeGet(context->get());
470  return PyNoneType(context->getRef(), t);
471  },
472  nb::arg("context") = nb::none(), "Create a none type.");
473  }
474 };
475 
476 /// Complex Type subclass - ComplexType.
477 class PyComplexType : public PyConcreteType<PyComplexType> {
478 public:
479  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
480  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
482  static constexpr const char *pyClassName = "ComplexType";
484 
485  static void bindDerived(ClassTy &c) {
486  c.def_static(
487  "get",
488  [](PyType &elementType) {
489  // The element must be a floating point or integer scalar type.
490  if (mlirTypeIsAIntegerOrFloat(elementType)) {
491  MlirType t = mlirComplexTypeGet(elementType);
492  return PyComplexType(elementType.getContext(), t);
493  }
494  throw nb::value_error(
495  (Twine("invalid '") +
496  nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
497  "' and expected floating point or integer type.")
498  .str()
499  .c_str());
500  },
501  "Create a complex type");
502  c.def_prop_ro(
503  "element_type",
504  [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
505  return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
506  .maybeDownCast();
507  },
508  "Returns element type.");
509  }
510 };
511 
512 } // namespace
513 
514 // Shaped Type Interface - ShapedType
516  c.def_prop_ro(
517  "element_type",
518  [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
519  return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
520  .maybeDownCast();
521  },
522  "Returns the element type of the shaped type.");
523  c.def_prop_ro(
524  "has_rank",
525  [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
526  "Returns whether the given shaped type is ranked.");
527  c.def_prop_ro(
528  "rank",
529  [](PyShapedType &self) {
530  self.requireHasRank();
531  return mlirShapedTypeGetRank(self);
532  },
533  "Returns the rank of the given ranked shaped type.");
534  c.def_prop_ro(
535  "has_static_shape",
536  [](PyShapedType &self) -> bool {
537  return mlirShapedTypeHasStaticShape(self);
538  },
539  "Returns whether the given shaped type has a static shape.");
540  c.def(
541  "is_dynamic_dim",
542  [](PyShapedType &self, intptr_t dim) -> bool {
543  self.requireHasRank();
544  return mlirShapedTypeIsDynamicDim(self, dim);
545  },
546  nb::arg("dim"),
547  "Returns whether the dim-th dimension of the given shaped type is "
548  "dynamic.");
549  c.def(
550  "is_static_dim",
551  [](PyShapedType &self, intptr_t dim) -> bool {
552  self.requireHasRank();
553  return mlirShapedTypeIsStaticDim(self, dim);
554  },
555  nb::arg("dim"),
556  "Returns whether the dim-th dimension of the given shaped type is "
557  "static.");
558  c.def(
559  "get_dim_size",
560  [](PyShapedType &self, intptr_t dim) {
561  self.requireHasRank();
562  return mlirShapedTypeGetDimSize(self, dim);
563  },
564  nb::arg("dim"),
565  "Returns the dim-th dimension of the given ranked shaped type.");
566  c.def_static(
567  "is_dynamic_size",
568  [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
569  nb::arg("dim_size"),
570  "Returns whether the given dimension size indicates a dynamic "
571  "dimension.");
572  c.def_static(
573  "is_static_size",
574  [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
575  nb::arg("dim_size"),
576  "Returns whether the given dimension size indicates a static "
577  "dimension.");
578  c.def(
579  "is_dynamic_stride_or_offset",
580  [](PyShapedType &self, int64_t val) -> bool {
581  self.requireHasRank();
583  },
584  nb::arg("dim_size"),
585  "Returns whether the given value is used as a placeholder for dynamic "
586  "strides and offsets in shaped types.");
587  c.def(
588  "is_static_stride_or_offset",
589  [](PyShapedType &self, int64_t val) -> bool {
590  self.requireHasRank();
592  },
593  nb::arg("dim_size"),
594  "Returns whether the given shaped type stride or offset value is "
595  "statically-sized.");
596  c.def_prop_ro(
597  "shape",
598  [](PyShapedType &self) {
599  self.requireHasRank();
600 
601  std::vector<int64_t> shape;
602  int64_t rank = mlirShapedTypeGetRank(self);
603  shape.reserve(rank);
604  for (int64_t i = 0; i < rank; ++i)
605  shape.push_back(mlirShapedTypeGetDimSize(self, i));
606  return shape;
607  },
608  "Returns the shape of the ranked shaped type as a list of integers.");
609  c.def_static(
610  "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
611  "Returns the value used to indicate dynamic dimensions in shaped "
612  "types.");
613  c.def_static(
614  "get_dynamic_stride_or_offset",
615  []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
616  "Returns the value used to indicate dynamic strides or offsets in "
617  "shaped types.");
618 }
619 
620 void mlir::PyShapedType::requireHasRank() {
621  if (!mlirShapedTypeHasRank(*this)) {
622  throw nb::value_error(
623  "calling this method requires that the type has a rank.");
624  }
625 }
626 
629 
630 namespace {
631 
632 /// Vector Type subclass - VectorType.
633 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
634 public:
635  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
636  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
638  static constexpr const char *pyClassName = "VectorType";
640 
641  static void bindDerived(ClassTy &c) {
642  c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
643  nb::arg("element_type"), nb::kw_only(),
644  nb::arg("scalable") = nb::none(),
645  nb::arg("scalable_dims") = nb::none(),
646  nb::arg("loc") = nb::none(), "Create a vector type")
647  .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
648  nb::arg("element_type"), nb::kw_only(),
649  nb::arg("scalable") = nb::none(),
650  nb::arg("scalable_dims") = nb::none(),
651  nb::arg("context") = nb::none(), "Create a vector type")
652  .def_prop_ro(
653  "scalable",
654  [](MlirType self) { return mlirVectorTypeIsScalable(self); })
655  .def_prop_ro("scalable_dims", [](MlirType self) {
656  std::vector<bool> scalableDims;
657  size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
658  scalableDims.reserve(rank);
659  for (size_t i = 0; i < rank; ++i)
660  scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
661  return scalableDims;
662  });
663  }
664 
665 private:
666  static PyVectorType
667  getChecked(std::vector<int64_t> shape, PyType &elementType,
668  std::optional<nb::list> scalable,
669  std::optional<std::vector<int64_t>> scalableDims,
670  DefaultingPyLocation loc) {
671  if (scalable && scalableDims) {
672  throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
673  "are mutually exclusive.");
674  }
675 
676  PyMlirContext::ErrorCapture errors(loc->getContext());
677  MlirType type;
678  if (scalable) {
679  if (scalable->size() != shape.size())
680  throw nb::value_error("Expected len(scalable) == len(shape).");
681 
682  SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
683  *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
684  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
685  scalableDimFlags.data(),
686  elementType);
687  } else if (scalableDims) {
688  SmallVector<bool> scalableDimFlags(shape.size(), false);
689  for (int64_t dim : *scalableDims) {
690  if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
691  throw nb::value_error("Scalable dimension index out of bounds.");
692  scalableDimFlags[dim] = true;
693  }
694  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
695  scalableDimFlags.data(),
696  elementType);
697  } else {
698  type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
699  elementType);
700  }
701  if (mlirTypeIsNull(type))
702  throw MLIRError("Invalid type", errors.take());
703  return PyVectorType(elementType.getContext(), type);
704  }
705 
706  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
707  std::optional<nb::list> scalable,
708  std::optional<std::vector<int64_t>> scalableDims,
709  DefaultingPyMlirContext context) {
710  if (scalable && scalableDims) {
711  throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
712  "are mutually exclusive.");
713  }
714 
715  PyMlirContext::ErrorCapture errors(context->getRef());
716  MlirType type;
717  if (scalable) {
718  if (scalable->size() != shape.size())
719  throw nb::value_error("Expected len(scalable) == len(shape).");
720 
721  SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
722  *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
723  type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
724  scalableDimFlags.data(), elementType);
725  } else if (scalableDims) {
726  SmallVector<bool> scalableDimFlags(shape.size(), false);
727  for (int64_t dim : *scalableDims) {
728  if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
729  throw nb::value_error("Scalable dimension index out of bounds.");
730  scalableDimFlags[dim] = true;
731  }
732  type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
733  scalableDimFlags.data(), elementType);
734  } else {
735  type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
736  }
737  if (mlirTypeIsNull(type))
738  throw MLIRError("Invalid type", errors.take());
739  return PyVectorType(elementType.getContext(), type);
740  }
741 };
742 
743 /// Ranked Tensor Type subclass - RankedTensorType.
744 class PyRankedTensorType
745  : public PyConcreteType<PyRankedTensorType, PyShapedType> {
746 public:
747  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
748  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
750  static constexpr const char *pyClassName = "RankedTensorType";
752 
753  static void bindDerived(ClassTy &c) {
754  c.def_static(
755  "get",
756  [](std::vector<int64_t> shape, PyType &elementType,
757  std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
758  PyMlirContext::ErrorCapture errors(loc->getContext());
759  MlirType t = mlirRankedTensorTypeGetChecked(
760  loc, shape.size(), shape.data(), elementType,
761  encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
762  if (mlirTypeIsNull(t))
763  throw MLIRError("Invalid type", errors.take());
764  return PyRankedTensorType(elementType.getContext(), t);
765  },
766  nb::arg("shape"), nb::arg("element_type"),
767  nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
768  "Create a ranked tensor type");
769  c.def_static(
770  "get_unchecked",
771  [](std::vector<int64_t> shape, PyType &elementType,
772  std::optional<PyAttribute> &encodingAttr,
773  DefaultingPyMlirContext context) {
774  PyMlirContext::ErrorCapture errors(context->getRef());
775  MlirType t = mlirRankedTensorTypeGet(
776  shape.size(), shape.data(), elementType,
777  encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
778  if (mlirTypeIsNull(t))
779  throw MLIRError("Invalid type", errors.take());
780  return PyRankedTensorType(elementType.getContext(), t);
781  },
782  nb::arg("shape"), nb::arg("element_type"),
783  nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
784  "Create a ranked tensor type");
785  c.def_prop_ro(
786  "encoding",
787  [](PyRankedTensorType &self)
788  -> std::optional<nb::typed<nb::object, PyAttribute>> {
789  MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
790  if (mlirAttributeIsNull(encoding))
791  return std::nullopt;
792  return PyAttribute(self.getContext(), encoding).maybeDownCast();
793  });
794  }
795 };
796 
797 /// Unranked Tensor Type subclass - UnrankedTensorType.
798 class PyUnrankedTensorType
799  : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
800 public:
801  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
802  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
804  static constexpr const char *pyClassName = "UnrankedTensorType";
806 
807  static void bindDerived(ClassTy &c) {
808  c.def_static(
809  "get",
810  [](PyType &elementType, DefaultingPyLocation loc) {
811  PyMlirContext::ErrorCapture errors(loc->getContext());
812  MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
813  if (mlirTypeIsNull(t))
814  throw MLIRError("Invalid type", errors.take());
815  return PyUnrankedTensorType(elementType.getContext(), t);
816  },
817  nb::arg("element_type"), nb::arg("loc") = nb::none(),
818  "Create a unranked tensor type");
819  c.def_static(
820  "get_unchecked",
821  [](PyType &elementType, DefaultingPyMlirContext context) {
822  PyMlirContext::ErrorCapture errors(context->getRef());
823  MlirType t = mlirUnrankedTensorTypeGet(elementType);
824  if (mlirTypeIsNull(t))
825  throw MLIRError("Invalid type", errors.take());
826  return PyUnrankedTensorType(elementType.getContext(), t);
827  },
828  nb::arg("element_type"), nb::arg("context") = nb::none(),
829  "Create a unranked tensor type");
830  }
831 };
832 
833 /// Ranked MemRef Type subclass - MemRefType.
834 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
835 public:
836  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
837  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
839  static constexpr const char *pyClassName = "MemRefType";
841 
842  static void bindDerived(ClassTy &c) {
843  c.def_static(
844  "get",
845  [](std::vector<int64_t> shape, PyType &elementType,
846  PyAttribute *layout, PyAttribute *memorySpace,
847  DefaultingPyLocation loc) {
848  PyMlirContext::ErrorCapture errors(loc->getContext());
849  MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
850  MlirAttribute memSpaceAttr =
851  memorySpace ? *memorySpace : mlirAttributeGetNull();
852  MlirType t =
853  mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
854  shape.data(), layoutAttr, memSpaceAttr);
855  if (mlirTypeIsNull(t))
856  throw MLIRError("Invalid type", errors.take());
857  return PyMemRefType(elementType.getContext(), t);
858  },
859  nb::arg("shape"), nb::arg("element_type"),
860  nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
861  nb::arg("loc") = nb::none(), "Create a memref type")
862  .def_static(
863  "get_unchecked",
864  [](std::vector<int64_t> shape, PyType &elementType,
865  PyAttribute *layout, PyAttribute *memorySpace,
866  DefaultingPyMlirContext context) {
867  PyMlirContext::ErrorCapture errors(context->getRef());
868  MlirAttribute layoutAttr =
869  layout ? *layout : mlirAttributeGetNull();
870  MlirAttribute memSpaceAttr =
871  memorySpace ? *memorySpace : mlirAttributeGetNull();
872  MlirType t =
873  mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
874  layoutAttr, memSpaceAttr);
875  if (mlirTypeIsNull(t))
876  throw MLIRError("Invalid type", errors.take());
877  return PyMemRefType(elementType.getContext(), t);
878  },
879  nb::arg("shape"), nb::arg("element_type"),
880  nb::arg("layout") = nb::none(),
881  nb::arg("memory_space") = nb::none(),
882  nb::arg("context") = nb::none(), "Create a memref type")
883  .def_prop_ro(
884  "layout",
885  [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
886  return PyAttribute(self.getContext(),
888  .maybeDownCast();
889  },
890  "The layout of the MemRef type.")
891  .def(
892  "get_strides_and_offset",
893  [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
894  std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
895  int64_t offset;
897  self, strides.data(), &offset)))
898  throw std::runtime_error(
899  "Failed to extract strides and offset from memref.");
900  return {strides, offset};
901  },
902  "The strides and offset of the MemRef type.")
903  .def_prop_ro(
904  "affine_map",
905  [](PyMemRefType &self) -> PyAffineMap {
906  MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
907  return PyAffineMap(self.getContext(), map);
908  },
909  "The layout of the MemRef type as an affine map.")
910  .def_prop_ro(
911  "memory_space",
912  [](PyMemRefType &self)
913  -> std::optional<nb::typed<nb::object, PyAttribute>> {
914  MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
915  if (mlirAttributeIsNull(a))
916  return std::nullopt;
917  return PyAttribute(self.getContext(), a).maybeDownCast();
918  },
919  "Returns the memory space of the given MemRef type.");
920  }
921 };
922 
923 /// Unranked MemRef Type subclass - UnrankedMemRefType.
924 class PyUnrankedMemRefType
925  : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
926 public:
927  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
928  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
930  static constexpr const char *pyClassName = "UnrankedMemRefType";
932 
933  static void bindDerived(ClassTy &c) {
934  c.def_static(
935  "get",
936  [](PyType &elementType, PyAttribute *memorySpace,
937  DefaultingPyLocation loc) {
938  PyMlirContext::ErrorCapture errors(loc->getContext());
939  MlirAttribute memSpaceAttr = {};
940  if (memorySpace)
941  memSpaceAttr = *memorySpace;
942 
943  MlirType t =
944  mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
945  if (mlirTypeIsNull(t))
946  throw MLIRError("Invalid type", errors.take());
947  return PyUnrankedMemRefType(elementType.getContext(), t);
948  },
949  nb::arg("element_type"), nb::arg("memory_space").none(),
950  nb::arg("loc") = nb::none(), "Create a unranked memref type")
951  .def_static(
952  "get_unchecked",
953  [](PyType &elementType, PyAttribute *memorySpace,
954  DefaultingPyMlirContext context) {
955  PyMlirContext::ErrorCapture errors(context->getRef());
956  MlirAttribute memSpaceAttr = {};
957  if (memorySpace)
958  memSpaceAttr = *memorySpace;
959 
960  MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
961  if (mlirTypeIsNull(t))
962  throw MLIRError("Invalid type", errors.take());
963  return PyUnrankedMemRefType(elementType.getContext(), t);
964  },
965  nb::arg("element_type"), nb::arg("memory_space").none(),
966  nb::arg("context") = nb::none(), "Create a unranked memref type")
967  .def_prop_ro(
968  "memory_space",
969  [](PyUnrankedMemRefType &self)
970  -> std::optional<nb::typed<nb::object, PyAttribute>> {
971  MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
972  if (mlirAttributeIsNull(a))
973  return std::nullopt;
974  return PyAttribute(self.getContext(), a).maybeDownCast();
975  },
976  "Returns the memory space of the given Unranked MemRef type.");
977  }
978 };
979 
980 /// Tuple Type subclass - TupleType.
981 class PyTupleType : public PyConcreteType<PyTupleType> {
982 public:
983  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
984  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
986  static constexpr const char *pyClassName = "TupleType";
988 
989  static void bindDerived(ClassTy &c) {
990  c.def_static(
991  "get_tuple",
992  [](const std::vector<PyType> &elements,
993  DefaultingPyMlirContext context) {
994  std::vector<MlirType> mlirElements;
995  mlirElements.reserve(elements.size());
996  for (const auto &element : elements)
997  mlirElements.push_back(element.get());
998  MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
999  mlirElements.data());
1000  return PyTupleType(context->getRef(), t);
1001  },
1002  nb::arg("elements"), nb::arg("context") = nb::none(),
1003  "Create a tuple type");
1004  c.def_static(
1005  "get_tuple",
1006  [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
1007  MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
1008  elements.data());
1009  return PyTupleType(context->getRef(), t);
1010  },
1011  nb::arg("elements"), nb::arg("context") = nb::none(),
1012  // clang-format off
1013  nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
1014  // clang-format on
1015  "Create a tuple type");
1016  c.def(
1017  "get_type",
1018  [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
1019  return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
1020  .maybeDownCast();
1021  },
1022  nb::arg("pos"), "Returns the pos-th type in the tuple type.");
1023  c.def_prop_ro(
1024  "num_types",
1025  [](PyTupleType &self) -> intptr_t {
1026  return mlirTupleTypeGetNumTypes(self);
1027  },
1028  "Returns the number of types contained in a tuple.");
1029  }
1030 };
1031 
1032 /// Function type.
1033 class PyFunctionType : public PyConcreteType<PyFunctionType> {
1034 public:
1035  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
1036  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1038  static constexpr const char *pyClassName = "FunctionType";
1040 
1041  static void bindDerived(ClassTy &c) {
1042  c.def_static(
1043  "get",
1044  [](std::vector<PyType> inputs, std::vector<PyType> results,
1045  DefaultingPyMlirContext context) {
1046  std::vector<MlirType> mlirInputs;
1047  mlirInputs.reserve(inputs.size());
1048  for (const auto &input : inputs)
1049  mlirInputs.push_back(input.get());
1050  std::vector<MlirType> mlirResults;
1051  mlirResults.reserve(results.size());
1052  for (const auto &result : results)
1053  mlirResults.push_back(result.get());
1054 
1055  MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
1056  mlirInputs.data(), results.size(),
1057  mlirResults.data());
1058  return PyFunctionType(context->getRef(), t);
1059  },
1060  nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
1061  "Gets a FunctionType from a list of input and result types");
1062  c.def_static(
1063  "get",
1064  [](std::vector<MlirType> inputs, std::vector<MlirType> results,
1065  DefaultingPyMlirContext context) {
1066  MlirType t =
1067  mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
1068  results.size(), results.data());
1069  return PyFunctionType(context->getRef(), t);
1070  },
1071  nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
1072  // clang-format off
1073  nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
1074  // clang-format on
1075  "Gets a FunctionType from a list of input and result types");
1076  c.def_prop_ro(
1077  "inputs",
1078  [](PyFunctionType &self) {
1079  MlirType t = self;
1080  nb::list types;
1081  for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
1082  ++i) {
1083  types.append(mlirFunctionTypeGetInput(t, i));
1084  }
1085  return types;
1086  },
1087  "Returns the list of input types in the FunctionType.");
1088  c.def_prop_ro(
1089  "results",
1090  [](PyFunctionType &self) {
1091  nb::list types;
1092  for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
1093  ++i) {
1094  types.append(mlirFunctionTypeGetResult(self, i));
1095  }
1096  return types;
1097  },
1098  "Returns the list of result types in the FunctionType.");
1099  }
1100 };
1101 
1102 static MlirStringRef toMlirStringRef(const std::string &s) {
1103  return mlirStringRefCreate(s.data(), s.size());
1104 }
1105 
1106 /// Opaque Type subclass - OpaqueType.
1107 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
1108 public:
1109  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
1110  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1112  static constexpr const char *pyClassName = "OpaqueType";
1114 
1115  static void bindDerived(ClassTy &c) {
1116  c.def_static(
1117  "get",
1118  [](const std::string &dialectNamespace, const std::string &typeData,
1119  DefaultingPyMlirContext context) {
1120  MlirType type = mlirOpaqueTypeGet(context->get(),
1121  toMlirStringRef(dialectNamespace),
1122  toMlirStringRef(typeData));
1123  return PyOpaqueType(context->getRef(), type);
1124  },
1125  nb::arg("dialect_namespace"), nb::arg("buffer"),
1126  nb::arg("context") = nb::none(),
1127  "Create an unregistered (opaque) dialect type.");
1128  c.def_prop_ro(
1129  "dialect_namespace",
1130  [](PyOpaqueType &self) {
1132  return nb::str(stringRef.data, stringRef.length);
1133  },
1134  "Returns the dialect namespace for the Opaque type as a string.");
1135  c.def_prop_ro(
1136  "data",
1137  [](PyOpaqueType &self) {
1138  MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
1139  return nb::str(stringRef.data, stringRef.length);
1140  },
1141  "Returns the data for the Opaque type as a string.");
1142  }
1143 };
1144 
1145 } // namespace
1146 
1147 void mlir::python::populateIRTypes(nb::module_ &m) {
1148  PyIntegerType::bind(m);
1149  PyFloatType::bind(m);
1150  PyIndexType::bind(m);
1151  PyFloat4E2M1FNType::bind(m);
1152  PyFloat6E2M3FNType::bind(m);
1153  PyFloat6E3M2FNType::bind(m);
1154  PyFloat8E4M3FNType::bind(m);
1155  PyFloat8E5M2Type::bind(m);
1156  PyFloat8E4M3Type::bind(m);
1157  PyFloat8E4M3FNUZType::bind(m);
1158  PyFloat8E4M3B11FNUZType::bind(m);
1159  PyFloat8E5M2FNUZType::bind(m);
1160  PyFloat8E3M4Type::bind(m);
1161  PyFloat8E8M0FNUType::bind(m);
1162  PyBF16Type::bind(m);
1163  PyF16Type::bind(m);
1164  PyTF32Type::bind(m);
1165  PyF32Type::bind(m);
1166  PyF64Type::bind(m);
1167  PyNoneType::bind(m);
1168  PyComplexType::bind(m);
1169  PyShapedType::bind(m);
1170  PyVectorType::bind(m);
1171  PyRankedTensorType::bind(m);
1172  PyUnrankedTensorType::bind(m);
1173  PyMemRefType::bind(m);
1174  PyUnrankedMemRefType::bind(m);
1175  PyTupleType::bind(m);
1176  PyFunctionType::bind(m);
1177  PyOpaqueType::bind(m);
1178 }
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:223
static MLIRContext * getContext(OpFoldResult val)
Shaped Type Interface - ShapedType.
Definition: IRTypes.h:17
static const IsAFunctionTy isaFunction
Definition: IRTypes.h:19
static void bindDerived(ClassTy &c)
Definition: IRTypes.cpp:515
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:292
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:499
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:273
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1008
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2143
CRTP base classes for Python types that subclass Type and should be castable from it (i....
Definition: IRModule.h:930
static void bind(nanobind::module_ &m)
Definition: IRModule.h:959
Wrapper around the generic MlirType.
Definition: IRModule.h:878
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2189
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void)
Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void)
Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void)
Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void)
Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void)
Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void)
Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is static.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val)
Checks whether the given dimension value of a stride or an offset is statically-sized.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void)
Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void)
Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Creates a scalable vector type with the shape identified by its rank and dimensions.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void)
Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void)
Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void)
Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void)
Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void)
Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType)
Creates a vector type of the shape identified by its rank and dimensions, with the given element type...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void)
Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void)
Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void)
Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void)
Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void)
Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void)
Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void)
Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace)
Creates an Unranked MemRef type with the given element type and in the given memory space.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void)
Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void)
Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size)
Checks whether the given shaped type dimension value is statically-sized.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps,...
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType)
Creates an unranked tensor type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void)
Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void)
Returns the value indicating a dynamic size in a shaped type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void)
Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1183
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1148
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateIRTypes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1318
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:408