13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
30 struct ArrayTypeStorage;
31 struct CooperativeMatrixTypeStorage;
32 struct TensorArmTypeStorage;
33 struct ImageTypeStorage;
34 struct MatrixTypeStorage;
35 struct PointerTypeStorage;
36 struct RuntimeArrayTypeStorage;
37 struct SampledImageTypeStorage;
38 struct StructTypeStorage;
60 std::optional<StorageClass> storage = std::nullopt);
72 std::optional<StorageClass> storage = std::nullopt);
83 using SPIRVType::SPIRVType;
90 static bool isValid(IntegerType);
97 using SPIRVType::SPIRVType;
102 static bool isValid(VectorType);
117 detail::ArrayTypeStorage> {
121 static constexpr StringLiteral
name =
"spirv.array";
140 :
public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
144 static constexpr StringLiteral
name =
"spirv.image";
148 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
149 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
150 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
151 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
152 ImageFormat format = ImageFormat::Unknown) {
154 std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
155 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
156 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
161 get(std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
162 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
176 detail::PointerTypeStorage> {
180 static constexpr StringLiteral
name =
"spirv.pointer";
192 detail::RuntimeArrayTypeStorage> {
196 static constexpr StringLiteral
name =
"spirv.rtarray";
213 detail::SampledImageTypeStorage> {
217 static constexpr StringLiteral
name =
"spirv.sampled_image";
251 detail::StructTypeStorage, TypeTrait::IsMutable> {
258 static constexpr StringLiteral
name =
"spirv.struct";
316 ArrayRef<StructDecorationInfo> structDecorations = {});
325 static StructType
getIdentified(MLIRContext *context, StringRef identifier);
333 static StructType
getEmpty(MLIRContext *context, StringRef identifier =
"");
358 &memberDecorations)
const;
364 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo)
const;
369 &structDecorations)
const;
375 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
376 ArrayRef<MemberDecorationInfo> memberDecorations = {},
377 ArrayRef<StructDecorationInfo> structDecorations = {});
381 hash_value(
const StructType::MemberDecorationInfo &memberDecorationInfo);
384 hash_value(
const StructType::StructDecorationInfo &structDecorationInfo);
389 detail::CooperativeMatrixTypeStorage,
394 static constexpr StringLiteral
name =
"spirv.coopmatrix";
397 uint32_t columns, Scope scope,
398 CooperativeMatrixUseKHR use);
408 CooperativeMatrixUseKHR
getUse()
const;
410 operator ShapedType()
const {
return llvm::cast<ShapedType>(*
this); }
417 Type elementType)
const {
421 assert(shape.value().size() == 2);
422 return get(elementType, shape.value()[0], shape.value()[1],
getScope(),
429 detail::MatrixTypeStorage> {
433 static constexpr StringLiteral
name =
"spirv.matrix";
438 Type columnType, uint32_t columnCount);
442 Type columnType, uint32_t columnCount);
465 detail::TensorArmTypeStorage, ShapedType::Trait> {
470 using ShapedTypeTraits::getDimSize;
471 using ShapedTypeTraits::getDynamicDimIndex;
472 using ShapedTypeTraits::getElementTypeBitWidth;
473 using ShapedTypeTraits::getNumDynamicDims;
475 using ShapedTypeTraits::getRank;
476 using ShapedTypeTraits::hasStaticShape;
477 using ShapedTypeTraits::isDynamicDim;
479 static constexpr StringLiteral
name =
"spirv.arm.tensor";
485 Type elementType)
const;
494 operator ShapedType()
const {
return llvm::cast<ShapedType>(*
this); }
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
static constexpr StringLiteral name
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
static constexpr StringLiteral name
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
CooperativeMatrixType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
static constexpr StringLiteral name
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
static constexpr StringLiteral name
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
static constexpr StringLiteral name
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
static constexpr StringLiteral name
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static constexpr StringLiteral name
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
ShapedType::Trait< TensorArmType > ShapedTypeTraits
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
friend bool operator==(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
MemberDecorationInfo(uint32_t index, Decoration decoration, Attribute decorationValue)
Attribute decorationValue
friend bool operator<(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
friend bool operator<(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
friend bool operator==(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Attribute decorationValue