MLIR  20.0.0git
IRTypes.cpp
Go to the documentation of this file.
1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // clang-format off
10 #include "IRModule.h"
12 // clang-format on
13 
14 #include <optional>
15 
16 #include "IRModule.h"
17 #include "NanobindUtils.h"
19 #include "mlir-c/BuiltinTypes.h"
20 #include "mlir-c/Support.h"
21 
22 namespace nb = nanobind;
23 using namespace mlir;
24 using namespace mlir::python;
25 
26 using llvm::SmallVector;
27 using llvm::Twine;
28 
29 namespace {
30 
31 /// Checks whether the given type is an integer or float type.
32 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34  mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35 }
36 
37 class PyIntegerType : public PyConcreteType<PyIntegerType> {
38 public:
39  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
42  static constexpr const char *pyClassName = "IntegerType";
44 
45  static void bindDerived(ClassTy &c) {
46  c.def_static(
47  "get_signless",
48  [](unsigned width, DefaultingPyMlirContext context) {
49  MlirType t = mlirIntegerTypeGet(context->get(), width);
50  return PyIntegerType(context->getRef(), t);
51  },
52  nb::arg("width"), nb::arg("context").none() = nb::none(),
53  "Create a signless integer type");
54  c.def_static(
55  "get_signed",
56  [](unsigned width, DefaultingPyMlirContext context) {
57  MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58  return PyIntegerType(context->getRef(), t);
59  },
60  nb::arg("width"), nb::arg("context").none() = nb::none(),
61  "Create a signed integer type");
62  c.def_static(
63  "get_unsigned",
64  [](unsigned width, DefaultingPyMlirContext context) {
65  MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66  return PyIntegerType(context->getRef(), t);
67  },
68  nb::arg("width"), nb::arg("context").none() = nb::none(),
69  "Create an unsigned integer type");
70  c.def_prop_ro(
71  "width",
72  [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73  "Returns the width of the integer type");
74  c.def_prop_ro(
75  "is_signless",
76  [](PyIntegerType &self) -> bool {
77  return mlirIntegerTypeIsSignless(self);
78  },
79  "Returns whether this is a signless integer");
80  c.def_prop_ro(
81  "is_signed",
82  [](PyIntegerType &self) -> bool {
83  return mlirIntegerTypeIsSigned(self);
84  },
85  "Returns whether this is a signed integer");
86  c.def_prop_ro(
87  "is_unsigned",
88  [](PyIntegerType &self) -> bool {
89  return mlirIntegerTypeIsUnsigned(self);
90  },
91  "Returns whether this is an unsigned integer");
92  }
93 };
94 
95 /// Index Type subclass - IndexType.
96 class PyIndexType : public PyConcreteType<PyIndexType> {
97 public:
98  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
101  static constexpr const char *pyClassName = "IndexType";
103 
104  static void bindDerived(ClassTy &c) {
105  c.def_static(
106  "get",
107  [](DefaultingPyMlirContext context) {
108  MlirType t = mlirIndexTypeGet(context->get());
109  return PyIndexType(context->getRef(), t);
110  },
111  nb::arg("context").none() = nb::none(), "Create a index type.");
112  }
113 };
114 
115 class PyFloatType : public PyConcreteType<PyFloatType> {
116 public:
117  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118  static constexpr const char *pyClassName = "FloatType";
120 
121  static void bindDerived(ClassTy &c) {
122  c.def_prop_ro(
123  "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124  "Returns the width of the floating-point type");
125  }
126 };
127 
128 /// Floating Point Type subclass - Float4E2M1FNType.
129 class PyFloat4E2M1FNType
130  : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131 public:
132  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
135  static constexpr const char *pyClassName = "Float4E2M1FNType";
137 
138  static void bindDerived(ClassTy &c) {
139  c.def_static(
140  "get",
141  [](DefaultingPyMlirContext context) {
142  MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143  return PyFloat4E2M1FNType(context->getRef(), t);
144  },
145  nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
146  }
147 };
148 
149 /// Floating Point Type subclass - Float6E2M3FNType.
150 class PyFloat6E2M3FNType
151  : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152 public:
153  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
156  static constexpr const char *pyClassName = "Float6E2M3FNType";
158 
159  static void bindDerived(ClassTy &c) {
160  c.def_static(
161  "get",
162  [](DefaultingPyMlirContext context) {
163  MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164  return PyFloat6E2M3FNType(context->getRef(), t);
165  },
166  nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
167  }
168 };
169 
170 /// Floating Point Type subclass - Float6E3M2FNType.
171 class PyFloat6E3M2FNType
172  : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173 public:
174  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
177  static constexpr const char *pyClassName = "Float6E3M2FNType";
179 
180  static void bindDerived(ClassTy &c) {
181  c.def_static(
182  "get",
183  [](DefaultingPyMlirContext context) {
184  MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185  return PyFloat6E3M2FNType(context->getRef(), t);
186  },
187  nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
188  }
189 };
190 
191 /// Floating Point Type subclass - Float8E4M3FNType.
192 class PyFloat8E4M3FNType
193  : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194 public:
195  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
198  static constexpr const char *pyClassName = "Float8E4M3FNType";
200 
201  static void bindDerived(ClassTy &c) {
202  c.def_static(
203  "get",
204  [](DefaultingPyMlirContext context) {
205  MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206  return PyFloat8E4M3FNType(context->getRef(), t);
207  },
208  nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
209  }
210 };
211 
212 /// Floating Point Type subclass - Float8E5M2Type.
213 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214 public:
215  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
218  static constexpr const char *pyClassName = "Float8E5M2Type";
220 
221  static void bindDerived(ClassTy &c) {
222  c.def_static(
223  "get",
224  [](DefaultingPyMlirContext context) {
225  MlirType t = mlirFloat8E5M2TypeGet(context->get());
226  return PyFloat8E5M2Type(context->getRef(), t);
227  },
228  nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
229  }
230 };
231 
232 /// Floating Point Type subclass - Float8E4M3Type.
233 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234 public:
235  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
238  static constexpr const char *pyClassName = "Float8E4M3Type";
240 
241  static void bindDerived(ClassTy &c) {
242  c.def_static(
243  "get",
244  [](DefaultingPyMlirContext context) {
245  MlirType t = mlirFloat8E4M3TypeGet(context->get());
246  return PyFloat8E4M3Type(context->getRef(), t);
247  },
248  nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
249  }
250 };
251 
252 /// Floating Point Type subclass - Float8E4M3FNUZ.
253 class PyFloat8E4M3FNUZType
254  : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255 public:
256  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
259  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
261 
262  static void bindDerived(ClassTy &c) {
263  c.def_static(
264  "get",
265  [](DefaultingPyMlirContext context) {
266  MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267  return PyFloat8E4M3FNUZType(context->getRef(), t);
268  },
269  nb::arg("context").none() = nb::none(),
270  "Create a float8_e4m3fnuz type.");
271  }
272 };
273 
274 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
275 class PyFloat8E4M3B11FNUZType
276  : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
277 public:
278  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
279  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
281  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
283 
284  static void bindDerived(ClassTy &c) {
285  c.def_static(
286  "get",
287  [](DefaultingPyMlirContext context) {
288  MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
289  return PyFloat8E4M3B11FNUZType(context->getRef(), t);
290  },
291  nb::arg("context").none() = nb::none(),
292  "Create a float8_e4m3b11fnuz type.");
293  }
294 };
295 
296 /// Floating Point Type subclass - Float8E5M2FNUZ.
297 class PyFloat8E5M2FNUZType
298  : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
299 public:
300  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
301  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
303  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
305 
306  static void bindDerived(ClassTy &c) {
307  c.def_static(
308  "get",
309  [](DefaultingPyMlirContext context) {
310  MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
311  return PyFloat8E5M2FNUZType(context->getRef(), t);
312  },
313  nb::arg("context").none() = nb::none(),
314  "Create a float8_e5m2fnuz type.");
315  }
316 };
317 
318 /// Floating Point Type subclass - Float8E3M4Type.
319 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
320 public:
321  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
322  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
324  static constexpr const char *pyClassName = "Float8E3M4Type";
326 
327  static void bindDerived(ClassTy &c) {
328  c.def_static(
329  "get",
330  [](DefaultingPyMlirContext context) {
331  MlirType t = mlirFloat8E3M4TypeGet(context->get());
332  return PyFloat8E3M4Type(context->getRef(), t);
333  },
334  nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
335  }
336 };
337 
338 /// Floating Point Type subclass - Float8E8M0FNUType.
339 class PyFloat8E8M0FNUType
340  : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
341 public:
342  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
343  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
345  static constexpr const char *pyClassName = "Float8E8M0FNUType";
347 
348  static void bindDerived(ClassTy &c) {
349  c.def_static(
350  "get",
351  [](DefaultingPyMlirContext context) {
352  MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
353  return PyFloat8E8M0FNUType(context->getRef(), t);
354  },
355  nb::arg("context").none() = nb::none(),
356  "Create a float8_e8m0fnu type.");
357  }
358 };
359 
360 /// Floating Point Type subclass - BF16Type.
361 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
362 public:
363  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
364  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
366  static constexpr const char *pyClassName = "BF16Type";
368 
369  static void bindDerived(ClassTy &c) {
370  c.def_static(
371  "get",
372  [](DefaultingPyMlirContext context) {
373  MlirType t = mlirBF16TypeGet(context->get());
374  return PyBF16Type(context->getRef(), t);
375  },
376  nb::arg("context").none() = nb::none(), "Create a bf16 type.");
377  }
378 };
379 
380 /// Floating Point Type subclass - F16Type.
381 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
382 public:
383  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
384  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
386  static constexpr const char *pyClassName = "F16Type";
388 
389  static void bindDerived(ClassTy &c) {
390  c.def_static(
391  "get",
392  [](DefaultingPyMlirContext context) {
393  MlirType t = mlirF16TypeGet(context->get());
394  return PyF16Type(context->getRef(), t);
395  },
396  nb::arg("context").none() = nb::none(), "Create a f16 type.");
397  }
398 };
399 
400 /// Floating Point Type subclass - TF32Type.
401 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
402 public:
403  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
404  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
406  static constexpr const char *pyClassName = "FloatTF32Type";
408 
409  static void bindDerived(ClassTy &c) {
410  c.def_static(
411  "get",
412  [](DefaultingPyMlirContext context) {
413  MlirType t = mlirTF32TypeGet(context->get());
414  return PyTF32Type(context->getRef(), t);
415  },
416  nb::arg("context").none() = nb::none(), "Create a tf32 type.");
417  }
418 };
419 
420 /// Floating Point Type subclass - F32Type.
421 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
422 public:
423  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
424  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
426  static constexpr const char *pyClassName = "F32Type";
428 
429  static void bindDerived(ClassTy &c) {
430  c.def_static(
431  "get",
432  [](DefaultingPyMlirContext context) {
433  MlirType t = mlirF32TypeGet(context->get());
434  return PyF32Type(context->getRef(), t);
435  },
436  nb::arg("context").none() = nb::none(), "Create a f32 type.");
437  }
438 };
439 
440 /// Floating Point Type subclass - F64Type.
441 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
442 public:
443  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
444  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
446  static constexpr const char *pyClassName = "F64Type";
448 
449  static void bindDerived(ClassTy &c) {
450  c.def_static(
451  "get",
452  [](DefaultingPyMlirContext context) {
453  MlirType t = mlirF64TypeGet(context->get());
454  return PyF64Type(context->getRef(), t);
455  },
456  nb::arg("context").none() = nb::none(), "Create a f64 type.");
457  }
458 };
459 
460 /// None Type subclass - NoneType.
461 class PyNoneType : public PyConcreteType<PyNoneType> {
462 public:
463  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
464  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
466  static constexpr const char *pyClassName = "NoneType";
468 
469  static void bindDerived(ClassTy &c) {
470  c.def_static(
471  "get",
472  [](DefaultingPyMlirContext context) {
473  MlirType t = mlirNoneTypeGet(context->get());
474  return PyNoneType(context->getRef(), t);
475  },
476  nb::arg("context").none() = nb::none(), "Create a none type.");
477  }
478 };
479 
480 /// Complex Type subclass - ComplexType.
481 class PyComplexType : public PyConcreteType<PyComplexType> {
482 public:
483  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
484  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
486  static constexpr const char *pyClassName = "ComplexType";
488 
489  static void bindDerived(ClassTy &c) {
490  c.def_static(
491  "get",
492  [](PyType &elementType) {
493  // The element must be a floating point or integer scalar type.
494  if (mlirTypeIsAIntegerOrFloat(elementType)) {
495  MlirType t = mlirComplexTypeGet(elementType);
496  return PyComplexType(elementType.getContext(), t);
497  }
498  throw nb::value_error(
499  (Twine("invalid '") +
500  nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
501  "' and expected floating point or integer type.")
502  .str()
503  .c_str());
504  },
505  "Create a complex type");
506  c.def_prop_ro(
507  "element_type",
508  [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
509  "Returns element type.");
510  }
511 };
512 
513 } // namespace
514 
515 // Shaped Type Interface - ShapedType
517  c.def_prop_ro(
518  "element_type",
519  [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
520  "Returns the element type of the shaped type.");
521  c.def_prop_ro(
522  "has_rank",
523  [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
524  "Returns whether the given shaped type is ranked.");
525  c.def_prop_ro(
526  "rank",
527  [](PyShapedType &self) {
528  self.requireHasRank();
529  return mlirShapedTypeGetRank(self);
530  },
531  "Returns the rank of the given ranked shaped type.");
532  c.def_prop_ro(
533  "has_static_shape",
534  [](PyShapedType &self) -> bool {
535  return mlirShapedTypeHasStaticShape(self);
536  },
537  "Returns whether the given shaped type has a static shape.");
538  c.def(
539  "is_dynamic_dim",
540  [](PyShapedType &self, intptr_t dim) -> bool {
541  self.requireHasRank();
542  return mlirShapedTypeIsDynamicDim(self, dim);
543  },
544  nb::arg("dim"),
545  "Returns whether the dim-th dimension of the given shaped type is "
546  "dynamic.");
547  c.def(
548  "get_dim_size",
549  [](PyShapedType &self, intptr_t dim) {
550  self.requireHasRank();
551  return mlirShapedTypeGetDimSize(self, dim);
552  },
553  nb::arg("dim"),
554  "Returns the dim-th dimension of the given ranked shaped type.");
555  c.def_static(
556  "is_dynamic_size",
557  [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
558  nb::arg("dim_size"),
559  "Returns whether the given dimension size indicates a dynamic "
560  "dimension.");
561  c.def(
562  "is_dynamic_stride_or_offset",
563  [](PyShapedType &self, int64_t val) -> bool {
564  self.requireHasRank();
566  },
567  nb::arg("dim_size"),
568  "Returns whether the given value is used as a placeholder for dynamic "
569  "strides and offsets in shaped types.");
570  c.def_prop_ro(
571  "shape",
572  [](PyShapedType &self) {
573  self.requireHasRank();
574 
575  std::vector<int64_t> shape;
576  int64_t rank = mlirShapedTypeGetRank(self);
577  shape.reserve(rank);
578  for (int64_t i = 0; i < rank; ++i)
579  shape.push_back(mlirShapedTypeGetDimSize(self, i));
580  return shape;
581  },
582  "Returns the shape of the ranked shaped type as a list of integers.");
583  c.def_static(
584  "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
585  "Returns the value used to indicate dynamic dimensions in shaped "
586  "types.");
587  c.def_static(
588  "get_dynamic_stride_or_offset",
589  []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
590  "Returns the value used to indicate dynamic strides or offsets in "
591  "shaped types.");
592 }
593 
594 void mlir::PyShapedType::requireHasRank() {
595  if (!mlirShapedTypeHasRank(*this)) {
596  throw nb::value_error(
597  "calling this method requires that the type has a rank.");
598  }
599 }
600 
603 
604 namespace {
605 
606 /// Vector Type subclass - VectorType.
607 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
608 public:
609  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
610  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
612  static constexpr const char *pyClassName = "VectorType";
614 
615  static void bindDerived(ClassTy &c) {
616  c.def_static("get", &PyVectorType::get, nb::arg("shape"),
617  nb::arg("element_type"), nb::kw_only(),
618  nb::arg("scalable").none() = nb::none(),
619  nb::arg("scalable_dims").none() = nb::none(),
620  nb::arg("loc").none() = nb::none(), "Create a vector type")
621  .def_prop_ro(
622  "scalable",
623  [](MlirType self) { return mlirVectorTypeIsScalable(self); })
624  .def_prop_ro("scalable_dims", [](MlirType self) {
625  std::vector<bool> scalableDims;
626  size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
627  scalableDims.reserve(rank);
628  for (size_t i = 0; i < rank; ++i)
629  scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
630  return scalableDims;
631  });
632  }
633 
634 private:
635  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
636  std::optional<nb::list> scalable,
637  std::optional<std::vector<int64_t>> scalableDims,
638  DefaultingPyLocation loc) {
639  if (scalable && scalableDims) {
640  throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
641  "are mutually exclusive.");
642  }
643 
644  PyMlirContext::ErrorCapture errors(loc->getContext());
645  MlirType type;
646  if (scalable) {
647  if (scalable->size() != shape.size())
648  throw nb::value_error("Expected len(scalable) == len(shape).");
649 
650  SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
651  *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
652  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
653  scalableDimFlags.data(),
654  elementType);
655  } else if (scalableDims) {
656  SmallVector<bool> scalableDimFlags(shape.size(), false);
657  for (int64_t dim : *scalableDims) {
658  if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
659  throw nb::value_error("Scalable dimension index out of bounds.");
660  scalableDimFlags[dim] = true;
661  }
662  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
663  scalableDimFlags.data(),
664  elementType);
665  } else {
666  type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
667  elementType);
668  }
669  if (mlirTypeIsNull(type))
670  throw MLIRError("Invalid type", errors.take());
671  return PyVectorType(elementType.getContext(), type);
672  }
673 };
674 
675 /// Ranked Tensor Type subclass - RankedTensorType.
676 class PyRankedTensorType
677  : public PyConcreteType<PyRankedTensorType, PyShapedType> {
678 public:
679  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
680  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
682  static constexpr const char *pyClassName = "RankedTensorType";
684 
685  static void bindDerived(ClassTy &c) {
686  c.def_static(
687  "get",
688  [](std::vector<int64_t> shape, PyType &elementType,
689  std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
690  PyMlirContext::ErrorCapture errors(loc->getContext());
691  MlirType t = mlirRankedTensorTypeGetChecked(
692  loc, shape.size(), shape.data(), elementType,
693  encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
694  if (mlirTypeIsNull(t))
695  throw MLIRError("Invalid type", errors.take());
696  return PyRankedTensorType(elementType.getContext(), t);
697  },
698  nb::arg("shape"), nb::arg("element_type"),
699  nb::arg("encoding").none() = nb::none(),
700  nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
701  c.def_prop_ro("encoding",
702  [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
703  MlirAttribute encoding =
705  if (mlirAttributeIsNull(encoding))
706  return std::nullopt;
707  return encoding;
708  });
709  }
710 };
711 
712 /// Unranked Tensor Type subclass - UnrankedTensorType.
713 class PyUnrankedTensorType
714  : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
715 public:
716  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
717  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
719  static constexpr const char *pyClassName = "UnrankedTensorType";
721 
722  static void bindDerived(ClassTy &c) {
723  c.def_static(
724  "get",
725  [](PyType &elementType, DefaultingPyLocation loc) {
726  PyMlirContext::ErrorCapture errors(loc->getContext());
727  MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
728  if (mlirTypeIsNull(t))
729  throw MLIRError("Invalid type", errors.take());
730  return PyUnrankedTensorType(elementType.getContext(), t);
731  },
732  nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
733  "Create a unranked tensor type");
734  }
735 };
736 
737 /// Ranked MemRef Type subclass - MemRefType.
738 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
739 public:
740  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
741  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
743  static constexpr const char *pyClassName = "MemRefType";
745 
746  static void bindDerived(ClassTy &c) {
747  c.def_static(
748  "get",
749  [](std::vector<int64_t> shape, PyType &elementType,
750  PyAttribute *layout, PyAttribute *memorySpace,
751  DefaultingPyLocation loc) {
752  PyMlirContext::ErrorCapture errors(loc->getContext());
753  MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
754  MlirAttribute memSpaceAttr =
755  memorySpace ? *memorySpace : mlirAttributeGetNull();
756  MlirType t =
757  mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
758  shape.data(), layoutAttr, memSpaceAttr);
759  if (mlirTypeIsNull(t))
760  throw MLIRError("Invalid type", errors.take());
761  return PyMemRefType(elementType.getContext(), t);
762  },
763  nb::arg("shape"), nb::arg("element_type"),
764  nb::arg("layout").none() = nb::none(),
765  nb::arg("memory_space").none() = nb::none(),
766  nb::arg("loc").none() = nb::none(), "Create a memref type")
767  .def_prop_ro(
768  "layout",
769  [](PyMemRefType &self) -> MlirAttribute {
770  return mlirMemRefTypeGetLayout(self);
771  },
772  "The layout of the MemRef type.")
773  .def(
774  "get_strides_and_offset",
775  [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
776  std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
777  int64_t offset;
779  self, strides.data(), &offset)))
780  throw std::runtime_error(
781  "Failed to extract strides and offset from memref.");
782  return {strides, offset};
783  },
784  "The strides and offset of the MemRef type.")
785  .def_prop_ro(
786  "affine_map",
787  [](PyMemRefType &self) -> PyAffineMap {
788  MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
789  return PyAffineMap(self.getContext(), map);
790  },
791  "The layout of the MemRef type as an affine map.")
792  .def_prop_ro(
793  "memory_space",
794  [](PyMemRefType &self) -> std::optional<MlirAttribute> {
795  MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
796  if (mlirAttributeIsNull(a))
797  return std::nullopt;
798  return a;
799  },
800  "Returns the memory space of the given MemRef type.");
801  }
802 };
803 
804 /// Unranked MemRef Type subclass - UnrankedMemRefType.
805 class PyUnrankedMemRefType
806  : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
807 public:
808  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
809  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
811  static constexpr const char *pyClassName = "UnrankedMemRefType";
813 
814  static void bindDerived(ClassTy &c) {
815  c.def_static(
816  "get",
817  [](PyType &elementType, PyAttribute *memorySpace,
818  DefaultingPyLocation loc) {
819  PyMlirContext::ErrorCapture errors(loc->getContext());
820  MlirAttribute memSpaceAttr = {};
821  if (memorySpace)
822  memSpaceAttr = *memorySpace;
823 
824  MlirType t =
825  mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
826  if (mlirTypeIsNull(t))
827  throw MLIRError("Invalid type", errors.take());
828  return PyUnrankedMemRefType(elementType.getContext(), t);
829  },
830  nb::arg("element_type"), nb::arg("memory_space").none(),
831  nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
832  .def_prop_ro(
833  "memory_space",
834  [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
835  MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
836  if (mlirAttributeIsNull(a))
837  return std::nullopt;
838  return a;
839  },
840  "Returns the memory space of the given Unranked MemRef type.");
841  }
842 };
843 
844 /// Tuple Type subclass - TupleType.
845 class PyTupleType : public PyConcreteType<PyTupleType> {
846 public:
847  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
848  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
850  static constexpr const char *pyClassName = "TupleType";
852 
853  static void bindDerived(ClassTy &c) {
854  c.def_static(
855  "get_tuple",
856  [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
857  MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
858  elements.data());
859  return PyTupleType(context->getRef(), t);
860  },
861  nb::arg("elements"), nb::arg("context").none() = nb::none(),
862  "Create a tuple type");
863  c.def(
864  "get_type",
865  [](PyTupleType &self, intptr_t pos) {
866  return mlirTupleTypeGetType(self, pos);
867  },
868  nb::arg("pos"), "Returns the pos-th type in the tuple type.");
869  c.def_prop_ro(
870  "num_types",
871  [](PyTupleType &self) -> intptr_t {
872  return mlirTupleTypeGetNumTypes(self);
873  },
874  "Returns the number of types contained in a tuple.");
875  }
876 };
877 
878 /// Function type.
879 class PyFunctionType : public PyConcreteType<PyFunctionType> {
880 public:
881  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
882  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
884  static constexpr const char *pyClassName = "FunctionType";
886 
887  static void bindDerived(ClassTy &c) {
888  c.def_static(
889  "get",
890  [](std::vector<MlirType> inputs, std::vector<MlirType> results,
891  DefaultingPyMlirContext context) {
892  MlirType t =
893  mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
894  results.size(), results.data());
895  return PyFunctionType(context->getRef(), t);
896  },
897  nb::arg("inputs"), nb::arg("results"),
898  nb::arg("context").none() = nb::none(),
899  "Gets a FunctionType from a list of input and result types");
900  c.def_prop_ro(
901  "inputs",
902  [](PyFunctionType &self) {
903  MlirType t = self;
904  nb::list types;
905  for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
906  ++i) {
907  types.append(mlirFunctionTypeGetInput(t, i));
908  }
909  return types;
910  },
911  "Returns the list of input types in the FunctionType.");
912  c.def_prop_ro(
913  "results",
914  [](PyFunctionType &self) {
915  nb::list types;
916  for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
917  ++i) {
918  types.append(mlirFunctionTypeGetResult(self, i));
919  }
920  return types;
921  },
922  "Returns the list of result types in the FunctionType.");
923  }
924 };
925 
926 static MlirStringRef toMlirStringRef(const std::string &s) {
927  return mlirStringRefCreate(s.data(), s.size());
928 }
929 
930 /// Opaque Type subclass - OpaqueType.
931 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
932 public:
933  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
934  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
936  static constexpr const char *pyClassName = "OpaqueType";
938 
939  static void bindDerived(ClassTy &c) {
940  c.def_static(
941  "get",
942  [](std::string dialectNamespace, std::string typeData,
943  DefaultingPyMlirContext context) {
944  MlirType type = mlirOpaqueTypeGet(context->get(),
945  toMlirStringRef(dialectNamespace),
946  toMlirStringRef(typeData));
947  return PyOpaqueType(context->getRef(), type);
948  },
949  nb::arg("dialect_namespace"), nb::arg("buffer"),
950  nb::arg("context").none() = nb::none(),
951  "Create an unregistered (opaque) dialect type.");
952  c.def_prop_ro(
953  "dialect_namespace",
954  [](PyOpaqueType &self) {
956  return nb::str(stringRef.data, stringRef.length);
957  },
958  "Returns the dialect namespace for the Opaque type as a string.");
959  c.def_prop_ro(
960  "data",
961  [](PyOpaqueType &self) {
962  MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
963  return nb::str(stringRef.data, stringRef.length);
964  },
965  "Returns the data for the Opaque type as a string.");
966  }
967 };
968 
969 } // namespace
970 
971 void mlir::python::populateIRTypes(nb::module_ &m) {
972  PyIntegerType::bind(m);
973  PyFloatType::bind(m);
974  PyIndexType::bind(m);
975  PyFloat4E2M1FNType::bind(m);
976  PyFloat6E2M3FNType::bind(m);
977  PyFloat6E3M2FNType::bind(m);
978  PyFloat8E4M3FNType::bind(m);
979  PyFloat8E5M2Type::bind(m);
980  PyFloat8E4M3Type::bind(m);
981  PyFloat8E4M3FNUZType::bind(m);
982  PyFloat8E4M3B11FNUZType::bind(m);
983  PyFloat8E5M2FNUZType::bind(m);
984  PyFloat8E3M4Type::bind(m);
985  PyFloat8E8M0FNUType::bind(m);
986  PyBF16Type::bind(m);
987  PyF16Type::bind(m);
988  PyTF32Type::bind(m);
989  PyF32Type::bind(m);
990  PyF64Type::bind(m);
991  PyNoneType::bind(m);
992  PyComplexType::bind(m);
994  PyVectorType::bind(m);
995  PyRankedTensorType::bind(m);
996  PyUnrankedTensorType::bind(m);
997  PyMemRefType::bind(m);
998  PyUnrankedMemRefType::bind(m);
999  PyTupleType::bind(m);
1000  PyFunctionType::bind(m);
1001  PyOpaqueType::bind(m);
1002 }
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:210
static MLIRContext * getContext(OpFoldResult val)
Shaped Type Interface - ShapedType.
Definition: IRTypes.h:17
static const IsAFunctionTy isaFunction
Definition: IRTypes.h:19
static void bindDerived(ClassTy &c)
Definition: IRTypes.cpp:516
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:310
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:517
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:291
ReferrentTy * get() const
Definition: NanobindUtils.h:55
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1003
CRTP base classes for Python types that subclass Type and should be castable from it (i....
Definition: IRModule.h:927
static void bind(nanobind::module_ &m)
Definition: IRModule.h:956
Wrapper around the generic MlirType.
Definition: IRModule.h:877
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void)
Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void)
Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void)
Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void)
Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void)
Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void)
Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void)
Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void)
Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void)
Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void)
Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void)
Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void)
Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void)
Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void)
Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void)
Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks wither the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void)
Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void)
Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void)
Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void)
Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void)
Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void)
Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void)
Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void)
Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void)
Returns the value indicating a dynamic size in a shaped type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void)
Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1043
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:1008
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateIRTypes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1294
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:426