13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
30 struct ArrayTypeStorage;
31 struct CooperativeMatrixTypeStorage;
32 struct ImageTypeStorage;
33 struct MatrixTypeStorage;
34 struct PointerTypeStorage;
35 struct RuntimeArrayTypeStorage;
36 struct SampledImageTypeStorage;
37 struct StructTypeStorage;
59 std::optional<StorageClass> storage = std::nullopt);
71 std::optional<StorageClass> storage = std::nullopt);
82 using SPIRVType::SPIRVType;
89 static bool isValid(IntegerType);
92 std::optional<StorageClass> storage = std::nullopt);
94 std::optional<StorageClass> storage = std::nullopt);
102 using SPIRVType::SPIRVType;
107 static bool isValid(VectorType);
120 std::optional<StorageClass> storage = std::nullopt);
122 std::optional<StorageClass> storage = std::nullopt);
129 detail::ArrayTypeStorage> {
133 static constexpr StringLiteral
name =
"spirv.array";
150 std::optional<StorageClass> storage = std::nullopt);
152 std::optional<StorageClass> storage = std::nullopt);
161 :
public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
165 static constexpr StringLiteral
name =
"spirv.image";
169 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
170 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
171 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
172 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
173 ImageFormat format = ImageFormat::Unknown) {
175 std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
176 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
177 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
182 get(std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
183 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
195 std::optional<StorageClass> storage = std::nullopt);
197 std::optional<StorageClass> storage = std::nullopt);
202 detail::PointerTypeStorage> {
206 static constexpr StringLiteral
name =
"spirv.pointer";
215 std::optional<StorageClass> storage = std::nullopt);
217 std::optional<StorageClass> storage = std::nullopt);
223 detail::RuntimeArrayTypeStorage> {
227 static constexpr StringLiteral
name =
"spirv.rtarray";
241 std::optional<StorageClass> storage = std::nullopt);
243 std::optional<StorageClass> storage = std::nullopt);
249 detail::SampledImageTypeStorage> {
253 static constexpr StringLiteral
name =
"spirv.sampled_image";
267 std::optional<spirv::StorageClass> storage = std::nullopt);
270 std::optional<spirv::StorageClass> storage = std::nullopt);
293 detail::StructTypeStorage, TypeTrait::IsMutable> {
300 static constexpr StringLiteral
name =
"spirv.struct";
340 static StructType
getIdentified(MLIRContext *context, StringRef identifier);
348 static StructType
getEmpty(MLIRContext *context, StringRef identifier =
"");
370 &memberDecorations)
const;
376 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo)
const;
382 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
383 ArrayRef<MemberDecorationInfo> memberDecorations = {});
386 std::optional<StorageClass> storage = std::nullopt);
388 std::optional<StorageClass> storage = std::nullopt);
392 hash_value(
const StructType::MemberDecorationInfo &memberDecorationInfo);
397 detail::CooperativeMatrixTypeStorage> {
401 static constexpr StringLiteral
name =
"spirv.coopmatrix";
404 uint32_t columns, Scope scope,
405 CooperativeMatrixUseKHR use);
415 CooperativeMatrixUseKHR
getUse()
const;
418 std::optional<StorageClass> storage = std::nullopt);
420 std::optional<StorageClass> storage = std::nullopt);
425 detail::MatrixTypeStorage> {
429 static constexpr StringLiteral
name =
"spirv.matrix";
434 Type columnType, uint32_t columnCount);
438 Type columnType, uint32_t columnCount);
458 std::optional<StorageClass> storage = std::nullopt);
460 std::optional<StorageClass> storage = std::nullopt);
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
static constexpr StringLiteral name
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static constexpr StringLiteral name
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
static constexpr StringLiteral name
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static constexpr StringLiteral name
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool operator==(const MemberDecorationInfo &other) const
MemberDecorationInfo(uint32_t index, uint32_t hasValue, Decoration decoration, uint32_t decorationValue)
bool operator<(const MemberDecorationInfo &other) const