26 return unwrap(type).isa<IntegerType>();
30 return wrap(IntegerType::get(
unwrap(ctx), bitwidth));
38 return wrap(IntegerType::get(
unwrap(ctx), bitwidth, IntegerType::Unsigned));
42 return unwrap(type).cast<IntegerType>().getWidth();
46 return unwrap(type).cast<IntegerType>().isSignless();
50 return unwrap(type).cast<IntegerType>().isSigned();
54 return unwrap(type).cast<IntegerType>().isUnsigned();
72 return unwrap(type).isFloat8E5M2();
80 return unwrap(type).isFloat8E4M3FN();
88 return unwrap(type).isFloat8E5M2FNUZ();
96 return unwrap(type).isFloat8E4M3FNUZ();
104 return unwrap(type).isFloat8E4M3B11FNUZ();
150 return unwrap(type).isa<ComplexType>();
154 return wrap(ComplexType::get(
unwrap(elementType)));
172 return unwrap(type).cast<ShapedType>().hasRank();
176 return unwrap(type).cast<ShapedType>().getRank();
180 return unwrap(type).cast<ShapedType>().hasStaticShape();
184 return unwrap(type).cast<ShapedType>().isDynamicDim(
185 static_cast<unsigned>(dim));
189 return unwrap(type).cast<ShapedType>().getDimSize(
static_cast<unsigned>(dim));
195 return ShapedType::isDynamic(size);
199 return ShapedType::isDynamic(val);
203 return ShapedType::kDynamic;
213 MlirType elementType) {
219 const int64_t *shape, MlirType elementType) {
220 return wrap(VectorType::getChecked(
232 return unwrap(type).isa<RankedTensorType>();
236 return unwrap(type).isa<UnrankedTensorType>();
240 MlirType elementType, MlirAttribute encoding) {
242 RankedTensorType::get(
llvm::ArrayRef(shape,
static_cast<size_t>(rank)),
247 const int64_t *shape,
248 MlirType elementType,
249 MlirAttribute encoding) {
250 return wrap(RankedTensorType::getChecked(
256 return wrap(
unwrap(type).cast<RankedTensorType>().getEncoding());
260 return wrap(UnrankedTensorType::get(
unwrap(elementType)));
264 MlirType elementType) {
265 return wrap(UnrankedTensorType::getChecked(
unwrap(loc),
unwrap(elementType)));
275 const int64_t *shape, MlirAttribute layout,
276 MlirAttribute memorySpace) {
277 return wrap(MemRefType::get(
280 ? MemRefLayoutAttrInterface()
281 :
unwrap(layout).cast<MemRefLayoutAttrInterface>(),
286 intptr_t rank,
const int64_t *shape,
287 MlirAttribute layout,
288 MlirAttribute memorySpace) {
289 return wrap(MemRefType::getChecked(
293 ? MemRefLayoutAttrInterface()
294 :
unwrap(layout).cast<MemRefLayoutAttrInterface>(),
299 const int64_t *shape,
300 MlirAttribute memorySpace) {
302 unwrap(elementType), MemRefLayoutAttrInterface(),
307 MlirType elementType, intptr_t rank,
308 const int64_t *shape,
309 MlirAttribute memorySpace) {
310 return wrap(MemRefType::getChecked(
312 unwrap(elementType), MemRefLayoutAttrInterface(),
unwrap(memorySpace)));
316 return wrap(
unwrap(type).cast<MemRefType>().getLayout());
320 return wrap(
unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
324 return wrap(
unwrap(type).cast<MemRefType>().getMemorySpace());
332 MlirAttribute memorySpace) {
334 UnrankedMemRefType::get(
unwrap(elementType),
unwrap(memorySpace)));
338 MlirType elementType,
339 MlirAttribute memorySpace) {
340 return wrap(UnrankedMemRefType::getChecked(
unwrap(loc),
unwrap(elementType),
345 return wrap(
unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
355 MlirType
const *elements) {
358 return wrap(TupleType::get(
unwrap(ctx), typeRef));
362 return unwrap(type).cast<TupleType>().size();
366 return wrap(
unwrap(type).cast<TupleType>().getType(
static_cast<size_t>(pos)));
374 return unwrap(type).isa<FunctionType>();
378 MlirType
const *inputs, intptr_t numResults,
379 MlirType
const *results) {
382 (void)
unwrapList(numInputs, inputs, inputsList);
383 (void)
unwrapList(numResults, results, resultsList);
384 return wrap(FunctionType::get(
unwrap(ctx), inputsList, resultsList));
388 return unwrap(type).cast<FunctionType>().getNumInputs();
392 return unwrap(type).cast<FunctionType>().getNumResults();
396 assert(pos >= 0 &&
"pos in array must be positive");
398 unwrap(type).cast<FunctionType>().getInput(
static_cast<unsigned>(pos)));
402 assert(pos >= 0 &&
"pos in array must be positive");
404 unwrap(type).cast<FunctionType>().getResult(
static_cast<unsigned>(pos)));
416 OpaqueType::get(StringAttr::get(
unwrap(ctx),
unwrap(dialectNamespace)),
421 return wrap(
unwrap(type).cast<OpaqueType>().getDialectNamespace().strref());
425 return wrap(
unwrap(type).cast<OpaqueType>().getTypeData());
bool mlirTypeIsAF16(MlirType type)
Checks whether the given type is an f16 type.
bool mlirTypeIsAF64(MlirType type)
Checks whether the given type is an f64 type.
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type)
Checks whether the given type is an f8E4M3FNUZ type.
bool mlirIntegerTypeIsUnsigned(MlirType type)
Checks whether the given integer type is unsigned.
MlirType mlirF32TypeGet(MlirContext ctx)
Creates an f32 type in the given context.
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType)
Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth)
Creates a signless integer type of the given bitwidth in the context.
intptr_t mlirFunctionTypeGetNumResults(MlirType type)
Returns the number of result types.
unsigned mlirIntegerTypeGetWidth(MlirType type)
Returns the bitwidth of an integer type.
bool mlirTypeIsAUnrankedTensor(MlirType type)
Checks whether the given type is an unranked tensor type.
int64_t mlirShapedTypeGetDynamicStrideOrOffset()
Returns the value indicating a dynamic stride or offset in a shaped type.
MlirType mlirF64TypeGet(MlirContext ctx)
Creates a f64 type in the given context.
MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type)
Returns the memory spcae of the given Unranked MemRef type.
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type)
Checks whether the given type is an f8E4M3B11FNUZ type.
bool mlirTypeIsAFloat8E5M2(MlirType type)
Checks whether the given type is an f8E5M2 type.
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth)
Creates a signed integer type of the given bitwidth in the context.
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim)
Returns the dim-th dimension of the given ranked shaped type.
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o illegal arguments,...
MlirType mlirUnrankedTensorTypeGet(MlirType elementType)
Creates an unranked tensor type with the given element type in the same context as the element type.
MlirStringRef mlirOpaqueTypeGetData(MlirType type)
Returns the raw data as a string reference.
bool mlirTypeIsAF32(MlirType type)
Checks whether the given type is an f32 type.
bool mlirTypeIsAFunction(MlirType type)
Checks whether the given type is a function type.
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type)
Returns the affine map of the given MemRef type.
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute memorySpace)
Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
bool mlirTypeIsAMemRef(MlirType type)
Checks whether the given type is a MemRef type.
MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos)
Returns the pos-th result type.
MlirType mlirF16TypeGet(MlirContext ctx)
Creates an f16 type in the given context.
bool mlirTypeIsAComplex(MlirType type)
Checks whether the given type is a Complex type.
MlirType mlirNoneTypeGet(MlirContext ctx)
Creates a None type in the given context.
bool mlirShapedTypeHasRank(MlirType type)
Checks whether the given shaped type is ranked.
bool mlirTypeIsAShaped(MlirType type)
Checks whether the given type is a Shaped type.
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType)
Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
bool mlirIntegerTypeIsSignless(MlirType type)
Checks whether the given integer type is signless.
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in th...
bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim)
Checks wither the dim-th dimension of the given shaped type is dynamic.
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3B11FNUZ type in the given context.
MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, MlirType const *inputs, intptr_t numResults, MlirType const *results)
Creates a function type, mapping a list of input types to result types.
bool mlirTypeIsAUnrankedMemRef(MlirType type)
Checks whether the given type is an UnrankedMemRef type.
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth)
Creates an unsigned integer type of the given bitwidth in the context.
MlirType mlirShapedTypeGetElementType(MlirType type)
Returns the element type of the shaped type.
int64_t mlirShapedTypeGetDynamicSize()
Returns the value indicating a dynamic size in a shaped type.
bool mlirTypeIsAVector(MlirType type)
Checks whether the given type is a Vector type.
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx)
Creates an f8E4M3FNUZ type in the given context.
MlirAttribute mlirMemRefTypeGetLayout(MlirType type)
Returns the layout of the given MemRef type.
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx)
Creates an f8E5M2FNUZ type in the given context.
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)
Checks whether the given value is used as a placeholder for dynamic strides and offsets in shaped typ...
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace)
Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps,...
bool mlirTypeIsABF16(MlirType type)
Checks whether the given type is a bf16 type.
intptr_t mlirTupleTypeGetNumTypes(MlirType type)
Returns the number of types contained in a tuple.
MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type)
Returns the namespace of the dialect with which the given opaque type is associated.
MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos)
Returns the pos-th input type.
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace)
Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping MlirType on illegal arguments,...
bool mlirTypeIsAInteger(MlirType type)
Checks whether the given type is an integer type.
MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type)
Gets the 'encoding' attribute from the ranked tensor type, returning a null attribute if none.
intptr_t mlirFunctionTypeGetNumInputs(MlirType type)
Returns the number of input types.
bool mlirTypeIsATuple(MlirType type)
Checks whether the given type is a tuple type.
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements)
Creates a tuple type that consists of the given list of elemental types.
MlirType mlirComplexTypeGet(MlirType elementType)
Creates a complex type with the given element type in the same context as the element type.
int64_t mlirShapedTypeGetRank(MlirType type)
Returns the rank of the given ranked shaped type.
bool mlirTypeIsAFloat8E4M3FN(MlirType type)
Checks whether the given type is an f8E4M3FN type.
bool mlirTypeIsAOpaque(MlirType type)
Checks whether the given type is an opaque type.
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos)
Returns the pos-th type in the tuple type.
bool mlirTypeIsATensor(MlirType type)
Checks whether the given type is a Tensor type.
bool mlirIntegerTypeIsSigned(MlirType type)
Checks whether the given integer type is signed.
bool mlirShapedTypeHasStaticShape(MlirType type)
Checks whether the given shaped type has a static shape.
MlirType mlirBF16TypeGet(MlirContext ctx)
Creates a bf16 type in the given context.
MlirType mlirComplexTypeGetElementType(MlirType type)
Returns the element type of the given complex type.
bool mlirTypeIsAIndex(MlirType type)
Checks whether the given type is an index type.
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type)
Checks whether the given type is an f8E5M2FNUZ type.
MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData)
Creates an opaque type in the given context associated with the dialect identified by its namespace.
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type)
Returns the memory space of the given MemRef type.
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding)
Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on illegal arguments,...
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType)
Creates a vector type of the shape identified by its rank and dimensions, with the given element type...
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace)
Creates a MemRef type with the given rank, shape, memory space and element type in the same context a...
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx)
Creates an f8E4M3FN type in the given context.
MlirType mlirIndexTypeGet(MlirContext ctx)
Creates an index type in the given context.
bool mlirTypeIsARankedTensor(MlirType type)
Checks whether the given type is a ranked tensor type.
bool mlirShapedTypeIsDynamicSize(int64_t size)
Checks whether the given value is used as a placeholder for dynamic sizes in shaped types.
bool mlirTypeIsANone(MlirType type)
Checks whether the given type is a None type.
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx)
Creates an f8E5M2 type in the given context.
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace)
Creates an Unranked MemRef type with the given element type and in the given memory space.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::ArrayRef< CppTy > unwrapList(size_t size, CTy *first, llvm::SmallVectorImpl< CppTy > &storage)
static FloatType getF64(MLIRContext *ctx)
static FloatType getFloat8E5M2(MLIRContext *ctx)
static FloatType getFloat8E4M3FN(MLIRContext *ctx)
static FloatType getF16(MLIRContext *ctx)
static FloatType getBF16(MLIRContext *ctx)
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx)
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx)
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx)
static FloatType getF32(MLIRContext *ctx)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
This header declares functions that assit transformations in the MemRef dialect.
A pointer to a sized fragment of a string, not necessarily null-terminated.