18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
40 return key ==
KeyTy(elementType, elementCount, stride);
44 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
45 stride(std::
get<2>(key)) {}
53 assert(elementCount &&
"ArrayType needs at least one element");
60 assert(elementCount &&
"ArrayType needs at least one element");
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
77 std::optional<StorageClass> storage) {
79 .getCapabilities(capabilities, storage);
84 std::optional<int64_t> size = elementType.getSizeInBytes();
95 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
103 return type.getRank() == 1 &&
104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105 llvm::isa<ScalarType>(type.getElementType());
111 [](
auto type) {
return type.getElementType(); })
116 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
120 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
121 return arrayType.getNumElements();
122 if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this))
123 return matrixType.getNumColumns();
124 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
125 return structType.getNumElements();
126 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this))
127 return vectorType.getNumElements();
128 if (llvm::isa<CooperativeMatrixType>(*
this)) {
130 "invalid to query number of elements of spirv Cooperative Matrix type");
132 if (llvm::isa<RuntimeArrayType>(*
this)) {
134 "invalid to query number of elements of spirv::RuntimeArray type");
136 llvm_unreachable(
"invalid composite type");
140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*
this);
145 std::optional<StorageClass> storage) {
149 [&](
auto type) { type.getExtensions(extensions, storage); })
150 .Case<VectorType>([&](VectorType type) {
151 return llvm::cast<ScalarType>(type.getElementType())
152 .getExtensions(extensions, storage);
154 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
159 std::optional<StorageClass> storage) {
163 [&](
auto type) { type.getCapabilities(capabilities, storage); })
164 .Case<VectorType>([&](VectorType type) {
166 if (vecSize == 8 || vecSize == 16) {
167 static const Capability caps[] = {Capability::Vector16};
169 capabilities.push_back(ref);
171 return llvm::cast<ScalarType>(type.getElementType())
172 .getCapabilities(capabilities, storage);
174 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
178 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
179 return arrayType.getSizeInBytes();
180 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
181 return structType.getSizeInBytes();
182 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this)) {
183 std::optional<int64_t> elementSize =
184 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
187 return *elementSize * vectorType.getNumElements();
211 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
220 return key ==
KeyTy(elementType, shape[0], shape[1], scope, use);
224 : elementType(std::
get<0>(key)),
225 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
226 use(std::get<4>(key)) {}
232 CooperativeMatrixUseKHR
use;
237 uint32_t columns, Scope scope,
238 CooperativeMatrixUseKHR use) {
248 assert(
getImpl()->shape[0] != ShapedType::kDynamic);
249 return static_cast<uint32_t
>(
getImpl()->shape[0]);
253 assert(
getImpl()->shape[1] != ShapedType::kDynamic);
254 return static_cast<uint32_t
>(
getImpl()->shape[1]);
269 std::optional<StorageClass> storage) {
270 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
271 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
272 extensions.push_back(exts);
277 std::optional<StorageClass> storage) {
279 .getCapabilities(capabilities, storage);
280 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
281 capabilities.push_back(caps);
288 template <
typename T>
294 static_assert((1 << 3) > getMaxEnumValForDim(),
295 "Not enough bits to encode Dim value");
300 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
301 "Not enough bits to encode ImageDepthInfo value");
306 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
307 "Not enough bits to encode ImageArrayedInfo value");
312 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
313 "Not enough bits to encode ImageSamplingInfo value");
318 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
319 "Not enough bits to encode ImageSamplerUseInfo value");
324 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
325 "Not enough bits to encode ImageFormat value");
331 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
332 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
340 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
341 samplerUseInfo, format);
345 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
346 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
347 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
348 format(std::
get<6>(key)) {}
361 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
377 return getImpl()->samplingInfo;
381 return getImpl()->samplerUseInfo;
387 std::optional<StorageClass>) {
393 std::optional<StorageClass>) {
394 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
395 capabilities.push_back(*dimCaps);
398 capabilities.push_back(*fmtCaps);
408 using KeyTy = std::pair<Type, StorageClass>;
417 return key ==
KeyTy(pointeeType, storageClass);
421 : pointeeType(key.first), storageClass(key.second) {}
434 return getImpl()->storageClass;
438 std::optional<StorageClass> storage) {
445 extensions.push_back(*scExts);
450 std::optional<StorageClass> storage) {
457 capabilities.push_back(*scCaps);
465 using KeyTy = std::pair<Type, unsigned>;
474 return key ==
KeyTy(elementType, stride);
478 : elementType(key.first), stride(key.second) {}
498 std::optional<StorageClass> storage) {
499 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
504 std::optional<StorageClass> storage) {
506 static const Capability caps[] = {Capability::Shader};
508 capabilities.push_back(ref);
511 .getCapabilities(capabilities, storage);
519 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
522 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
529 return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
533 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
537 std::optional<StorageClass> storage) {
545 case StorageClass::PushConstant:
546 case StorageClass::StorageBuffer:
547 case StorageClass::Uniform:
549 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
551 extensions.push_back(ref);
554 case StorageClass::Input:
555 case StorageClass::Output:
557 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
559 extensions.push_back(ref);
569 std::optional<StorageClass> storage) {
576 #define STORAGE_CASE(storage, cap8, cap16) \
577 case StorageClass::storage: { \
578 if (bitwidth == 8) { \
579 static const Capability caps[] = {Capability::cap8}; \
580 ArrayRef<Capability> ref(caps, std::size(caps)); \
581 capabilities.push_back(ref); \
584 if (bitwidth == 16) { \
585 static const Capability caps[] = {Capability::cap16}; \
586 ArrayRef<Capability> ref(caps, std::size(caps)); \
587 capabilities.push_back(ref); \
598 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
600 StorageBuffer16BitAccess);
603 case StorageClass::Input:
604 case StorageClass::Output: {
605 if (bitwidth == 16) {
606 static const Capability caps[] = {Capability::StorageInputOutput16};
608 capabilities.push_back(ref);
622 #define WIDTH_CASE(type, width) \
624 static const Capability caps[] = {Capability::type##width}; \
625 ArrayRef<Capability> ref(caps, std::size(caps)); \
626 capabilities.push_back(ref); \
629 if (
auto intType = llvm::dyn_cast<IntegerType>(*
this)) {
638 llvm_unreachable(
"invalid bitwidth to getCapabilities");
641 assert(llvm::isa<FloatType>(*
this));
648 llvm_unreachable(
"invalid bitwidth to getCapabilities");
674 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
676 if (llvm::isa<ScalarType>(type))
678 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
688 std::optional<StorageClass> storage) {
689 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
690 scalarType.getExtensions(extensions, storage);
691 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
692 compositeType.getExtensions(extensions, storage);
693 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
694 imageType.getExtensions(extensions, storage);
695 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
696 sampledImageType.getExtensions(extensions, storage);
697 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
698 matrixType.getExtensions(extensions, storage);
699 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
700 ptrType.getExtensions(extensions, storage);
702 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
708 std::optional<StorageClass> storage) {
709 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
710 scalarType.getCapabilities(capabilities, storage);
711 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
712 compositeType.getCapabilities(capabilities, storage);
713 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
714 imageType.getCapabilities(capabilities, storage);
715 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
716 sampledImageType.getCapabilities(capabilities, storage);
717 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
718 matrixType.getCapabilities(capabilities, storage);
719 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
720 ptrType.getCapabilities(capabilities, storage);
722 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
727 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this))
728 return scalarType.getSizeInBytes();
729 if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this))
730 return compositeType.getSizeInBytes();
768 if (!llvm::isa<ImageType>(imageType))
769 return emitError() <<
"expected image type";
776 std::optional<StorageClass> storage) {
777 llvm::cast<ImageType>(
getImageType()).getExtensions(extensions, storage);
782 std::optional<StorageClass> storage) {
783 llvm::cast<ImageType>(
getImageType()).getCapabilities(capabilities, storage);
809 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
810 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
811 identifier(identifier) {}
816 unsigned numMembers,
Type const *memberTypes,
819 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
820 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
821 memberDecorationsInfo(memberDecorationsInfo) {}
847 if (isIdentified()) {
849 return getIdentifier() == std::get<0>(key);
852 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
853 getMemberDecorationsInfo());
864 StringRef keyIdentifier = std::get<0>(key);
866 if (!keyIdentifier.empty()) {
867 StringRef identifier = allocator.
copyInto(keyIdentifier);
878 const Type *typesList =
nullptr;
879 if (!keyTypes.empty()) {
880 typesList = allocator.
copyInto(keyTypes).data();
884 if (!std::get<2>(key).empty()) {
886 assert(keyOffsetInfo.size() == keyTypes.size() &&
887 "size of offset information must be same as the size of number of "
889 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
893 unsigned numMemberDecorations = 0;
894 if (!std::get<3>(key).empty()) {
895 auto keyMemberDecorations = std::get<3>(key);
896 numMemberDecorations = keyMemberDecorations.size();
897 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
902 numMemberDecorations, memberDecorationList);
906 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
917 if (memberDecorationsInfo) {
919 numMemberDecorations);
943 if (memberTypesAndIsBodySet.getInt() &&
944 (getMemberTypes() != structMemberTypes ||
945 getOffsetInfo() != structOffsetInfo ||
946 getMemberDecorationsInfo() != structMemberDecorationInfo))
949 memberTypesAndIsBodySet.setInt(
true);
950 numMembers = structMemberTypes.size();
953 if (!structMemberTypes.empty())
954 memberTypesAndIsBodySet.setPointer(
955 allocator.
copyInto(structMemberTypes).data());
957 if (!structOffsetInfo.empty()) {
958 assert(structOffsetInfo.size() == structMemberTypes.size() &&
959 "size of offset information must be same as the size of number of "
961 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
964 if (!structMemberDecorationInfo.empty()) {
965 numMemberDecorations = structMemberDecorationInfo.size();
966 memberDecorationsInfo =
967 allocator.
copyInto(structMemberDecorationInfo).data();
985 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
989 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
990 return Base::get(memberTypes.vec().front().getContext(),
991 StringRef(), memberTypes, offsetInfo,
996 StringRef identifier) {
997 assert(!identifier.empty() &&
998 "StructType identifier must be non-empty string");
1016 return newStructType;
1027 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1039 return getImpl()->offsetInfo[index];
1045 memberDecorations.clear();
1046 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1047 memberDecorations.append(implMemberDecorations.begin(),
1048 implMemberDecorations.end());
1055 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1056 decorationsInfo.clear();
1057 for (
const auto &memberDecoration : memberDecorations) {
1058 if (memberDecoration.memberIndex == index) {
1059 decorationsInfo.push_back(memberDecoration);
1061 if (memberDecoration.memberIndex > index) {
1072 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1076 std::optional<StorageClass> storage) {
1078 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1083 std::optional<StorageClass> storage) {
1085 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1090 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1100 : columnType(columnType), columnCount(columnCount) {}
1102 using KeyTy = std::tuple<Type, uint32_t>;
1113 return key ==
KeyTy(columnType, columnCount);
1125 Type columnType, uint32_t columnCount) {
1132 Type columnType, uint32_t columnCount) {
1133 if (columnCount < 2 || columnCount > 4)
1134 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1137 return emitError() <<
"matrix columns must be vectors of floats";
1141 if (columnShape.size() != 1)
1142 return emitError() <<
"matrix columns must be 1D vectors";
1144 if (columnShape[0] < 2 || columnShape[0] > 4)
1145 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1152 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1153 if (llvm::isa<FloatType>(vectorType.getElementType()))
1162 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1168 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1176 std::optional<StorageClass> storage) {
1177 llvm::cast<SPIRVType>(
getColumnType()).getExtensions(extensions, storage);
1182 std::optional<StorageClass> storage) {
1184 static const Capability caps[] = {Capability::Matrix};
1186 capabilities.push_back(ref);
1189 llvm::cast<SPIRVType>(
getColumnType()).getCapabilities(capabilities, storage);
1196 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::array< int64_t, 2 > shape
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const