MLIR 22.0.0git
IRTypes.cpp
Go to the documentation of this file.
1//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
10#include "IRModule.h"
12// clang-format on
13
14#include <optional>
15
16#include "IRModule.h"
17#include "NanobindUtils.h"
19#include "mlir-c/BuiltinTypes.h"
20#include "mlir-c/Support.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24using namespace mlir::python;
25
27using llvm::Twine;
28
29namespace {
30
31/// Checks whether the given type is an integer or float type.
32static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35}
36
37class PyIntegerType : public PyConcreteType<PyIntegerType> {
38public:
39 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
42 static constexpr const char *pyClassName = "IntegerType";
44
45 static void bindDerived(ClassTy &c) {
46 c.def_static(
47 "get_signless",
48 [](unsigned width, DefaultingPyMlirContext context) {
49 MlirType t = mlirIntegerTypeGet(context->get(), width);
50 return PyIntegerType(context->getRef(), t);
51 },
52 nb::arg("width"), nb::arg("context") = nb::none(),
53 "Create a signless integer type");
54 c.def_static(
55 "get_signed",
56 [](unsigned width, DefaultingPyMlirContext context) {
57 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58 return PyIntegerType(context->getRef(), t);
59 },
60 nb::arg("width"), nb::arg("context") = nb::none(),
61 "Create a signed integer type");
62 c.def_static(
63 "get_unsigned",
64 [](unsigned width, DefaultingPyMlirContext context) {
65 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66 return PyIntegerType(context->getRef(), t);
67 },
68 nb::arg("width"), nb::arg("context") = nb::none(),
69 "Create an unsigned integer type");
70 c.def_prop_ro(
71 "width",
72 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73 "Returns the width of the integer type");
74 c.def_prop_ro(
75 "is_signless",
76 [](PyIntegerType &self) -> bool {
77 return mlirIntegerTypeIsSignless(self);
78 },
79 "Returns whether this is a signless integer");
80 c.def_prop_ro(
81 "is_signed",
82 [](PyIntegerType &self) -> bool {
83 return mlirIntegerTypeIsSigned(self);
84 },
85 "Returns whether this is a signed integer");
86 c.def_prop_ro(
87 "is_unsigned",
88 [](PyIntegerType &self) -> bool {
89 return mlirIntegerTypeIsUnsigned(self);
90 },
91 "Returns whether this is an unsigned integer");
92 }
93};
94
95/// Index Type subclass - IndexType.
96class PyIndexType : public PyConcreteType<PyIndexType> {
97public:
98 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
101 static constexpr const char *pyClassName = "IndexType";
103
104 static void bindDerived(ClassTy &c) {
105 c.def_static(
106 "get",
107 [](DefaultingPyMlirContext context) {
108 MlirType t = mlirIndexTypeGet(context->get());
109 return PyIndexType(context->getRef(), t);
110 },
111 nb::arg("context") = nb::none(), "Create a index type.");
112 }
113};
114
115class PyFloatType : public PyConcreteType<PyFloatType> {
116public:
117 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118 static constexpr const char *pyClassName = "FloatType";
120
121 static void bindDerived(ClassTy &c) {
122 c.def_prop_ro(
123 "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124 "Returns the width of the floating-point type");
125 }
126};
127
128/// Floating Point Type subclass - Float4E2M1FNType.
129class PyFloat4E2M1FNType
130 : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131public:
132 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
135 static constexpr const char *pyClassName = "Float4E2M1FNType";
137
138 static void bindDerived(ClassTy &c) {
139 c.def_static(
140 "get",
141 [](DefaultingPyMlirContext context) {
142 MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143 return PyFloat4E2M1FNType(context->getRef(), t);
144 },
145 nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
146 }
147};
148
149/// Floating Point Type subclass - Float6E2M3FNType.
150class PyFloat6E2M3FNType
151 : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152public:
153 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
156 static constexpr const char *pyClassName = "Float6E2M3FNType";
158
159 static void bindDerived(ClassTy &c) {
160 c.def_static(
161 "get",
162 [](DefaultingPyMlirContext context) {
163 MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164 return PyFloat6E2M3FNType(context->getRef(), t);
165 },
166 nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
167 }
168};
169
170/// Floating Point Type subclass - Float6E3M2FNType.
171class PyFloat6E3M2FNType
172 : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173public:
174 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
177 static constexpr const char *pyClassName = "Float6E3M2FNType";
179
180 static void bindDerived(ClassTy &c) {
181 c.def_static(
182 "get",
183 [](DefaultingPyMlirContext context) {
184 MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185 return PyFloat6E3M2FNType(context->getRef(), t);
186 },
187 nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
188 }
189};
190
191/// Floating Point Type subclass - Float8E4M3FNType.
192class PyFloat8E4M3FNType
193 : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194public:
195 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
198 static constexpr const char *pyClassName = "Float8E4M3FNType";
200
201 static void bindDerived(ClassTy &c) {
202 c.def_static(
203 "get",
204 [](DefaultingPyMlirContext context) {
205 MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206 return PyFloat8E4M3FNType(context->getRef(), t);
207 },
208 nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
209 }
210};
211
212/// Floating Point Type subclass - Float8E5M2Type.
213class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214public:
215 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
218 static constexpr const char *pyClassName = "Float8E5M2Type";
220
221 static void bindDerived(ClassTy &c) {
222 c.def_static(
223 "get",
224 [](DefaultingPyMlirContext context) {
225 MlirType t = mlirFloat8E5M2TypeGet(context->get());
226 return PyFloat8E5M2Type(context->getRef(), t);
227 },
228 nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
229 }
230};
231
232/// Floating Point Type subclass - Float8E4M3Type.
233class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234public:
235 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
238 static constexpr const char *pyClassName = "Float8E4M3Type";
240
241 static void bindDerived(ClassTy &c) {
242 c.def_static(
243 "get",
244 [](DefaultingPyMlirContext context) {
245 MlirType t = mlirFloat8E4M3TypeGet(context->get());
246 return PyFloat8E4M3Type(context->getRef(), t);
247 },
248 nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
249 }
250};
251
252/// Floating Point Type subclass - Float8E4M3FNUZ.
253class PyFloat8E4M3FNUZType
254 : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255public:
256 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
259 static constexpr const char *pyClassName = "Float8E4M3FNUZType";
261
262 static void bindDerived(ClassTy &c) {
263 c.def_static(
264 "get",
265 [](DefaultingPyMlirContext context) {
266 MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267 return PyFloat8E4M3FNUZType(context->getRef(), t);
268 },
269 nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
270 }
271};
272
273/// Floating Point Type subclass - Float8E4M3B11FNUZ.
274class PyFloat8E4M3B11FNUZType
275 : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
276public:
277 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
278 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280 static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282
283 static void bindDerived(ClassTy &c) {
284 c.def_static(
285 "get",
286 [](DefaultingPyMlirContext context) {
287 MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
288 return PyFloat8E4M3B11FNUZType(context->getRef(), t);
289 },
290 nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
291 }
292};
293
294/// Floating Point Type subclass - Float8E5M2FNUZ.
295class PyFloat8E5M2FNUZType
296 : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
297public:
298 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
299 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
301 static constexpr const char *pyClassName = "Float8E5M2FNUZType";
303
304 static void bindDerived(ClassTy &c) {
305 c.def_static(
306 "get",
307 [](DefaultingPyMlirContext context) {
308 MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
309 return PyFloat8E5M2FNUZType(context->getRef(), t);
310 },
311 nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
312 }
313};
314
315/// Floating Point Type subclass - Float8E3M4Type.
316class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
317public:
318 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
319 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
321 static constexpr const char *pyClassName = "Float8E3M4Type";
323
324 static void bindDerived(ClassTy &c) {
325 c.def_static(
326 "get",
327 [](DefaultingPyMlirContext context) {
328 MlirType t = mlirFloat8E3M4TypeGet(context->get());
329 return PyFloat8E3M4Type(context->getRef(), t);
330 },
331 nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
332 }
333};
334
335/// Floating Point Type subclass - Float8E8M0FNUType.
336class PyFloat8E8M0FNUType
337 : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
338public:
339 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
340 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
342 static constexpr const char *pyClassName = "Float8E8M0FNUType";
344
345 static void bindDerived(ClassTy &c) {
346 c.def_static(
347 "get",
348 [](DefaultingPyMlirContext context) {
349 MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
350 return PyFloat8E8M0FNUType(context->getRef(), t);
351 },
352 nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
353 }
354};
355
356/// Floating Point Type subclass - BF16Type.
357class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
358public:
359 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
360 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
362 static constexpr const char *pyClassName = "BF16Type";
364
365 static void bindDerived(ClassTy &c) {
366 c.def_static(
367 "get",
368 [](DefaultingPyMlirContext context) {
369 MlirType t = mlirBF16TypeGet(context->get());
370 return PyBF16Type(context->getRef(), t);
371 },
372 nb::arg("context") = nb::none(), "Create a bf16 type.");
373 }
374};
375
376/// Floating Point Type subclass - F16Type.
377class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
378public:
379 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
380 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
382 static constexpr const char *pyClassName = "F16Type";
384
385 static void bindDerived(ClassTy &c) {
386 c.def_static(
387 "get",
388 [](DefaultingPyMlirContext context) {
389 MlirType t = mlirF16TypeGet(context->get());
390 return PyF16Type(context->getRef(), t);
391 },
392 nb::arg("context") = nb::none(), "Create a f16 type.");
393 }
394};
395
396/// Floating Point Type subclass - TF32Type.
397class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
398public:
399 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
400 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
402 static constexpr const char *pyClassName = "FloatTF32Type";
404
405 static void bindDerived(ClassTy &c) {
406 c.def_static(
407 "get",
408 [](DefaultingPyMlirContext context) {
409 MlirType t = mlirTF32TypeGet(context->get());
410 return PyTF32Type(context->getRef(), t);
411 },
412 nb::arg("context") = nb::none(), "Create a tf32 type.");
413 }
414};
415
416/// Floating Point Type subclass - F32Type.
417class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
418public:
419 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
420 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
422 static constexpr const char *pyClassName = "F32Type";
424
425 static void bindDerived(ClassTy &c) {
426 c.def_static(
427 "get",
428 [](DefaultingPyMlirContext context) {
429 MlirType t = mlirF32TypeGet(context->get());
430 return PyF32Type(context->getRef(), t);
431 },
432 nb::arg("context") = nb::none(), "Create a f32 type.");
433 }
434};
435
436/// Floating Point Type subclass - F64Type.
437class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
438public:
439 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
440 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
442 static constexpr const char *pyClassName = "F64Type";
444
445 static void bindDerived(ClassTy &c) {
446 c.def_static(
447 "get",
448 [](DefaultingPyMlirContext context) {
449 MlirType t = mlirF64TypeGet(context->get());
450 return PyF64Type(context->getRef(), t);
451 },
452 nb::arg("context") = nb::none(), "Create a f64 type.");
453 }
454};
455
456/// None Type subclass - NoneType.
457class PyNoneType : public PyConcreteType<PyNoneType> {
458public:
459 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
460 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
462 static constexpr const char *pyClassName = "NoneType";
464
465 static void bindDerived(ClassTy &c) {
466 c.def_static(
467 "get",
468 [](DefaultingPyMlirContext context) {
469 MlirType t = mlirNoneTypeGet(context->get());
470 return PyNoneType(context->getRef(), t);
471 },
472 nb::arg("context") = nb::none(), "Create a none type.");
473 }
474};
475
476/// Complex Type subclass - ComplexType.
477class PyComplexType : public PyConcreteType<PyComplexType> {
478public:
479 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
480 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
482 static constexpr const char *pyClassName = "ComplexType";
484
485 static void bindDerived(ClassTy &c) {
486 c.def_static(
487 "get",
488 [](PyType &elementType) {
489 // The element must be a floating point or integer scalar type.
490 if (mlirTypeIsAIntegerOrFloat(elementType)) {
491 MlirType t = mlirComplexTypeGet(elementType);
492 return PyComplexType(elementType.getContext(), t);
493 }
494 throw nb::value_error(
495 (Twine("invalid '") +
496 nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
497 "' and expected floating point or integer type.")
498 .str()
499 .c_str());
500 },
501 "Create a complex type");
502 c.def_prop_ro(
503 "element_type",
504 [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
505 return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
506 .maybeDownCast();
507 },
508 "Returns element type.");
509 }
510};
511
512} // namespace
513
514// Shaped Type Interface - ShapedType
516 c.def_prop_ro(
517 "element_type",
518 [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
520 .maybeDownCast();
521 },
522 "Returns the element type of the shaped type.");
523 c.def_prop_ro(
524 "has_rank",
525 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
526 "Returns whether the given shaped type is ranked.");
527 c.def_prop_ro(
528 "rank",
529 [](PyShapedType &self) {
530 self.requireHasRank();
531 return mlirShapedTypeGetRank(self);
532 },
533 "Returns the rank of the given ranked shaped type.");
534 c.def_prop_ro(
535 "has_static_shape",
536 [](PyShapedType &self) -> bool {
537 return mlirShapedTypeHasStaticShape(self);
538 },
539 "Returns whether the given shaped type has a static shape.");
540 c.def(
541 "is_dynamic_dim",
542 [](PyShapedType &self, intptr_t dim) -> bool {
543 self.requireHasRank();
544 return mlirShapedTypeIsDynamicDim(self, dim);
545 },
546 nb::arg("dim"),
547 "Returns whether the dim-th dimension of the given shaped type is "
548 "dynamic.");
549 c.def(
550 "is_static_dim",
551 [](PyShapedType &self, intptr_t dim) -> bool {
552 self.requireHasRank();
553 return mlirShapedTypeIsStaticDim(self, dim);
554 },
555 nb::arg("dim"),
556 "Returns whether the dim-th dimension of the given shaped type is "
557 "static.");
558 c.def(
559 "get_dim_size",
560 [](PyShapedType &self, intptr_t dim) {
561 self.requireHasRank();
562 return mlirShapedTypeGetDimSize(self, dim);
563 },
564 nb::arg("dim"),
565 "Returns the dim-th dimension of the given ranked shaped type.");
566 c.def_static(
567 "is_dynamic_size",
568 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
569 nb::arg("dim_size"),
570 "Returns whether the given dimension size indicates a dynamic "
571 "dimension.");
572 c.def_static(
573 "is_static_size",
574 [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
575 nb::arg("dim_size"),
576 "Returns whether the given dimension size indicates a static "
577 "dimension.");
578 c.def(
579 "is_dynamic_stride_or_offset",
580 [](PyShapedType &self, int64_t val) -> bool {
581 self.requireHasRank();
583 },
584 nb::arg("dim_size"),
585 "Returns whether the given value is used as a placeholder for dynamic "
586 "strides and offsets in shaped types.");
587 c.def(
588 "is_static_stride_or_offset",
589 [](PyShapedType &self, int64_t val) -> bool {
590 self.requireHasRank();
592 },
593 nb::arg("dim_size"),
594 "Returns whether the given shaped type stride or offset value is "
595 "statically-sized.");
596 c.def_prop_ro(
597 "shape",
598 [](PyShapedType &self) {
599 self.requireHasRank();
600
601 std::vector<int64_t> shape;
602 int64_t rank = mlirShapedTypeGetRank(self);
603 shape.reserve(rank);
604 for (int64_t i = 0; i < rank; ++i)
605 shape.push_back(mlirShapedTypeGetDimSize(self, i));
606 return shape;
607 },
608 "Returns the shape of the ranked shaped type as a list of integers.");
609 c.def_static(
610 "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
611 "Returns the value used to indicate dynamic dimensions in shaped "
612 "types.");
613 c.def_static(
614 "get_dynamic_stride_or_offset",
616 "Returns the value used to indicate dynamic strides or offsets in "
617 "shaped types.");
618}
619
620void mlir::PyShapedType::requireHasRank() {
621 if (!mlirShapedTypeHasRank(*this)) {
622 throw nb::value_error(
623 "calling this method requires that the type has a rank.");
624 }
625}
626
629
630namespace {
631
632/// Vector Type subclass - VectorType.
633class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
634public:
635 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
636 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
638 static constexpr const char *pyClassName = "VectorType";
640
641 static void bindDerived(ClassTy &c) {
642 c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
643 nb::arg("element_type"), nb::kw_only(),
644 nb::arg("scalable") = nb::none(),
645 nb::arg("scalable_dims") = nb::none(),
646 nb::arg("loc") = nb::none(), "Create a vector type")
647 .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
648 nb::arg("element_type"), nb::kw_only(),
649 nb::arg("scalable") = nb::none(),
650 nb::arg("scalable_dims") = nb::none(),
651 nb::arg("context") = nb::none(), "Create a vector type")
652 .def_prop_ro(
653 "scalable",
654 [](MlirType self) { return mlirVectorTypeIsScalable(self); })
655 .def_prop_ro("scalable_dims", [](MlirType self) {
656 std::vector<bool> scalableDims;
657 size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
658 scalableDims.reserve(rank);
659 for (size_t i = 0; i < rank; ++i)
660 scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
661 return scalableDims;
662 });
663 }
664
665private:
666 static PyVectorType
667 getChecked(std::vector<int64_t> shape, PyType &elementType,
668 std::optional<nb::list> scalable,
669 std::optional<std::vector<int64_t>> scalableDims,
670 DefaultingPyLocation loc) {
671 if (scalable && scalableDims) {
672 throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
673 "are mutually exclusive.");
674 }
675
676 PyMlirContext::ErrorCapture errors(loc->getContext());
677 MlirType type;
678 if (scalable) {
679 if (scalable->size() != shape.size())
680 throw nb::value_error("Expected len(scalable) == len(shape).");
681
682 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
683 *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
684 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
685 scalableDimFlags.data(),
686 elementType);
687 } else if (scalableDims) {
688 SmallVector<bool> scalableDimFlags(shape.size(), false);
689 for (int64_t dim : *scalableDims) {
690 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
691 throw nb::value_error("Scalable dimension index out of bounds.");
692 scalableDimFlags[dim] = true;
693 }
694 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
695 scalableDimFlags.data(),
696 elementType);
697 } else {
698 type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
699 elementType);
700 }
701 if (mlirTypeIsNull(type))
702 throw MLIRError("Invalid type", errors.take());
703 return PyVectorType(elementType.getContext(), type);
704 }
705
706 static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
707 std::optional<nb::list> scalable,
708 std::optional<std::vector<int64_t>> scalableDims,
709 DefaultingPyMlirContext context) {
710 if (scalable && scalableDims) {
711 throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
712 "are mutually exclusive.");
713 }
714
715 PyMlirContext::ErrorCapture errors(context->getRef());
716 MlirType type;
717 if (scalable) {
718 if (scalable->size() != shape.size())
719 throw nb::value_error("Expected len(scalable) == len(shape).");
720
721 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
722 *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
723 type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
724 scalableDimFlags.data(), elementType);
725 } else if (scalableDims) {
726 SmallVector<bool> scalableDimFlags(shape.size(), false);
727 for (int64_t dim : *scalableDims) {
728 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
729 throw nb::value_error("Scalable dimension index out of bounds.");
730 scalableDimFlags[dim] = true;
731 }
732 type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
733 scalableDimFlags.data(), elementType);
734 } else {
735 type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
736 }
737 if (mlirTypeIsNull(type))
738 throw MLIRError("Invalid type", errors.take());
739 return PyVectorType(elementType.getContext(), type);
740 }
741};
742
743/// Ranked Tensor Type subclass - RankedTensorType.
744class PyRankedTensorType
745 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
746public:
747 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
748 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
750 static constexpr const char *pyClassName = "RankedTensorType";
752
753 static void bindDerived(ClassTy &c) {
754 c.def_static(
755 "get",
756 [](std::vector<int64_t> shape, PyType &elementType,
757 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
758 PyMlirContext::ErrorCapture errors(loc->getContext());
760 loc, shape.size(), shape.data(), elementType,
761 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
762 if (mlirTypeIsNull(t))
763 throw MLIRError("Invalid type", errors.take());
764 return PyRankedTensorType(elementType.getContext(), t);
765 },
766 nb::arg("shape"), nb::arg("element_type"),
767 nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
768 "Create a ranked tensor type");
769 c.def_static(
770 "get_unchecked",
771 [](std::vector<int64_t> shape, PyType &elementType,
772 std::optional<PyAttribute> &encodingAttr,
773 DefaultingPyMlirContext context) {
774 PyMlirContext::ErrorCapture errors(context->getRef());
775 MlirType t = mlirRankedTensorTypeGet(
776 shape.size(), shape.data(), elementType,
777 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
778 if (mlirTypeIsNull(t))
779 throw MLIRError("Invalid type", errors.take());
780 return PyRankedTensorType(elementType.getContext(), t);
781 },
782 nb::arg("shape"), nb::arg("element_type"),
783 nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
784 "Create a ranked tensor type");
785 c.def_prop_ro(
786 "encoding",
787 [](PyRankedTensorType &self)
788 -> std::optional<nb::typed<nb::object, PyAttribute>> {
789 MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
790 if (mlirAttributeIsNull(encoding))
791 return std::nullopt;
792 return PyAttribute(self.getContext(), encoding).maybeDownCast();
793 });
794 }
795};
796
797/// Unranked Tensor Type subclass - UnrankedTensorType.
798class PyUnrankedTensorType
799 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
800public:
801 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
802 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
804 static constexpr const char *pyClassName = "UnrankedTensorType";
806
807 static void bindDerived(ClassTy &c) {
808 c.def_static(
809 "get",
810 [](PyType &elementType, DefaultingPyLocation loc) {
811 PyMlirContext::ErrorCapture errors(loc->getContext());
812 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
813 if (mlirTypeIsNull(t))
814 throw MLIRError("Invalid type", errors.take());
815 return PyUnrankedTensorType(elementType.getContext(), t);
816 },
817 nb::arg("element_type"), nb::arg("loc") = nb::none(),
818 "Create a unranked tensor type");
819 c.def_static(
820 "get_unchecked",
821 [](PyType &elementType, DefaultingPyMlirContext context) {
822 PyMlirContext::ErrorCapture errors(context->getRef());
823 MlirType t = mlirUnrankedTensorTypeGet(elementType);
824 if (mlirTypeIsNull(t))
825 throw MLIRError("Invalid type", errors.take());
826 return PyUnrankedTensorType(elementType.getContext(), t);
827 },
828 nb::arg("element_type"), nb::arg("context") = nb::none(),
829 "Create a unranked tensor type");
830 }
831};
832
833/// Ranked MemRef Type subclass - MemRefType.
834class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
835public:
836 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
837 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
839 static constexpr const char *pyClassName = "MemRefType";
841
842 static void bindDerived(ClassTy &c) {
843 c.def_static(
844 "get",
845 [](std::vector<int64_t> shape, PyType &elementType,
846 PyAttribute *layout, PyAttribute *memorySpace,
847 DefaultingPyLocation loc) {
848 PyMlirContext::ErrorCapture errors(loc->getContext());
849 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
850 MlirAttribute memSpaceAttr =
851 memorySpace ? *memorySpace : mlirAttributeGetNull();
852 MlirType t =
853 mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
854 shape.data(), layoutAttr, memSpaceAttr);
855 if (mlirTypeIsNull(t))
856 throw MLIRError("Invalid type", errors.take());
857 return PyMemRefType(elementType.getContext(), t);
858 },
859 nb::arg("shape"), nb::arg("element_type"),
860 nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
861 nb::arg("loc") = nb::none(), "Create a memref type")
862 .def_static(
863 "get_unchecked",
864 [](std::vector<int64_t> shape, PyType &elementType,
865 PyAttribute *layout, PyAttribute *memorySpace,
866 DefaultingPyMlirContext context) {
867 PyMlirContext::ErrorCapture errors(context->getRef());
868 MlirAttribute layoutAttr =
869 layout ? *layout : mlirAttributeGetNull();
870 MlirAttribute memSpaceAttr =
871 memorySpace ? *memorySpace : mlirAttributeGetNull();
872 MlirType t =
873 mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
874 layoutAttr, memSpaceAttr);
875 if (mlirTypeIsNull(t))
876 throw MLIRError("Invalid type", errors.take());
877 return PyMemRefType(elementType.getContext(), t);
878 },
879 nb::arg("shape"), nb::arg("element_type"),
880 nb::arg("layout") = nb::none(),
881 nb::arg("memory_space") = nb::none(),
882 nb::arg("context") = nb::none(), "Create a memref type")
883 .def_prop_ro(
884 "layout",
885 [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
886 return PyAttribute(self.getContext(),
888 .maybeDownCast();
889 },
890 "The layout of the MemRef type.")
891 .def(
892 "get_strides_and_offset",
893 [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
894 std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
895 int64_t offset;
897 self, strides.data(), &offset)))
898 throw std::runtime_error(
899 "Failed to extract strides and offset from memref.");
900 return {strides, offset};
901 },
902 "The strides and offset of the MemRef type.")
903 .def_prop_ro(
904 "affine_map",
905 [](PyMemRefType &self) -> PyAffineMap {
906 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
907 return PyAffineMap(self.getContext(), map);
908 },
909 "The layout of the MemRef type as an affine map.")
910 .def_prop_ro(
911 "memory_space",
912 [](PyMemRefType &self)
913 -> std::optional<nb::typed<nb::object, PyAttribute>> {
914 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
915 if (mlirAttributeIsNull(a))
916 return std::nullopt;
917 return PyAttribute(self.getContext(), a).maybeDownCast();
918 },
919 "Returns the memory space of the given MemRef type.");
920 }
921};
922
923/// Unranked MemRef Type subclass - UnrankedMemRefType.
924class PyUnrankedMemRefType
925 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
926public:
927 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
928 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
930 static constexpr const char *pyClassName = "UnrankedMemRefType";
932
933 static void bindDerived(ClassTy &c) {
934 c.def_static(
935 "get",
936 [](PyType &elementType, PyAttribute *memorySpace,
937 DefaultingPyLocation loc) {
938 PyMlirContext::ErrorCapture errors(loc->getContext());
939 MlirAttribute memSpaceAttr = {};
940 if (memorySpace)
941 memSpaceAttr = *memorySpace;
942
943 MlirType t =
944 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
945 if (mlirTypeIsNull(t))
946 throw MLIRError("Invalid type", errors.take());
947 return PyUnrankedMemRefType(elementType.getContext(), t);
948 },
949 nb::arg("element_type"), nb::arg("memory_space").none(),
950 nb::arg("loc") = nb::none(), "Create a unranked memref type")
951 .def_static(
952 "get_unchecked",
953 [](PyType &elementType, PyAttribute *memorySpace,
954 DefaultingPyMlirContext context) {
955 PyMlirContext::ErrorCapture errors(context->getRef());
956 MlirAttribute memSpaceAttr = {};
957 if (memorySpace)
958 memSpaceAttr = *memorySpace;
959
960 MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
961 if (mlirTypeIsNull(t))
962 throw MLIRError("Invalid type", errors.take());
963 return PyUnrankedMemRefType(elementType.getContext(), t);
964 },
965 nb::arg("element_type"), nb::arg("memory_space").none(),
966 nb::arg("context") = nb::none(), "Create a unranked memref type")
967 .def_prop_ro(
968 "memory_space",
969 [](PyUnrankedMemRefType &self)
970 -> std::optional<nb::typed<nb::object, PyAttribute>> {
971 MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
972 if (mlirAttributeIsNull(a))
973 return std::nullopt;
974 return PyAttribute(self.getContext(), a).maybeDownCast();
975 },
976 "Returns the memory space of the given Unranked MemRef type.");
977 }
978};
979
980/// Tuple Type subclass - TupleType.
981class PyTupleType : public PyConcreteType<PyTupleType> {
982public:
983 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
984 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
986 static constexpr const char *pyClassName = "TupleType";
988
989 static void bindDerived(ClassTy &c) {
990 c.def_static(
991 "get_tuple",
992 [](const std::vector<PyType> &elements,
993 DefaultingPyMlirContext context) {
994 std::vector<MlirType> mlirElements;
995 mlirElements.reserve(elements.size());
996 for (const auto &element : elements)
997 mlirElements.push_back(element.get());
998 MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
999 mlirElements.data());
1000 return PyTupleType(context->getRef(), t);
1001 },
1002 nb::arg("elements"), nb::arg("context") = nb::none(),
1003 "Create a tuple type");
1004 c.def_static(
1005 "get_tuple",
1006 [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
1007 MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
1008 elements.data());
1009 return PyTupleType(context->getRef(), t);
1010 },
1011 nb::arg("elements"), nb::arg("context") = nb::none(),
1012 // clang-format off
1013 nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
1014 // clang-format on
1015 "Create a tuple type");
1016 c.def(
1017 "get_type",
1018 [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
1019 return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
1020 .maybeDownCast();
1021 },
1022 nb::arg("pos"), "Returns the pos-th type in the tuple type.");
1023 c.def_prop_ro(
1024 "num_types",
1025 [](PyTupleType &self) -> intptr_t {
1026 return mlirTupleTypeGetNumTypes(self);
1027 },
1028 "Returns the number of types contained in a tuple.");
1029 }
1030};
1031
1032/// Function type.
1033class PyFunctionType : public PyConcreteType<PyFunctionType> {
1034public:
1035 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
1036 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1038 static constexpr const char *pyClassName = "FunctionType";
1040
1041 static void bindDerived(ClassTy &c) {
1042 c.def_static(
1043 "get",
1044 [](std::vector<PyType> inputs, std::vector<PyType> results,
1045 DefaultingPyMlirContext context) {
1046 std::vector<MlirType> mlirInputs;
1047 mlirInputs.reserve(inputs.size());
1048 for (const auto &input : inputs)
1049 mlirInputs.push_back(input.get());
1050 std::vector<MlirType> mlirResults;
1051 mlirResults.reserve(results.size());
1052 for (const auto &result : results)
1053 mlirResults.push_back(result.get());
1054
1055 MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
1056 mlirInputs.data(), results.size(),
1057 mlirResults.data());
1058 return PyFunctionType(context->getRef(), t);
1059 },
1060 nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
1061 "Gets a FunctionType from a list of input and result types");
1062 c.def_static(
1063 "get",
1064 [](std::vector<MlirType> inputs, std::vector<MlirType> results,
1065 DefaultingPyMlirContext context) {
1066 MlirType t =
1067 mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
1068 results.size(), results.data());
1069 return PyFunctionType(context->getRef(), t);
1070 },
1071 nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
1072 // clang-format off
1073 nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
1074 // clang-format on
1075 "Gets a FunctionType from a list of input and result types");
1076 c.def_prop_ro(
1077 "inputs",
1078 [](PyFunctionType &self) {
1079 MlirType t = self;
1080 nb::list types;
1081 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
1082 ++i) {
1083 types.append(mlirFunctionTypeGetInput(t, i));
1084 }
1085 return types;
1086 },
1087 "Returns the list of input types in the FunctionType.");
1088 c.def_prop_ro(
1089 "results",
1090 [](PyFunctionType &self) {
1091 nb::list types;
1092 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
1093 ++i) {
1094 types.append(mlirFunctionTypeGetResult(self, i));
1095 }
1096 return types;
1097 },
1098 "Returns the list of result types in the FunctionType.");
1099 }
1100};
1101
1102static MlirStringRef toMlirStringRef(const std::string &s) {
1103 return mlirStringRefCreate(s.data(), s.size());
1104}
1105
1106/// Opaque Type subclass - OpaqueType.
1107class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
1108public:
1109 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
1110 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1112 static constexpr const char *pyClassName = "OpaqueType";
1114
1115 static void bindDerived(ClassTy &c) {
1116 c.def_static(
1117 "get",
1118 [](const std::string &dialectNamespace, const std::string &typeData,
1119 DefaultingPyMlirContext context) {
1120 MlirType type = mlirOpaqueTypeGet(context->get(),
1121 toMlirStringRef(dialectNamespace),
1122 toMlirStringRef(typeData));
1123 return PyOpaqueType(context->getRef(), type);
1124 },
1125 nb::arg("dialect_namespace"), nb::arg("buffer"),
1126 nb::arg("context") = nb::none(),
1127 "Create an unregistered (opaque) dialect type.");
1128 c.def_prop_ro(
1129 "dialect_namespace",
1130 [](PyOpaqueType &self) {
1132 return nb::str(stringRef.data, stringRef.length);
1133 },
1134 "Returns the dialect namespace for the Opaque type as a string.");
1135 c.def_prop_ro(
1136 "data",
1137 [](PyOpaqueType &self) {
1138 MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
1139 return nb::str(stringRef.data, stringRef.length);
1140 },
1141 "Returns the data for the Opaque type as a string.");
1142 }
1143};
1144
1145} // namespace
1146
1147void mlir::python::populateIRTypes(nb::module_ &m) {
1148 PyIntegerType::bind(m);
1149 PyFloatType::bind(m);
1150 PyIndexType::bind(m);
1151 PyFloat4E2M1FNType::bind(m);
1152 PyFloat6E2M3FNType::bind(m);
1153 PyFloat6E3M2FNType::bind(m);
1154 PyFloat8E4M3FNType::bind(m);
1155 PyFloat8E5M2Type::bind(m);
1156 PyFloat8E4M3Type::bind(m);
1157 PyFloat8E4M3FNUZType::bind(m);
1158 PyFloat8E4M3B11FNUZType::bind(m);
1159 PyFloat8E5M2FNUZType::bind(m);
1160 PyFloat8E3M4Type::bind(m);
1161 PyFloat8E8M0FNUType::bind(m);
1162 PyBF16Type::bind(m);
1163 PyF16Type::bind(m);
1164 PyTF32Type::bind(m);
1165 PyF32Type::bind(m);
1166 PyF64Type::bind(m);
1167 PyNoneType::bind(m);
1168 PyComplexType::bind(m);
1170 PyVectorType::bind(m);
1171 PyRankedTensorType::bind(m);
1172 PyUnrankedTensorType::bind(m);
1173 PyMemRefType::bind(m);
1174 PyUnrankedMemRefType::bind(m);
1175 PyTupleType::bind(m);
1176 PyFunctionType::bind(m);
1177 PyOpaqueType::bind(m);
1178}
static MlirStringRef toMlirStringRef(const std::string &s)
Definition IRCore.cpp:77
Shaped Type Interface - ShapedType.
Definition IRTypes.h:17
static const IsAFunctionTy isaFunction
Definition IRTypes.h:19
static void bindDerived(ClassTy &c)
Definition IRTypes.cpp:515
PyMlirContextRef & getContext()
Accesses the context reference.
Definition IRModule.h:292
ReferrentTy * get() const
nanobind::class_< PyShapedType, PyType > ClassTy
Definition IRModule.h:935
static void bind(nanobind::module_ &m)
Definition IRModule.h:959
PyType(PyMlirContextRef contextRef, MlirType type)
Definition IRModule.h:880
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void)
Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void)
Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void)
Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void)
Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void)
Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void)
Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is static.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val)
Checks whether the given dimension value of a stride or an offset is statically-sized.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void)
Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void)
Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Creates a scalable vector type with the shape identified by its rank and dimensions.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void)
Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void)
Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void)
Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void)
Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void)
Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType)
Creates a vector type of the shape identified by its rank and dimensions, with the given element type...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void)
Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void)
Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks whether the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void)
Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void)
Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void)
Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void)
Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void)
Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace)
Creates an Unranked MemRef type with the given element type and in the given memory space.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void)
Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void)
Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size)
Checks whether the given shaped type dimension value is statically-sized.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps,...
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType)
Creates an unranked tensor type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void)
Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void)
Returns the value indicating a dynamic size in a shaped type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void)
Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition IR.h:1152
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition Support.h:82
struct MlirStringRef MlirStringRef
Definition Support.h:77
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition Support.h:127
void populateIRTypes(nanobind::module_ &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
const char * data
Pointer to the first symbol.
Definition Support.h:74
size_t length
Length of the fragment.
Definition Support.h:75