13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
30 struct ArrayTypeStorage;
31 struct CooperativeMatrixTypeStorage;
32 struct TensorArmTypeStorage;
33 struct ImageTypeStorage;
34 struct MatrixTypeStorage;
35 struct PointerTypeStorage;
36 struct RuntimeArrayTypeStorage;
37 struct SampledImageTypeStorage;
38 struct StructTypeStorage;
60 std::optional<StorageClass> storage = std::nullopt);
72 std::optional<StorageClass> storage = std::nullopt);
83 using SPIRVType::SPIRVType;
90 static bool isValid(IntegerType);
93 std::optional<StorageClass> storage = std::nullopt);
95 std::optional<StorageClass> storage = std::nullopt);
104 using SPIRVType::SPIRVType;
109 static bool isValid(VectorType);
122 std::optional<StorageClass> storage = std::nullopt);
124 std::optional<StorageClass> storage = std::nullopt);
131 detail::ArrayTypeStorage> {
135 static constexpr StringLiteral
name =
"spirv.array";
152 std::optional<StorageClass> storage = std::nullopt);
154 std::optional<StorageClass> storage = std::nullopt);
163 :
public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
167 static constexpr StringLiteral
name =
"spirv.image";
171 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
172 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
173 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
174 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
175 ImageFormat format = ImageFormat::Unknown) {
177 std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
178 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
179 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
184 get(std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
185 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
197 std::optional<StorageClass> storage = std::nullopt);
199 std::optional<StorageClass> storage = std::nullopt);
204 detail::PointerTypeStorage> {
208 static constexpr StringLiteral
name =
"spirv.pointer";
217 std::optional<StorageClass> storage = std::nullopt);
219 std::optional<StorageClass> storage = std::nullopt);
225 detail::RuntimeArrayTypeStorage> {
229 static constexpr StringLiteral
name =
"spirv.rtarray";
243 std::optional<StorageClass> storage = std::nullopt);
245 std::optional<StorageClass> storage = std::nullopt);
251 detail::SampledImageTypeStorage> {
255 static constexpr StringLiteral
name =
"spirv.sampled_image";
269 std::optional<spirv::StorageClass> storage = std::nullopt);
272 std::optional<spirv::StorageClass> storage = std::nullopt);
295 detail::StructTypeStorage, TypeTrait::IsMutable> {
302 static constexpr StringLiteral
name =
"spirv.struct";
360 ArrayRef<StructDecorationInfo> structDecorations = {});
369 static StructType
getIdentified(MLIRContext *context, StringRef identifier);
377 static StructType
getEmpty(MLIRContext *context, StringRef identifier =
"");
402 &memberDecorations)
const;
408 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo)
const;
413 &structDecorations)
const;
419 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
420 ArrayRef<MemberDecorationInfo> memberDecorations = {},
421 ArrayRef<StructDecorationInfo> structDecorations = {});
424 std::optional<StorageClass> storage = std::nullopt);
426 std::optional<StorageClass> storage = std::nullopt);
430 hash_value(
const StructType::MemberDecorationInfo &memberDecorationInfo);
433 hash_value(
const StructType::StructDecorationInfo &structDecorationInfo);
438 detail::CooperativeMatrixTypeStorage,
443 static constexpr StringLiteral
name =
"spirv.coopmatrix";
446 uint32_t columns, Scope scope,
447 CooperativeMatrixUseKHR use);
457 CooperativeMatrixUseKHR
getUse()
const;
460 std::optional<StorageClass> storage = std::nullopt);
462 std::optional<StorageClass> storage = std::nullopt);
464 operator ShapedType()
const {
return llvm::cast<ShapedType>(*
this); }
471 Type elementType)
const {
475 assert(shape.value().size() == 2);
476 return get(elementType, shape.value()[0], shape.value()[1],
getScope(),
483 detail::MatrixTypeStorage> {
487 static constexpr StringLiteral
name =
"spirv.matrix";
492 Type columnType, uint32_t columnCount);
496 Type columnType, uint32_t columnCount);
516 std::optional<StorageClass> storage = std::nullopt);
518 std::optional<StorageClass> storage = std::nullopt);
524 detail::TensorArmTypeStorage, ShapedType::Trait> {
529 using ShapedTypeTraits::getDimSize;
530 using ShapedTypeTraits::getDynamicDimIndex;
531 using ShapedTypeTraits::getElementTypeBitWidth;
532 using ShapedTypeTraits::getNumDynamicDims;
534 using ShapedTypeTraits::getRank;
535 using ShapedTypeTraits::hasStaticShape;
536 using ShapedTypeTraits::isDynamicDim;
538 static constexpr StringLiteral
name =
"spirv.arm.tensor";
544 Type elementType)
const;
553 operator ShapedType()
const {
return llvm::cast<ShapedType>(*
this); }
556 std::optional<StorageClass> storage = std::nullopt);
558 std::optional<StorageClass> storage = std::nullopt);
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
static constexpr StringLiteral name
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
CooperativeMatrixType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
static constexpr StringLiteral name
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
static constexpr StringLiteral name
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
static constexpr StringLiteral name
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static constexpr StringLiteral name
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
ShapedType::Trait< TensorArmType > ShapedTypeTraits
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
@ Type
An inlay hint that for a type annotation.
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
friend bool operator==(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
MemberDecorationInfo(uint32_t index, Decoration decoration, Attribute decorationValue)
Attribute decorationValue
friend bool operator<(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
friend bool operator<(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
friend bool operator==(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Attribute decorationValue