13#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
61 std::optional<StorageClass> storage = std::nullopt);
73 std::optional<StorageClass> storage = std::nullopt);
84 using SPIRVType::SPIRVType;
91 static bool isValid(IntegerType);
98 using SPIRVType::SPIRVType;
103 static bool isValid(VectorType);
118 detail::ArrayTypeStorage> {
122 static constexpr StringLiteral
name =
"spirv.array";
141 :
public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
145 static constexpr StringLiteral
name =
"spirv.image";
149 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
150 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
151 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
152 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
153 ImageFormat format = ImageFormat::Unknown) {
155 std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
156 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
157 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
162 get(std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
163 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
177 :
public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
178 VectorElementTypeInterface::Trait> {
182 static constexpr StringLiteral
name =
"spirv.pointer";
194 detail::RuntimeArrayTypeStorage> {
198 static constexpr StringLiteral
name =
"spirv.rtarray";
215 detail::SampledImageTypeStorage> {
219 static constexpr StringLiteral
name =
"spirv.sampled_image";
238 static constexpr StringLiteral
name =
"spirv.sampler";
245 :
public Type::TypeBase<NamedBarrierType, SPIRVType, TypeStorage> {
249 static constexpr StringLiteral
name =
"spirv.named_barrier";
274 detail::StructTypeStorage, TypeTrait::IsMutable> {
281 static constexpr StringLiteral
name =
"spirv.struct";
298 return lhs.memberIndex ==
rhs.memberIndex &&
299 lhs.decoration ==
rhs.decoration &&
300 lhs.decorationValue ==
rhs.decorationValue;
305 return std::tuple(
lhs.memberIndex, llvm::to_underlying(
lhs.decoration)) <
306 std::tuple(
rhs.memberIndex, llvm::to_underlying(
rhs.decoration));
322 return lhs.decoration ==
rhs.decoration &&
323 lhs.decorationValue ==
rhs.decorationValue;
328 return llvm::to_underlying(
lhs.decoration) <
329 llvm::to_underlying(
rhs.decoration);
339 ArrayRef<StructDecorationInfo> structDecorations = {});
348 static StructType
getIdentified(MLIRContext *context, StringRef identifier);
356 static StructType
getEmpty(MLIRContext *context, StringRef identifier =
"");
381 &memberDecorations)
const;
387 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo)
const;
392 &structDecorations)
const;
398 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
399 ArrayRef<MemberDecorationInfo> memberDecorations = {},
400 ArrayRef<StructDecorationInfo> structDecorations = {});
412 detail::CooperativeMatrixTypeStorage,
417 static constexpr StringLiteral
name =
"spirv.coopmatrix";
420 uint32_t columns, Scope scope,
421 CooperativeMatrixUseKHR use);
431 CooperativeMatrixUseKHR
getUse()
const;
433 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
440 Type elementType)
const {
444 assert(
shape.value().size() == 2);
453 detail::MatrixTypeStorage, ShapedType::Trait> {
457 static constexpr StringLiteral
name =
"spirv.matrix";
462 Type columnType, uint32_t columnCount);
466 Type columnType, uint32_t columnCount);
485 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
492 Type elementType)
const {
496 assert(
shape.value().size() == 2);
498 auto vectorType = cast<VectorType>(elementType);
499 Type newElementType =
500 vectorType.cloneWith({
shape.value()[0]}, vectorType.getElementType());
502 return get(newElementType,
shape.value()[1]);
509 detail::TensorArmTypeStorage, ShapedType::Trait> {
514 using ShapedTypeTraits::getDimSize;
515 using ShapedTypeTraits::getDynamicDimIndex;
516 using ShapedTypeTraits::getElementTypeBitWidth;
517 using ShapedTypeTraits::getNumDynamicDims;
518 using ShapedTypeTraits::getNumElements;
519 using ShapedTypeTraits::getRank;
520 using ShapedTypeTraits::hasStaticShape;
521 using ShapedTypeTraits::isDynamicDim;
523 static constexpr StringLiteral
name =
"spirv.arm.tensor";
529 Type elementType)
const;
538 operator ShapedType()
const {
return cast<ShapedType>(*
this); }
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
detail::StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > TypeBase
Utility class for implementing types.
StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > Base
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
static constexpr StringLiteral name
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
CooperativeMatrixType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
static constexpr StringLiteral name
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static constexpr StringLiteral name
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
MatrixType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
ArrayRef< int64_t > getShape() const
static constexpr StringLiteral name
unsigned getNumRows() const
Returns the number of rows.
static NamedBarrierType get(MLIRContext *context)
static constexpr StringLiteral name
Type getPointeeType() const
static constexpr StringLiteral name
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static constexpr StringLiteral name
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
static constexpr StringLiteral name
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static constexpr StringLiteral name
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
ShapedType::Trait< TensorArmType > ShapedTypeTraits
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
friend bool operator==(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
MemberDecorationInfo(uint32_t index, Decoration decoration, Attribute decorationValue)
Attribute decorationValue
friend bool operator<(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
friend bool operator<(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
friend bool operator==(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Attribute decorationValue
Type storage for SPIR-V structure types: