MLIR  20.0.0git
IRTypes.cpp
Go to the documentation of this file.
1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
14 
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Support.h"
18 
19 #include <optional>
20 
21 namespace py = pybind11;
22 using namespace mlir;
23 using namespace mlir::python;
24 
25 using llvm::SmallVector;
26 using llvm::Twine;
27 
28 namespace {
29 
30 /// Checks whether the given type is an integer or float type.
31 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
32  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
33  mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
34 }
35 
36 class PyIntegerType : public PyConcreteType<PyIntegerType> {
37 public:
38  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
39  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
41  static constexpr const char *pyClassName = "IntegerType";
43 
44  static void bindDerived(ClassTy &c) {
45  c.def_static(
46  "get_signless",
47  [](unsigned width, DefaultingPyMlirContext context) {
48  MlirType t = mlirIntegerTypeGet(context->get(), width);
49  return PyIntegerType(context->getRef(), t);
50  },
51  py::arg("width"), py::arg("context") = py::none(),
52  "Create a signless integer type");
53  c.def_static(
54  "get_signed",
55  [](unsigned width, DefaultingPyMlirContext context) {
56  MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
57  return PyIntegerType(context->getRef(), t);
58  },
59  py::arg("width"), py::arg("context") = py::none(),
60  "Create a signed integer type");
61  c.def_static(
62  "get_unsigned",
63  [](unsigned width, DefaultingPyMlirContext context) {
64  MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
65  return PyIntegerType(context->getRef(), t);
66  },
67  py::arg("width"), py::arg("context") = py::none(),
68  "Create an unsigned integer type");
69  c.def_property_readonly(
70  "width",
71  [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
72  "Returns the width of the integer type");
73  c.def_property_readonly(
74  "is_signless",
75  [](PyIntegerType &self) -> bool {
76  return mlirIntegerTypeIsSignless(self);
77  },
78  "Returns whether this is a signless integer");
79  c.def_property_readonly(
80  "is_signed",
81  [](PyIntegerType &self) -> bool {
82  return mlirIntegerTypeIsSigned(self);
83  },
84  "Returns whether this is a signed integer");
85  c.def_property_readonly(
86  "is_unsigned",
87  [](PyIntegerType &self) -> bool {
88  return mlirIntegerTypeIsUnsigned(self);
89  },
90  "Returns whether this is an unsigned integer");
91  }
92 };
93 
94 /// Index Type subclass - IndexType.
95 class PyIndexType : public PyConcreteType<PyIndexType> {
96 public:
97  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
98  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
100  static constexpr const char *pyClassName = "IndexType";
102 
103  static void bindDerived(ClassTy &c) {
104  c.def_static(
105  "get",
106  [](DefaultingPyMlirContext context) {
107  MlirType t = mlirIndexTypeGet(context->get());
108  return PyIndexType(context->getRef(), t);
109  },
110  py::arg("context") = py::none(), "Create a index type.");
111  }
112 };
113 
114 class PyFloatType : public PyConcreteType<PyFloatType> {
115 public:
116  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
117  static constexpr const char *pyClassName = "FloatType";
119 
120  static void bindDerived(ClassTy &c) {
121  c.def_property_readonly(
122  "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
123  "Returns the width of the floating-point type");
124  }
125 };
126 
127 /// Floating Point Type subclass - Float4E2M1FNType.
128 class PyFloat4E2M1FNType
129  : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
130 public:
131  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
132  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
134  static constexpr const char *pyClassName = "Float4E2M1FNType";
136 
137  static void bindDerived(ClassTy &c) {
138  c.def_static(
139  "get",
140  [](DefaultingPyMlirContext context) {
141  MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
142  return PyFloat4E2M1FNType(context->getRef(), t);
143  },
144  py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
145  }
146 };
147 
148 /// Floating Point Type subclass - Float6E2M3FNType.
149 class PyFloat6E2M3FNType
150  : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
151 public:
152  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
153  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
155  static constexpr const char *pyClassName = "Float6E2M3FNType";
157 
158  static void bindDerived(ClassTy &c) {
159  c.def_static(
160  "get",
161  [](DefaultingPyMlirContext context) {
162  MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
163  return PyFloat6E2M3FNType(context->getRef(), t);
164  },
165  py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
166  }
167 };
168 
169 /// Floating Point Type subclass - Float6E3M2FNType.
170 class PyFloat6E3M2FNType
171  : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
172 public:
173  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
174  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
176  static constexpr const char *pyClassName = "Float6E3M2FNType";
178 
179  static void bindDerived(ClassTy &c) {
180  c.def_static(
181  "get",
182  [](DefaultingPyMlirContext context) {
183  MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
184  return PyFloat6E3M2FNType(context->getRef(), t);
185  },
186  py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
187  }
188 };
189 
190 /// Floating Point Type subclass - Float8E4M3FNType.
191 class PyFloat8E4M3FNType
192  : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
193 public:
194  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
195  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
197  static constexpr const char *pyClassName = "Float8E4M3FNType";
199 
200  static void bindDerived(ClassTy &c) {
201  c.def_static(
202  "get",
203  [](DefaultingPyMlirContext context) {
204  MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
205  return PyFloat8E4M3FNType(context->getRef(), t);
206  },
207  py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
208  }
209 };
210 
211 /// Floating Point Type subclass - Float8E5M2Type.
212 class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
213 public:
214  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
215  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
217  static constexpr const char *pyClassName = "Float8E5M2Type";
219 
220  static void bindDerived(ClassTy &c) {
221  c.def_static(
222  "get",
223  [](DefaultingPyMlirContext context) {
224  MlirType t = mlirFloat8E5M2TypeGet(context->get());
225  return PyFloat8E5M2Type(context->getRef(), t);
226  },
227  py::arg("context") = py::none(), "Create a float8_e5m2 type.");
228  }
229 };
230 
231 /// Floating Point Type subclass - Float8E4M3Type.
232 class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
233 public:
234  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
235  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
237  static constexpr const char *pyClassName = "Float8E4M3Type";
239 
240  static void bindDerived(ClassTy &c) {
241  c.def_static(
242  "get",
243  [](DefaultingPyMlirContext context) {
244  MlirType t = mlirFloat8E4M3TypeGet(context->get());
245  return PyFloat8E4M3Type(context->getRef(), t);
246  },
247  py::arg("context") = py::none(), "Create a float8_e4m3 type.");
248  }
249 };
250 
251 /// Floating Point Type subclass - Float8E4M3FNUZ.
252 class PyFloat8E4M3FNUZType
253  : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
254 public:
255  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
256  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
258  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
260 
261  static void bindDerived(ClassTy &c) {
262  c.def_static(
263  "get",
264  [](DefaultingPyMlirContext context) {
265  MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
266  return PyFloat8E4M3FNUZType(context->getRef(), t);
267  },
268  py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
269  }
270 };
271 
272 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
273 class PyFloat8E4M3B11FNUZType
274  : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
275 public:
276  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
277  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
279  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
281 
282  static void bindDerived(ClassTy &c) {
283  c.def_static(
284  "get",
285  [](DefaultingPyMlirContext context) {
286  MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
287  return PyFloat8E4M3B11FNUZType(context->getRef(), t);
288  },
289  py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
290  }
291 };
292 
293 /// Floating Point Type subclass - Float8E5M2FNUZ.
294 class PyFloat8E5M2FNUZType
295  : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
296 public:
297  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
298  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
300  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
302 
303  static void bindDerived(ClassTy &c) {
304  c.def_static(
305  "get",
306  [](DefaultingPyMlirContext context) {
307  MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
308  return PyFloat8E5M2FNUZType(context->getRef(), t);
309  },
310  py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
311  }
312 };
313 
314 /// Floating Point Type subclass - Float8E3M4Type.
315 class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
316 public:
317  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
318  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
320  static constexpr const char *pyClassName = "Float8E3M4Type";
322 
323  static void bindDerived(ClassTy &c) {
324  c.def_static(
325  "get",
326  [](DefaultingPyMlirContext context) {
327  MlirType t = mlirFloat8E3M4TypeGet(context->get());
328  return PyFloat8E3M4Type(context->getRef(), t);
329  },
330  py::arg("context") = py::none(), "Create a float8_e3m4 type.");
331  }
332 };
333 
334 /// Floating Point Type subclass - Float8E8M0FNUType.
335 class PyFloat8E8M0FNUType
336  : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
337 public:
338  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
339  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
341  static constexpr const char *pyClassName = "Float8E8M0FNUType";
343 
344  static void bindDerived(ClassTy &c) {
345  c.def_static(
346  "get",
347  [](DefaultingPyMlirContext context) {
348  MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
349  return PyFloat8E8M0FNUType(context->getRef(), t);
350  },
351  py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
352  }
353 };
354 
355 /// Floating Point Type subclass - BF16Type.
356 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
357 public:
358  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
359  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
361  static constexpr const char *pyClassName = "BF16Type";
363 
364  static void bindDerived(ClassTy &c) {
365  c.def_static(
366  "get",
367  [](DefaultingPyMlirContext context) {
368  MlirType t = mlirBF16TypeGet(context->get());
369  return PyBF16Type(context->getRef(), t);
370  },
371  py::arg("context") = py::none(), "Create a bf16 type.");
372  }
373 };
374 
375 /// Floating Point Type subclass - F16Type.
376 class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
377 public:
378  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
379  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
381  static constexpr const char *pyClassName = "F16Type";
383 
384  static void bindDerived(ClassTy &c) {
385  c.def_static(
386  "get",
387  [](DefaultingPyMlirContext context) {
388  MlirType t = mlirF16TypeGet(context->get());
389  return PyF16Type(context->getRef(), t);
390  },
391  py::arg("context") = py::none(), "Create a f16 type.");
392  }
393 };
394 
395 /// Floating Point Type subclass - TF32Type.
396 class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
397 public:
398  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
399  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
401  static constexpr const char *pyClassName = "FloatTF32Type";
403 
404  static void bindDerived(ClassTy &c) {
405  c.def_static(
406  "get",
407  [](DefaultingPyMlirContext context) {
408  MlirType t = mlirTF32TypeGet(context->get());
409  return PyTF32Type(context->getRef(), t);
410  },
411  py::arg("context") = py::none(), "Create a tf32 type.");
412  }
413 };
414 
415 /// Floating Point Type subclass - F32Type.
416 class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
417 public:
418  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
419  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
421  static constexpr const char *pyClassName = "F32Type";
423 
424  static void bindDerived(ClassTy &c) {
425  c.def_static(
426  "get",
427  [](DefaultingPyMlirContext context) {
428  MlirType t = mlirF32TypeGet(context->get());
429  return PyF32Type(context->getRef(), t);
430  },
431  py::arg("context") = py::none(), "Create a f32 type.");
432  }
433 };
434 
435 /// Floating Point Type subclass - F64Type.
436 class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
437 public:
438  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
439  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
441  static constexpr const char *pyClassName = "F64Type";
443 
444  static void bindDerived(ClassTy &c) {
445  c.def_static(
446  "get",
447  [](DefaultingPyMlirContext context) {
448  MlirType t = mlirF64TypeGet(context->get());
449  return PyF64Type(context->getRef(), t);
450  },
451  py::arg("context") = py::none(), "Create a f64 type.");
452  }
453 };
454 
455 /// None Type subclass - NoneType.
456 class PyNoneType : public PyConcreteType<PyNoneType> {
457 public:
458  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
459  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
461  static constexpr const char *pyClassName = "NoneType";
463 
464  static void bindDerived(ClassTy &c) {
465  c.def_static(
466  "get",
467  [](DefaultingPyMlirContext context) {
468  MlirType t = mlirNoneTypeGet(context->get());
469  return PyNoneType(context->getRef(), t);
470  },
471  py::arg("context") = py::none(), "Create a none type.");
472  }
473 };
474 
475 /// Complex Type subclass - ComplexType.
476 class PyComplexType : public PyConcreteType<PyComplexType> {
477 public:
478  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
479  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
481  static constexpr const char *pyClassName = "ComplexType";
483 
484  static void bindDerived(ClassTy &c) {
485  c.def_static(
486  "get",
487  [](PyType &elementType) {
488  // The element must be a floating point or integer scalar type.
489  if (mlirTypeIsAIntegerOrFloat(elementType)) {
490  MlirType t = mlirComplexTypeGet(elementType);
491  return PyComplexType(elementType.getContext(), t);
492  }
493  throw py::value_error(
494  (Twine("invalid '") +
495  py::repr(py::cast(elementType)).cast<std::string>() +
496  "' and expected floating point or integer type.")
497  .str());
498  },
499  "Create a complex type");
500  c.def_property_readonly(
501  "element_type",
502  [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
503  "Returns element type.");
504  }
505 };
506 
507 } // namespace
508 
509 // Shaped Type Interface - ShapedType
511  c.def_property_readonly(
512  "element_type",
513  [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
514  "Returns the element type of the shaped type.");
515  c.def_property_readonly(
516  "has_rank",
517  [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
518  "Returns whether the given shaped type is ranked.");
519  c.def_property_readonly(
520  "rank",
521  [](PyShapedType &self) {
522  self.requireHasRank();
523  return mlirShapedTypeGetRank(self);
524  },
525  "Returns the rank of the given ranked shaped type.");
526  c.def_property_readonly(
527  "has_static_shape",
528  [](PyShapedType &self) -> bool {
529  return mlirShapedTypeHasStaticShape(self);
530  },
531  "Returns whether the given shaped type has a static shape.");
532  c.def(
533  "is_dynamic_dim",
534  [](PyShapedType &self, intptr_t dim) -> bool {
535  self.requireHasRank();
536  return mlirShapedTypeIsDynamicDim(self, dim);
537  },
538  py::arg("dim"),
539  "Returns whether the dim-th dimension of the given shaped type is "
540  "dynamic.");
541  c.def(
542  "get_dim_size",
543  [](PyShapedType &self, intptr_t dim) {
544  self.requireHasRank();
545  return mlirShapedTypeGetDimSize(self, dim);
546  },
547  py::arg("dim"),
548  "Returns the dim-th dimension of the given ranked shaped type.");
549  c.def_static(
550  "is_dynamic_size",
551  [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
552  py::arg("dim_size"),
553  "Returns whether the given dimension size indicates a dynamic "
554  "dimension.");
555  c.def(
556  "is_dynamic_stride_or_offset",
557  [](PyShapedType &self, int64_t val) -> bool {
558  self.requireHasRank();
560  },
561  py::arg("dim_size"),
562  "Returns whether the given value is used as a placeholder for dynamic "
563  "strides and offsets in shaped types.");
564  c.def_property_readonly(
565  "shape",
566  [](PyShapedType &self) {
567  self.requireHasRank();
568 
569  std::vector<int64_t> shape;
570  int64_t rank = mlirShapedTypeGetRank(self);
571  shape.reserve(rank);
572  for (int64_t i = 0; i < rank; ++i)
573  shape.push_back(mlirShapedTypeGetDimSize(self, i));
574  return shape;
575  },
576  "Returns the shape of the ranked shaped type as a list of integers.");
577  c.def_static(
578  "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
579  "Returns the value used to indicate dynamic dimensions in shaped "
580  "types.");
581  c.def_static(
582  "get_dynamic_stride_or_offset",
583  []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
584  "Returns the value used to indicate dynamic strides or offsets in "
585  "shaped types.");
586 }
587 
588 void mlir::PyShapedType::requireHasRank() {
589  if (!mlirShapedTypeHasRank(*this)) {
590  throw py::value_error(
591  "calling this method requires that the type has a rank.");
592  }
593 }
594 
597 
598 namespace {
599 
600 /// Vector Type subclass - VectorType.
601 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
602 public:
603  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
604  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
606  static constexpr const char *pyClassName = "VectorType";
608 
609  static void bindDerived(ClassTy &c) {
610  c.def_static("get", &PyVectorType::get, py::arg("shape"),
611  py::arg("element_type"), py::kw_only(),
612  py::arg("scalable") = py::none(),
613  py::arg("scalable_dims") = py::none(),
614  py::arg("loc") = py::none(), "Create a vector type")
615  .def_property_readonly(
616  "scalable",
617  [](MlirType self) { return mlirVectorTypeIsScalable(self); })
618  .def_property_readonly("scalable_dims", [](MlirType self) {
619  std::vector<bool> scalableDims;
620  size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
621  scalableDims.reserve(rank);
622  for (size_t i = 0; i < rank; ++i)
623  scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
624  return scalableDims;
625  });
626  }
627 
628 private:
629  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
630  std::optional<py::list> scalable,
631  std::optional<std::vector<int64_t>> scalableDims,
632  DefaultingPyLocation loc) {
633  if (scalable && scalableDims) {
634  throw py::value_error("'scalable' and 'scalable_dims' kwargs "
635  "are mutually exclusive.");
636  }
637 
638  PyMlirContext::ErrorCapture errors(loc->getContext());
639  MlirType type;
640  if (scalable) {
641  if (scalable->size() != shape.size())
642  throw py::value_error("Expected len(scalable) == len(shape).");
643 
644  SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
645  *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
646  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
647  scalableDimFlags.data(),
648  elementType);
649  } else if (scalableDims) {
650  SmallVector<bool> scalableDimFlags(shape.size(), false);
651  for (int64_t dim : *scalableDims) {
652  if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
653  throw py::value_error("Scalable dimension index out of bounds.");
654  scalableDimFlags[dim] = true;
655  }
656  type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
657  scalableDimFlags.data(),
658  elementType);
659  } else {
660  type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
661  elementType);
662  }
663  if (mlirTypeIsNull(type))
664  throw MLIRError("Invalid type", errors.take());
665  return PyVectorType(elementType.getContext(), type);
666  }
667 };
668 
669 /// Ranked Tensor Type subclass - RankedTensorType.
670 class PyRankedTensorType
671  : public PyConcreteType<PyRankedTensorType, PyShapedType> {
672 public:
673  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
674  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
676  static constexpr const char *pyClassName = "RankedTensorType";
678 
679  static void bindDerived(ClassTy &c) {
680  c.def_static(
681  "get",
682  [](std::vector<int64_t> shape, PyType &elementType,
683  std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
684  PyMlirContext::ErrorCapture errors(loc->getContext());
685  MlirType t = mlirRankedTensorTypeGetChecked(
686  loc, shape.size(), shape.data(), elementType,
687  encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
688  if (mlirTypeIsNull(t))
689  throw MLIRError("Invalid type", errors.take());
690  return PyRankedTensorType(elementType.getContext(), t);
691  },
692  py::arg("shape"), py::arg("element_type"),
693  py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
694  "Create a ranked tensor type");
695  c.def_property_readonly(
696  "encoding",
697  [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
698  MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
699  if (mlirAttributeIsNull(encoding))
700  return std::nullopt;
701  return encoding;
702  });
703  }
704 };
705 
706 /// Unranked Tensor Type subclass - UnrankedTensorType.
707 class PyUnrankedTensorType
708  : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
709 public:
710  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
711  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
713  static constexpr const char *pyClassName = "UnrankedTensorType";
715 
716  static void bindDerived(ClassTy &c) {
717  c.def_static(
718  "get",
719  [](PyType &elementType, DefaultingPyLocation loc) {
720  PyMlirContext::ErrorCapture errors(loc->getContext());
721  MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
722  if (mlirTypeIsNull(t))
723  throw MLIRError("Invalid type", errors.take());
724  return PyUnrankedTensorType(elementType.getContext(), t);
725  },
726  py::arg("element_type"), py::arg("loc") = py::none(),
727  "Create a unranked tensor type");
728  }
729 };
730 
731 /// Ranked MemRef Type subclass - MemRefType.
732 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
733 public:
734  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
735  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
737  static constexpr const char *pyClassName = "MemRefType";
739 
740  static void bindDerived(ClassTy &c) {
741  c.def_static(
742  "get",
743  [](std::vector<int64_t> shape, PyType &elementType,
744  PyAttribute *layout, PyAttribute *memorySpace,
745  DefaultingPyLocation loc) {
746  PyMlirContext::ErrorCapture errors(loc->getContext());
747  MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
748  MlirAttribute memSpaceAttr =
749  memorySpace ? *memorySpace : mlirAttributeGetNull();
750  MlirType t =
751  mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
752  shape.data(), layoutAttr, memSpaceAttr);
753  if (mlirTypeIsNull(t))
754  throw MLIRError("Invalid type", errors.take());
755  return PyMemRefType(elementType.getContext(), t);
756  },
757  py::arg("shape"), py::arg("element_type"),
758  py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
759  py::arg("loc") = py::none(), "Create a memref type")
760  .def_property_readonly(
761  "layout",
762  [](PyMemRefType &self) -> MlirAttribute {
763  return mlirMemRefTypeGetLayout(self);
764  },
765  "The layout of the MemRef type.")
766  .def(
767  "get_strides_and_offset",
768  [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
769  std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
770  int64_t offset;
772  self, strides.data(), &offset)))
773  throw std::runtime_error(
774  "Failed to extract strides and offset from memref.");
775  return {strides, offset};
776  },
777  "The strides and offset of the MemRef type.")
778  .def_property_readonly(
779  "affine_map",
780  [](PyMemRefType &self) -> PyAffineMap {
781  MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
782  return PyAffineMap(self.getContext(), map);
783  },
784  "The layout of the MemRef type as an affine map.")
785  .def_property_readonly(
786  "memory_space",
787  [](PyMemRefType &self) -> std::optional<MlirAttribute> {
788  MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
789  if (mlirAttributeIsNull(a))
790  return std::nullopt;
791  return a;
792  },
793  "Returns the memory space of the given MemRef type.");
794  }
795 };
796 
797 /// Unranked MemRef Type subclass - UnrankedMemRefType.
798 class PyUnrankedMemRefType
799  : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
800 public:
801  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
802  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
804  static constexpr const char *pyClassName = "UnrankedMemRefType";
806 
807  static void bindDerived(ClassTy &c) {
808  c.def_static(
809  "get",
810  [](PyType &elementType, PyAttribute *memorySpace,
811  DefaultingPyLocation loc) {
812  PyMlirContext::ErrorCapture errors(loc->getContext());
813  MlirAttribute memSpaceAttr = {};
814  if (memorySpace)
815  memSpaceAttr = *memorySpace;
816 
817  MlirType t =
818  mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
819  if (mlirTypeIsNull(t))
820  throw MLIRError("Invalid type", errors.take());
821  return PyUnrankedMemRefType(elementType.getContext(), t);
822  },
823  py::arg("element_type"), py::arg("memory_space"),
824  py::arg("loc") = py::none(), "Create a unranked memref type")
825  .def_property_readonly(
826  "memory_space",
827  [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
828  MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
829  if (mlirAttributeIsNull(a))
830  return std::nullopt;
831  return a;
832  },
833  "Returns the memory space of the given Unranked MemRef type.");
834  }
835 };
836 
837 /// Tuple Type subclass - TupleType.
838 class PyTupleType : public PyConcreteType<PyTupleType> {
839 public:
840  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
841  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
843  static constexpr const char *pyClassName = "TupleType";
845 
846  static void bindDerived(ClassTy &c) {
847  c.def_static(
848  "get_tuple",
849  [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
850  MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
851  elements.data());
852  return PyTupleType(context->getRef(), t);
853  },
854  py::arg("elements"), py::arg("context") = py::none(),
855  "Create a tuple type");
856  c.def(
857  "get_type",
858  [](PyTupleType &self, intptr_t pos) {
859  return mlirTupleTypeGetType(self, pos);
860  },
861  py::arg("pos"), "Returns the pos-th type in the tuple type.");
862  c.def_property_readonly(
863  "num_types",
864  [](PyTupleType &self) -> intptr_t {
865  return mlirTupleTypeGetNumTypes(self);
866  },
867  "Returns the number of types contained in a tuple.");
868  }
869 };
870 
871 /// Function type.
872 class PyFunctionType : public PyConcreteType<PyFunctionType> {
873 public:
874  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
875  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
877  static constexpr const char *pyClassName = "FunctionType";
879 
880  static void bindDerived(ClassTy &c) {
881  c.def_static(
882  "get",
883  [](std::vector<MlirType> inputs, std::vector<MlirType> results,
884  DefaultingPyMlirContext context) {
885  MlirType t =
886  mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
887  results.size(), results.data());
888  return PyFunctionType(context->getRef(), t);
889  },
890  py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
891  "Gets a FunctionType from a list of input and result types");
892  c.def_property_readonly(
893  "inputs",
894  [](PyFunctionType &self) {
895  MlirType t = self;
896  py::list types;
897  for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
898  ++i) {
899  types.append(mlirFunctionTypeGetInput(t, i));
900  }
901  return types;
902  },
903  "Returns the list of input types in the FunctionType.");
904  c.def_property_readonly(
905  "results",
906  [](PyFunctionType &self) {
907  py::list types;
908  for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
909  ++i) {
910  types.append(mlirFunctionTypeGetResult(self, i));
911  }
912  return types;
913  },
914  "Returns the list of result types in the FunctionType.");
915  }
916 };
917 
918 static MlirStringRef toMlirStringRef(const std::string &s) {
919  return mlirStringRefCreate(s.data(), s.size());
920 }
921 
922 /// Opaque Type subclass - OpaqueType.
923 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
924 public:
925  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
926  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
928  static constexpr const char *pyClassName = "OpaqueType";
930 
931  static void bindDerived(ClassTy &c) {
932  c.def_static(
933  "get",
934  [](std::string dialectNamespace, std::string typeData,
935  DefaultingPyMlirContext context) {
936  MlirType type = mlirOpaqueTypeGet(context->get(),
937  toMlirStringRef(dialectNamespace),
938  toMlirStringRef(typeData));
939  return PyOpaqueType(context->getRef(), type);
940  },
941  py::arg("dialect_namespace"), py::arg("buffer"),
942  py::arg("context") = py::none(),
943  "Create an unregistered (opaque) dialect type.");
944  c.def_property_readonly(
945  "dialect_namespace",
946  [](PyOpaqueType &self) {
948  return py::str(stringRef.data, stringRef.length);
949  },
950  "Returns the dialect namespace for the Opaque type as a string.");
951  c.def_property_readonly(
952  "data",
953  [](PyOpaqueType &self) {
954  MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
955  return py::str(stringRef.data, stringRef.length);
956  },
957  "Returns the data for the Opaque type as a string.");
958  }
959 };
960 
961 } // namespace
962 
963 void mlir::python::populateIRTypes(py::module &m) {
964  PyIntegerType::bind(m);
965  PyFloatType::bind(m);
966  PyIndexType::bind(m);
967  PyFloat4E2M1FNType::bind(m);
968  PyFloat6E2M3FNType::bind(m);
969  PyFloat6E3M2FNType::bind(m);
970  PyFloat8E4M3FNType::bind(m);
971  PyFloat8E5M2Type::bind(m);
972  PyFloat8E4M3Type::bind(m);
973  PyFloat8E4M3FNUZType::bind(m);
974  PyFloat8E4M3B11FNUZType::bind(m);
975  PyFloat8E5M2FNUZType::bind(m);
976  PyFloat8E3M4Type::bind(m);
977  PyFloat8E8M0FNUType::bind(m);
978  PyBF16Type::bind(m);
979  PyF16Type::bind(m);
980  PyTF32Type::bind(m);
981  PyF32Type::bind(m);
982  PyF64Type::bind(m);
983  PyNoneType::bind(m);
984  PyComplexType::bind(m);
986  PyVectorType::bind(m);
987  PyRankedTensorType::bind(m);
988  PyUnrankedTensorType::bind(m);
989  PyMemRefType::bind(m);
990  PyUnrankedMemRefType::bind(m);
991  PyTupleType::bind(m);
992  PyFunctionType::bind(m);
993  PyOpaqueType::bind(m);
994 }
static MlirStringRef toMlirStringRef(const std::string &s)
Definition: IRCore.cpp:205
static MLIRContext * getContext(OpFoldResult val)
Shaped Type Interface - ShapedType.
Definition: IRTypes.h:17
static const IsAFunctionTy isaFunction
Definition: IRTypes.h:19
static void bindDerived(ClassTy &c)
Definition: IRTypes.cpp:510
PyMlirContextRef & getContext()
Accesses the context reference.
Definition: IRModule.h:311
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:518
Used in function arguments when None should resolve to the current context manager set instance.
Definition: IRModule.h:292
ReferrentTy * get() const
Definition: PybindUtils.h:47
Wrapper around the generic MlirAttribute.
Definition: IRModule.h:1002
CRTP base classes for Python types that subclass Type and should be castable from it (i....
Definition: IRModule.h:930
static void bind(pybind11::module &m)
Definition: IRModule.h:957
Wrapper around the generic MlirType.
Definition: IRModule.h:880
MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void)
Returns an empty attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void)
Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void)
Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type)
Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void)
Returns the typeID of an Float4E2M1FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void)
Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type)
Checks whether the given type is an f8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void)
Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void)
Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type)
Checks whether the given type is an f6E3M2FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx)
Creates an f8E3M4 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type)
Checks whether the given type is an f8E3M4 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void)
Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx)
Creates an f8E8M0FNU type in the given context.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void)
Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type)
Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, const bool *scalable, MlirType elementType)
Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx)
Creates an f6E2M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void)
Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void)
Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void)
Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx)
Creates an f8E4M3 type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void)
Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void)
Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type)
Checks whether the given type is an f6E2M3FN type.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides, int64_t *offset)
Returns the strides of the MemRef if the layout map is in strided form.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void)
Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type)
Checks whether the given type is a floating-point type.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void)
Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void)
Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim)
Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks wither the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type)
Checks whether the given type is an TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void)
Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void)
Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx)
Creates an f6E3M2FN type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void)
Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void)
Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void)
Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void)
Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void)
Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type)
Checks whether the given vector type is scalable, i.e., has at least one scalable dimension.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type)
Checks whether the given type is an f4E2M1FN type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void)
Returns the value indicating a dynamic stride or offset in a shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void)
Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void)
Returns the value indicating a dynamic size in a shaped type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void)
Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void)
Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx)
Creates a TF32 type in the given context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx)
Creates an f4E2M1FN type in the given context.
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1034
static bool mlirTypeIsNull(MlirType type)
Checks whether a type is null.
Definition: IR.h:999
static MlirStringRef mlirStringRefCreate(const char *str, size_t length)
Constructs a string reference from the pointer and length.
Definition: Support.h:82
static bool mlirLogicalResultIsFailure(MlirLogicalResult res)
Checks if the given logical result represents a failure.
Definition: Support.h:127
void populateIRTypes(pybind11::module &m)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
const char * data
Pointer to the first symbol.
Definition: Support.h:74
size_t length
Length of the fragment.
Definition: Support.h:75
Custom exception that allows access to error diagnostic information.
Definition: IRModule.h:1284
RAII object that captures any error diagnostics emitted to the provided context.
Definition: IRModule.h:427