13#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
61 std::optional<StorageClass> storage = std::nullopt);
73 std::optional<StorageClass> storage = std::nullopt);
84 using SPIRVType::SPIRVType;
91 static bool isValid(IntegerType);
98 using SPIRVType::SPIRVType;
103 static bool isValid(VectorType);
118 detail::ArrayTypeStorage> {
122 static constexpr StringLiteral
name =
"spirv.array";
141 :
public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
145 static constexpr StringLiteral
name =
"spirv.image";
149 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
150 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
151 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
152 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
153 ImageFormat format = ImageFormat::Unknown) {
155 std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
156 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
157 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
162 get(std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
163 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
177 :
public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
178 VectorElementTypeInterface::Trait> {
182 static constexpr StringLiteral
name =
"spirv.pointer";
194 detail::RuntimeArrayTypeStorage> {
198 static constexpr StringLiteral
name =
"spirv.rtarray";
215 detail::SampledImageTypeStorage> {
219 static constexpr StringLiteral
name =
"spirv.sampled_image";
238 static constexpr StringLiteral
name =
"spirv.sampler";
263 detail::StructTypeStorage, TypeTrait::IsMutable> {
270 static constexpr StringLiteral
name =
"spirv.struct";
287 return lhs.memberIndex ==
rhs.memberIndex &&
288 lhs.decoration ==
rhs.decoration &&
289 lhs.decorationValue ==
rhs.decorationValue;
294 return std::tuple(
lhs.memberIndex, llvm::to_underlying(
lhs.decoration)) <
295 std::tuple(
rhs.memberIndex, llvm::to_underlying(
rhs.decoration));
311 return lhs.decoration ==
rhs.decoration &&
312 lhs.decorationValue ==
rhs.decorationValue;
317 return llvm::to_underlying(
lhs.decoration) <
318 llvm::to_underlying(
rhs.decoration);
328 ArrayRef<StructDecorationInfo> structDecorations = {});
337 static StructType
getIdentified(MLIRContext *context, StringRef identifier);
345 static StructType
getEmpty(MLIRContext *context, StringRef identifier =
"");
370 &memberDecorations)
const;
376 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo)
const;
381 &structDecorations)
const;
387 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
388 ArrayRef<MemberDecorationInfo> memberDecorations = {},
389 ArrayRef<StructDecorationInfo> structDecorations = {});
401 detail::CooperativeMatrixTypeStorage,
406 static constexpr StringLiteral
name =
"spirv.coopmatrix";
409 uint32_t columns, Scope scope,
410 CooperativeMatrixUseKHR use);
420 CooperativeMatrixUseKHR
getUse()
const;
422 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
429 Type elementType)
const {
433 assert(
shape.value().size() == 2);
442 detail::MatrixTypeStorage, ShapedType::Trait> {
446 static constexpr StringLiteral
name =
"spirv.matrix";
451 Type columnType, uint32_t columnCount);
455 Type columnType, uint32_t columnCount);
474 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
481 Type elementType)
const {
485 assert(
shape.value().size() == 2);
487 auto vectorType = cast<VectorType>(elementType);
488 Type newElementType =
489 vectorType.cloneWith({
shape.value()[0]}, vectorType.getElementType());
491 return get(newElementType,
shape.value()[1]);
498 detail::TensorArmTypeStorage, ShapedType::Trait> {
503 using ShapedTypeTraits::getDimSize;
504 using ShapedTypeTraits::getDynamicDimIndex;
505 using ShapedTypeTraits::getElementTypeBitWidth;
506 using ShapedTypeTraits::getNumDynamicDims;
507 using ShapedTypeTraits::getNumElements;
508 using ShapedTypeTraits::getRank;
509 using ShapedTypeTraits::hasStaticShape;
510 using ShapedTypeTraits::isDynamicDim;
512 static constexpr StringLiteral
name =
"spirv.arm.tensor";
518 Type elementType)
const;
527 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
detail::StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > TypeBase
Utility class for implementing types.
StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > Base
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
static constexpr StringLiteral name
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
CooperativeMatrixType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
static constexpr StringLiteral name
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static constexpr StringLiteral name
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
MatrixType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
ArrayRef< int64_t > getShape() const
static constexpr StringLiteral name
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
static constexpr StringLiteral name
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static constexpr StringLiteral name
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
static constexpr StringLiteral name
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static constexpr StringLiteral name
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
ShapedType::Trait< TensorArmType > ShapedTypeTraits
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
friend bool operator==(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
MemberDecorationInfo(uint32_t index, Decoration decoration, Attribute decorationValue)
Attribute decorationValue
friend bool operator<(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
friend bool operator<(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
friend bool operator==(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Attribute decorationValue
Type storage for SPIR-V structure types: