MLIR  22.0.0git
IRTypes.cpp
Go to the documentation of this file.
1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // clang-format off
10 #include "IRModule.h"
12 // clang-format on
13 
14 #include <optional>
15 
16 #include "IRModule.h"
17 #include "NanobindUtils.h"
19 #include "mlir-c/BuiltinTypes.h"
20 #include "mlir-c/Support.h"
21 
22 namespace nb = nanobind;
23 using namespace mlir;
24 using namespace mlir::python;
25 
26 using llvm::SmallVector;
27 using llvm::Twine;
28 
29 namespace {
30 
31 /// Checks whether the given type is an integer or float type.
32 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34  mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35 }
36 
37 class PyIntegerType : public PyConcreteType<PyIntegerType> {
38 public:
39  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
42  static constexpr const char *pyClassName = "IntegerType";
44 
45  static void bindDerived(ClassTy &c) {
46  c.def_static(
47  "get_signless",
48  [](unsigned width, DefaultingPyMlirContext context) {
49  MlirType t = mlirIntegerTypeGet(context->get(), width);
50  return PyIntegerType(context->getRef(), t);
51  },
52  nb::arg("width"), nb::arg("context") = nb::none(),
53  "Create a signless integer type");
54  c.def_static(
55  "get_signed",
56  [](unsigned width, DefaultingPyMlirContext context) {
57  MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58  return PyIntegerType(context->getRef(), t);
59  },
60  nb::arg("width"), nb::arg("context") = nb::none(),
61  "Create a signed integer type");
62  c.def_static(
63  "get_unsigned",
64  [](unsigned width, DefaultingPyMlirContext context) {
65  MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66  return PyIntegerType(context->getRef(), t);
67  },
68  nb::arg("width"), nb::arg("context") = nb::none(),
69  "Create an unsigned integer type");
70  c.def_prop_ro(
71  "width",
72  [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73  "Returns the width of the integer type");
74  c.def_prop_ro(
75  "is_signless",
76  [](PyIntegerType &self) -> bool {
77  return mlirIntegerTypeIsSignless(self);
78  },
79  "Returns whether this is a signless integer");
80  c.def_prop_ro(
81  "is_signed",
82  [](PyIntegerType &self) -> bool {
83  return mlirIntegerTypeIsSigned(self);
84  },
85  "Returns whether this is a signed integer");
86  c.def_prop_ro(
87  "is_unsigned",
88  [](PyIntegerType &self) -> bool {
89  return mlirIntegerTypeIsUnsigned(self);
90  },
91  "Returns whether this is an unsigned integer");
92  }
93 };
94 
95 /// Index Type subclass - IndexType.
96 class PyIndexType : public PyConcreteType<PyIndexType> {
97 public:
98  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
101  static constexpr const char *pyClassName = "IndexType";
103 
104  static void bindDerived(ClassTy &c) {
105  c.def_static(
106  "get",
107  [](DefaultingPyMlirContext context) {
108  MlirType t = mlirIndexTypeGet(context->get());
109  return PyIndexType(context->getRef(), t);
110  },
111  nb::arg("context") = nb::none(), "Create a index type.");
112  }
113 };
114 
115 class PyFloatType : public PyConcreteType<PyFloatType> {
116 public:
117  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118  static constexpr const char *pyClassName = "FloatType";
120 
121  static void bindDerived(ClassTy &c) {
122  c.def_prop_ro(
123  "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124  "Returns the width of the floating-point type");
125  }
126 };
127 
128 /// Floating Point Type subclass - Float4E2M1FNType.
129 class PyFloat4E2M1FNType
130  : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131 public:
132  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
135  static constexpr const char *pyClassName = "Float4E2M1FNType";
137 
138  static void bindDerived(ClassTy &c) {
139  c.def_static(
140  "get",
141  [](DefaultingPyMlirContext context) {
142  MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143  return PyFloat4E2M1FNType(context->getRef(), t);
144  },
145  nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
146  }
147 };
148 
149 /// Floating Point Type subclass - Float6E2M3FNType.
150 class PyFloat6E2M3FNType
151  : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152 public:
153  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
156  static constexpr const char *pyClassName = "Float6E2M3FNType";
158 
159  static void bindDerived(ClassTy &c) {
160  c.def_static(
161  "get",
162  [](DefaultingPyMlirContext context) {
163  MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164  return PyFloat6E2M3FNType(context->getRef(), t);
165  },
166  nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
167  }
168 };
169 
170 /// Floating Point Type subclass - Float6E3M2FNType.
171 class PyFloat6E3M2FNType
172  : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173 public:
174  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
177  static constexpr const char *pyClassName = "Float6E3M2FNType";
179 
180  static void bindDerived(ClassTy &c) {
181  c.def_static(
182  "get",
183  [](DefaultingPyMlirContext context) {
184  MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185  return PyFloat6E3M2FNType(context->getRef(), t);
186  },
187  nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
188  }
189 };
190 
191 /// Floating Point Type subclass - Float8E4M3FNType.
192 class PyFloat8E4M3FNType
193  : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194 public:
195  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
198  static constexpr const char *pyClassName = "Float8E4M3FNType";
200 
201  static void bindDerived(ClassTy &c) {
202  c.def_static(
203  "get",
204  [](DefaultingPyMlirContext context) {
205  MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206  return PyFloat8E4M3FNType(context->getRef(), t);
207  },
208  nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
209  }
210 };
211 
212 /// Floating Point Type subclass - Float8E5M2Type.
213 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214 public:
215  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
218  static constexpr const char *pyClassName = "Float8E5M2Type";
220 
221  static void bindDerived(ClassTy &c) {
222  c.def_static(
223  "get",
224  [](DefaultingPyMlirContext context) {
225  MlirType t = mlirFloat8E5M2TypeGet(context->get());
226  return PyFloat8E5M2Type(context->getRef(), t);
227  },
228  nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
229  }
230 };
231 
232 /// Floating Point Type subclass - Float8E4M3Type.
233 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234 public:
235  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
238  static constexpr const char *pyClassName = "Float8E4M3Type";
240 
241  static void bindDerived(ClassTy &c) {
242  c.def_static(
243  "get",
244  [](DefaultingPyMlirContext context) {
245  MlirType t = mlirFloat8E4M3TypeGet(context->get());
246  return PyFloat8E4M3Type(context->getRef(), t);
247  },
248  nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
249  }
250 };
251 
252 /// Floating Point Type subclass - Float8E4M3FNUZ.
253 class PyFloat8E4M3FNUZType
254  : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255 public:
256  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
259  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
261 
262  static void bindDerived(ClassTy &c) {
263  c.def_static(
264  "get",
265  [](DefaultingPyMlirContext context) {
266  MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267  return PyFloat8E4M3FNUZType(context->getRef(), t);
268  },
269  nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
270  }
271 };
272 
273 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
274 class PyFloat8E4M3B11FNUZType
275  : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
276 public:
277  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
278  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282 
283  static void bindDerived(ClassTy &c) {
284  c.def_static(
285  "get",
286  [](DefaultingPyMlirContext context) {
287  MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
288  return PyFloat8E4M3B11FNUZType(context->getRef(), t);
289  },
290  nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
291  }
292 };
293 
294 /// Floating Point Type subclass - Float8E5M2FNUZ.
295 class PyFloat8E5M2FNUZType
296  : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
297 public:
298  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
299  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
301  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
303 
304  static void bindDerived(ClassTy &c) {
305  c.def_static(
306  "get",
307  [](DefaultingPyMlirContext context) {
308  MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
309  return PyFloat8E5M2FNUZType(context->getRef(), t);
310  },
311  nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
312  }
313 };
314 
315 /// Floating Point Type subclass - Float8E3M4Type.
316 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
317 public:
318  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
319  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
321  static constexpr const char *pyClassName = "Float8E3M4Type";
323 
324  static void bindDerived(ClassTy &c) {
325  c.def_static(
326  "get",
327  [](DefaultingPyMlirContext context) {
328  MlirType t = mlirFloat8E3M4TypeGet(context->get());
329  return PyFloat8E3M4Type(context->getRef(), t);
330  },
331  nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
332  }
333 };
334 
335 /// Floating Point Type subclass - Float8E8M0FNUType.
336 class PyFloat8E8M0FNUType
337  : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
338 public:
339  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
340  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
342  static constexpr const char *pyClassName = "Float8E8M0FNUType";
344 
345  static void bindDerived(ClassTy &c) {
346  c.def_static(
347  "get",
348  [](DefaultingPyMlirContext context) {
349  MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
350  return PyFloat8E8M0FNUType(context->getRef(), t);
351  },
352  nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
353  }
354 };
355 
356 /// Floating Point Type subclass - BF16Type.
357 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
358 public:
359  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
360  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
362  static constexpr const char *pyClassName = "BF16Type";
364 
365  static void bindDerived(ClassTy &c) {
366  c.def_static(
367  "get",
368  [](DefaultingPyMlirContext context) {
369  MlirType t = mlirBF16TypeGet(context->get());
370  return PyBF16Type(context->getRef(), t);
371  },
372  nb::arg("context") = nb::none(), "Create a bf16 type.");
373  }
374 };
375 
376 /// Floating Point Type subclass - F16Type.
377 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
378 public:
379  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
380  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
382  static constexpr const char *pyClassName = "F16Type";
384 
385  static void bindDerived(ClassTy &c) {
386  c.def_static(
387  "get",
388  [](DefaultingPyMlirContext context) {
389  MlirType t = mlirF16TypeGet(context->get());
390  return PyF16Type(context->getRef(), t);
391  },
392  nb::arg("context") = nb::none(), "Create a f16 type.");
393  }
394 };
395 
396 /// Floating Point Type subclass - TF32Type.
397 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
398 public:
399  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
400  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
402  static constexpr const char *pyClassName = "FloatTF32Type";
404 
405  static void bindDerived(ClassTy &c) {
406  c.def_static(
407  "get",
408  [](DefaultingPyMlirContext context) {
409  MlirType t = mlirTF32TypeGet(context->get());
410  return PyTF32Type(context->getRef(), t);
411  },
412  nb::arg("context") = nb::none(), "Create a tf32 type.");
413  }
414 };
415 
416 /// Floating Point Type subclass - F32Type.
417 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
418 public:
419  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
420  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
422  static constexpr const char *pyClassName = "F32Type";
424 
425  static void bindDerived(ClassTy &c) {
426  c.def_static(
427  "get",
428  [](DefaultingPyMlirContext context) {
429  MlirType t = mlirF32TypeGet(context->get());
430  return PyF32Type(context->getRef(), t);
431  },
432  nb::arg("context") = nb::none(), "Create a f32 type.");
433  }
434 };
435 
436 /// Floating Point Type subclass - F64Type.
437 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
438 public:
439  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
440  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
442  static constexpr const char *pyClassName = "F64Type";
444 
445  static void bindDerived(ClassTy &c) {
446  c.def_static(
447  "get",
448  [](DefaultingPyMlirContext context) {
449  MlirType t = mlirF64TypeGet(context->get());
450  return PyF64Type(context->getRef(), t);
451  },
452  nb::arg("context") = nb::none(), "Create a f64 type.");
453  }
454 };
455 
456 /// None Type subclass - NoneType.
457 class PyNoneType : public PyConcreteType<PyNoneType> {
458 public:
459  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
460  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
462  static constexpr const char *pyClassName = "NoneType";
464 
465  static void bindDerived(ClassTy &c) {
466  c.def_static(
467  "get",
468  [](DefaultingPyMlirContext context) {
469  MlirType t = mlirNoneTypeGet(context->get());
470  return PyNoneType(context->getRef(), t);
471  },
472  nb::arg("context") = nb::none(), "Create a none type.");
473  }
474 };
475 
476 /// Complex Type subclass - ComplexType.
477 class PyComplexType : public PyConcreteType<PyComplexType> {
478 public:
479  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
480  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
482  static constexpr const char *pyClassName = "ComplexType";
484 
485  static void bindDerived(ClassTy &c) {
486  c.def_static(
487  "get",
488  [](PyType &elementType) {
489  // The element must be a floating point or integer scalar type.
490  if (mlirTypeIsAIntegerOrFloat(elementType)) {
491  MlirType t = mlirComplexTypeGet(elementType);
492  return PyComplexType(elementType.getContext(), t);
493  }
494  throw nb::value_error(
495  (Twine("invalid '") +
496  nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
497  "' and expected floating point or integer type.")
498  .str()
499  .c_str());
500  },
501  "Create a complex type");
502  c.def_prop_ro(
503  "element_type",
504  [](PyComplexType &self) {
505  return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
506  .maybeDownCast();
507  },
508  "Returns element type.");
509  }
510 };
511 
512 } // namespace
513 
514 // Shaped Type Interface - ShapedType
516  c.def_prop_ro(
517  "element_type",
518  [](PyShapedType &self) {
519  return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
520  .maybeDownCast();
521  },
522  "Returns the element type of the shaped type.");
523  c.def_prop_ro(
524  "has_rank",
525  [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
526  "Returns whether the given shaped type is ranked.");
527  c.def_prop_ro(
528  "rank",
529  [](PyShapedType &self) {
530  self.requireHasRank();
531  return mlirShapedTypeGetRank(self);
532  },
533  "Returns the rank of the given ranked shaped type.");
534  c.def_prop_ro(
535  "has_static_shape",
536  [](PyShapedType &self) -> bool {
537  return mlirShapedTypeHasStaticShape(self);
538  },
539  "Returns whether the given shaped type has a static shape.");
540  c.def(
541  "is_dynamic_dim",
542  [](PyShapedType &self, intptr_t dim) -> bool {
543  self.requireHasRank();
544  return mlirShapedTypeIsDynamicDim(self, dim);
545  },
546  nb::arg("dim"),
547  "Returns whether the dim-th dimension of the given shaped type is "
548  "dynamic.");
549  c.def(
550  "is_static_dim",
551  [](PyShapedType &self, intptr_t dim) -> bool {
552  self.requireHasRank();
553  return mlirShapedTypeIsStaticDim(self, dim);
554  },
555  nb::arg("dim"),
556  "Returns whether the dim-th dimension of the given shaped type is "
557  "static.");
558  c.def(
559  "get_dim_size",
560  [](PyShapedType &self, intptr_t dim) {
561  self.requireHasRank();
562  return mlirShapedTypeGetDimSize(self, dim);
563  },
564  nb::arg("dim"),
565  "Returns the dim-th dimension of the given ranked shaped type.");
566  c.def_static(
567  "is_dynamic_size",
568  [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
569  nb::arg("dim_size"),
570  "Returns whether the given dimension size indicates a dynamic "
571  "dimension.");
572  c.def_static(
573  "is_static_size",
574  [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
575  nb::arg("dim_size"),
576  "Returns whether the given dimension size indicates a static "
577  "dimension.");
578  c.def(
579  "is_dynamic_stride_or_offset",
580  [](PyShapedType &self, int64_t val) -> bool {
581  self.requireHasRank();
583  },
584  nb::arg("dim_size"),
585  "Returns whether the given value is used as a placeholder for dynamic "
586  "strides and offsets in shaped types.");
587  c.def(
588  "is_static_stride_or_offset",
589  [](PyShapedType &self, int64_t val) -> bool {
590  self.requireHasRank();
592  },
593  nb::arg("dim_size"),
594  "Returns whether the given shaped type stride or offset value is "
595  "statically-sized.");
596  c.def_prop_ro(
597  "shape",
598  [](PyShapedType &self) {
599  self.requireHasRank();
600 
601  std::vector<int64_t> shape;
602  int64_t rank = mlirShapedTypeGetRank(self);
603  shape.reserve(rank);
604  for (int64_t i = 0; i < rank; ++i)
605  shape.push_back(mlirShapedTypeGetDimSize(self, i));
606  return shape;
607  },
608  "Returns the shape of the ranked shaped type as a list of integers.");
609  c.def_static(
610  "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
611  "Returns the value used to indicate dynamic dimensions in shaped "
612  "types.");
613  c.def_static(
614  "get_dynamic_stride_or_offset",
615  []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
616  "Returns the value used to indicate dynamic strides or offsets in "
617  "shaped types.");
618 }
619 
620 void mlir::PyShapedType::requireHasRank() {
621  if (!mlirShapedTypeHasRank(*this)) {
622  throw nb::value_error(
623  "calling this method requires that the type has a rank.");
624  }
625 }
626 
629 
630 namespace {
631 
632 /// Vector Type subclass - VectorType.
633 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
634 public:
635  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
636  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
638  static constexpr const char *pyClassName = "VectorType";
640 
641  static void bindDerived(ClassTy &c) {
642  c.def_static("get", &PyVectorType::get, nb::arg("shape"),
643  nb::arg("element_type"), nb::kw_only(),
644  nb::arg("scalable") = nb::none(),
645  nb::arg("scalable_dims") = nb::none(),
646  nb::arg("loc") = nb::none(), "Create a vector type")
647  .def_prop_ro(
648  "scalable",
649  [](MlirType self) { return mlirVectorTypeIsScalable(self); })
650  .def_prop_ro("scalable_dims", [](MlirType self) {
651  std::vector<bool> scalableDims;
652  size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
653  scalableDims.reserve(rank);
654  for (size_t i = 0; i < rank; ++i)
655  scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
656  return scalableDims;
657  });
658  }
659 
660 private:
661  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
662  std::optional<nb::list> scalable,
663  std::optional<std::vector<int64_t>> scalableDims,
664  DefaultingPyLocation loc) {
665  if (scalable && scalableDims) {
666  throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
667  "are mutually exclusive.");
668  }
669 
670  PyMlirContext::ErrorCapture errors(loc->getContext());
671  MlirType type;
672  if (scalable) {
673  if (scalable->size() != shape.size())
674  throw nb::value_error("Expected len(scalable) == len(shape).");
675 
676  SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
677  *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
678  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
679  scalableDimFlags.data(),
680  elementType);
681  } else if (scalableDims) {
682  SmallVector<bool> scalableDimFlags(shape.size(), false);
683  for (int64_t dim : *scalableDims) {
684  if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
685  throw nb::value_error("Scalable dimension index out of bounds.");
686  scalableDimFlags[dim] = true;
687  }
688  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
689  scalableDimFlags.data(),
690  elementType);
691  } else {
692  type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
693  elementType);
694  }
695  if (mlirTypeIsNull(type))
696  throw MLIRError("Invalid type", errors.take());
697  return PyVectorType(elementType.getContext(), type);
698  }
699 };
700 
701 /// Ranked Tensor Type subclass - RankedTensorType.
702 class PyRankedTensorType
703  : public PyConcreteType<PyRankedTensorType, PyShapedType> {
704 public:
705  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
706  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
708  static constexpr const char *pyClassName = "RankedTensorType";
710 
711  static void bindDerived(ClassTy &c) {
712  c.def_static(
713  "get",
714  [](std::vector<int64_t> shape, PyType &elementType,
715  std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
716  PyMlirContext::ErrorCapture errors(loc->getContext());
717  MlirType t = mlirRankedTensorTypeGetChecked(
718  loc, shape.size(), shape.data(), elementType,
719  encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
720  if (mlirTypeIsNull(t))
721  throw MLIRError("Invalid type", errors.take());
722  return PyRankedTensorType(elementType.getContext(), t);
723  },
724  nb::arg("shape"), nb::arg("element_type"),
725  nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
726  "Create a ranked tensor type");
727  c.def_prop_ro(
728  "encoding",
729  [](PyRankedTensorType &self)
730  -> std::optional<nb::typed<nb::object, PyAttribute>> {
731  MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
732  if (mlirAttributeIsNull(encoding))
733  return std::nullopt;
734  return nb::cast<nb::typed<nb::object, PyAttribute>>(
735  PyAttribute(self.getContext(), encoding).maybeDownCast());
736  });
737  }
738 };
739 
740 /// Unranked Tensor Type subclass - UnrankedTensorType.
741 class PyUnrankedTensorType
742  : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
743 public:
744  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
745  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
747  static constexpr const char *pyClassName = "UnrankedTensorType";
749 
750  static void bindDerived(ClassTy &c) {
751  c.def_static(
752  "get",
753  [](PyType &elementType, DefaultingPyLocation loc) {
754  PyMlirContext::ErrorCapture errors(loc->getContext());
755  MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
756  if (mlirTypeIsNull(t))
757  throw MLIRError("Invalid type", errors.take());
758  return PyUnrankedTensorType(elementType.getContext(), t);
759  },
760  nb::arg("element_type"), nb::arg("loc") = nb::none(),
761  "Create a unranked tensor type");
762  }
763 };
764 
765 /// Ranked MemRef Type subclass - MemRefType.
766 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
767 public:
768  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
769  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
771  static constexpr const char *pyClassName = "MemRefType";
773 
774  static void bindDerived(ClassTy &c) {
775  c.def_static(
776  "get",
777  [](std::vector<int64_t> shape, PyType &elementType,
778  PyAttribute *layout, PyAttribute *memorySpace,
779  DefaultingPyLocation loc) {
780  PyMlirContext::ErrorCapture errors(loc->getContext());
781  MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
782  MlirAttribute memSpaceAttr =
783  memorySpace ? *memorySpace : mlirAttributeGetNull();
784  MlirType t =
785  mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
786  shape.data(), layoutAttr, memSpaceAttr);
787  if (mlirTypeIsNull(t))
788  throw MLIRError("Invalid type", errors.take());
789  return PyMemRefType(elementType.getContext(), t);
790  },
791  nb::arg("shape"), nb::arg("element_type"),
792  nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
793  nb::arg("loc") = nb::none(), "Create a memref type")
794  .def_prop_ro(
795  "layout",
796  [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
797  return nb::cast<nb::typed<nb::object, PyAttribute>>(
799  .maybeDownCast());
800  },
801  "The layout of the MemRef type.")
802  .def(
803  "get_strides_and_offset",
804  [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
805  std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
806  int64_t offset;
808  self, strides.data(), &offset)))
809  throw std::runtime_error(
810  "Failed to extract strides and offset from memref.");
811  return {strides, offset};
812  },
813  "The strides and offset of the MemRef type.")
814  .def_prop_ro(
815  "affine_map",
816  [](PyMemRefType &self) -> PyAffineMap {
817  MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
818  return PyAffineMap(self.getContext(), map);
819  },
820  "The layout of the MemRef type as an affine map.")
821  .def_prop_ro(
822  "memory_space",
823  [](PyMemRefType &self)
824  -> std::optional<nb::typed<nb::object, PyAttribute>> {
825  MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
826  if (mlirAttributeIsNull(a))
827  return std::nullopt;
828  return nb::cast<nb::typed<nb::object, PyAttribute>>(
829  PyAttribute(self.getContext(), a).maybeDownCast());
830  },
831  "Returns the memory space of the given MemRef type.");
832  }
833 };
834 
835 /// Unranked MemRef Type subclass - UnrankedMemRefType.
836 class PyUnrankedMemRefType
837  : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
838 public:
839  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
840  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
842  static constexpr const char *pyClassName = "UnrankedMemRefType";
844 
845  static void bindDerived(ClassTy &c) {
846  c.def_static(
847  "get",
848  [](PyType &elementType, PyAttribute *memorySpace,
849  DefaultingPyLocation loc) {
850  PyMlirContext::ErrorCapture errors(loc->getContext());
851  MlirAttribute memSpaceAttr = {};
852  if (memorySpace)
853  memSpaceAttr = *memorySpace;
854 
855  MlirType t =
856  mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
857  if (mlirTypeIsNull(t))
858  throw MLIRError("Invalid type", errors.take());
859  return PyUnrankedMemRefType(elementType.getContext(), t);
860  },
861  nb::arg("element_type"), nb::arg("memory_space").none(),
862  nb::arg("loc") = nb::none(), "Create a unranked memref type")
863  .def_prop_ro(
864  "memory_space",
865  [](PyUnrankedMemRefType &self)
866  -> std::optional<nb::typed<nb::object, PyAttribute>> {
867  MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
868  if (mlirAttributeIsNull(a))
869  return std::nullopt;
870  return nb::cast<nb::typed<nb::object, PyAttribute>>(
871  PyAttribute(self.getContext(), a).maybeDownCast());
872  },
873  "Returns the memory space of the given Unranked MemRef type.");
874  }
875 };
876 
877 /// Tuple Type subclass - TupleType.
878 class PyTupleType : public PyConcreteType<PyTupleType> {
879 public:
880  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
881  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
883  static constexpr const char *pyClassName = "TupleType";
885 
886  static void bindDerived(ClassTy &c) {
887  c.def_static(
888  "get_tuple",
889  [](const std::vector<PyType> &elements,
890  DefaultingPyMlirContext context) {
891  std::vector<MlirType> mlirElements;
892  mlirElements.reserve(elements.size());
893  for (const auto &element : elements)
894  mlirElements.push_back(element.get());
895  MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
896  mlirElements.data());
897  return PyTupleType(context->getRef(), t);
898  },
899  nb::arg("elements"), nb::arg("context") = nb::none(),
900  "Create a tuple type");
901  c.def(
902  "get_type",
903  [](PyTupleType &self, intptr_t pos) {
904  return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
905  .maybeDownCast();
906  },
907  nb::arg("pos"), "Returns the pos-th type in the tuple type.");
908  c.def_prop_ro(
909  "num_types",
910  [](PyTupleType &self) -> intptr_t {
911  return mlirTupleTypeGetNumTypes(self);
912  },
913  "Returns the number of types contained in a tuple.");
914  }
915 };
916 
917 /// Function type.
918 class PyFunctionType : public PyConcreteType<PyFunctionType> {
919 public:
920  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
921  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
923  static constexpr const char *pyClassName = "FunctionType";
925 
926  static void bindDerived(ClassTy &c) {
927  c.def_static(
928  "get",
929  [](std::vector<PyType> inputs, std::vector<PyType> results,
930  DefaultingPyMlirContext context) {
931  std::vector<MlirType> mlirInputs;
932  mlirInputs.reserve(inputs.size());
933  for (const auto &input : inputs)
934  mlirInputs.push_back(input.get());
935  std::vector<MlirType> mlirResults;
936  mlirResults.reserve(results.size());
937  for (const auto &result : results)
938  mlirResults.push_back(result.get());
939 
940  MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
941  mlirInputs.data(), results.size(),
942  mlirResults.data());
943  return PyFunctionType(context->getRef(), t);
944  },
945  nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
946  "Gets a FunctionType from a list of input and result types");
947  c.def_prop_ro(
948  "inputs",
949  [](PyFunctionType &self) {
950  MlirType t = self;
951  nb::list types;
952  for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
953  ++i) {
954  types.append(mlirFunctionTypeGetInput(t, i));
955  }
956  return types;
957  },
958  "Returns the list of input types in the FunctionType.");
959  c.def_prop_ro(
960  "results",
961  [](PyFunctionType &self) {
962  nb::list types;
963  for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
964  ++i) {
965  types.append(mlirFunctionTypeGetResult(self, i));
966  }
967  return types;
968  },
969  "Returns the list of result types in the FunctionType.");
970  }
971 };
972 
973 static MlirStringRef toMlirStringRef(const std::string &s) {
974  return mlirStringRefCreate(s.data(), s.size());
975 }
976 
977 /// Opaque Type subclass - OpaqueType.
978 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
979 public:
980  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
981  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
983  static constexpr const char *pyClassName = "OpaqueType";
985 
986  static void bindDerived(ClassTy &c) {
987  c.def_static(
988  "get",
989  [](const std::string &dialectNamespace, const std::string &typeData,
990  DefaultingPyMlirContext context) {
991  MlirType type = mlirOpaqueTypeGet(context->get(),
992  toMlirStringRef(dialectNamespace),
993  toMlirStringRef(typeData));
994  return PyOpaqueType(context->getRef(), type);
995  },
996  nb::arg("dialect_namespace"), nb::arg("buffer"),
997  nb::arg("context") = nb::none(),
998  "Create an unregistered (opaque) dialect type.");
999  c.def_prop_ro(
1000  "dialect_namespace",
1001  [](PyOpaqueType &self) {
1003  return nb::str(stringRef.data, stringRef.length);
1004  },
1005  "Returns the dialect namespace for the Opaque type as a string.");
1006  c.def_prop_ro(
1007  "data",
1008  [](PyOpaqueType &self) {
1009  MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
1010  return nb::str(stringRef.data, stringRef.length);
1011  },
1012  "Returns the data for the Opaque type as a string.");
1013  }
1014 };
1015 
1016 } // namespace
1017 
1018 void mlir::python::populateIRTypes(nb::module_ &m) {
1019  PyIntegerType::bind(m);
1020  PyFloatType::bind(m);
1021  PyIndexType::bind(m);
1022  PyFloat4E2M1FNType::bind(m);
1023  PyFloat6E2M3FNType::bind(m);
1024  PyFloat6E3M2FNType::bind(m);
1025  PyFloat8E4M3FNType::bind(m);
1026  PyFloat8E5M2Type::bind(m);
1027  PyFloat8E4M3Type::bind(m);
1028  PyFloat8E4M3FNUZType::bind(m);
1029  PyFloat8E4M3B11FNUZType::bind(m);
1030  PyFloat8E5M2FNUZType::bind(m);
1031  PyFloat8E3M4Type::bind(m);
1032  PyFloat8E8M0FNUType::bind(m);
1033  PyBF16Type::bind(m);
1034  PyF16Type::bind(m);
1035  PyTF32Type::bind(m);
1036  PyF32Type::bind(m);
1037  PyF64Type::bind(m);
1038  PyNoneType::bind(m);
1039  PyComplexType::bind(m);
1040  PyShapedType::bind(m);
1041  PyVectorType::bind(m);
1042  PyRankedTensorType::bind(m);
1043  PyUnrankedTensorType::bind(m);
1044  PyMemRefType::bind(m);
1045  PyUnrankedMemRefType::bind(m);
1046  PyTupleType::bind(m);
1047  PyFunctionType::bind(m);
1048  PyOpaqueType::bind(m);
1049 }
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:223
static MLIRContext * getContext(OpFoldResult val)
Shaped Type Interface - ShapedType.
Definition: IRTypes.h:17
static const IsAFunctionTy isaFunction
Definition: IRTypes.h:19
static void bindDerived(ClassTy &c)
Definition: IRTypes.cpp:515
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:293
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:500
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:273
ReferrentTy * get() const
Definition: NanobindUtils.h:60
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1008
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2136
CRTP base classes for Python types that subclass Type and should be castable from it (i....
Definition: IRModule.h:930
static void bind(nanobind::module_ &m)
Definition: IRModule.h:959
Wrapper around the generic MlirType.
Definition: IRModule.h:878
nanobind::object maybeDownCast()
Definition: IRCore.cpp:2182
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void)
Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void)
Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void)
Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void)
Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void)
Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void)
Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is static.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val)
Checks whether the given dimension value of a stride or an offset is statically-sized.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void)
Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void)
Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void)
Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void)
Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void)
Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void)
Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void)
Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void)
Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void)
Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void)
Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void)
Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void)
Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void)
Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void)
Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void)
Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void)
Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size)
Checks whether the given shaped type dimension value is statically-sized.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void)
Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void)
Returns the value indicating a dynamic size in a shaped type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void)
Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1179
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1144
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateIRTypes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1316
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:409