18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
31 using KeyTy = std::tuple<Type, unsigned, unsigned>;
39 return key ==
KeyTy(elementType, elementCount, stride);
43 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
44 stride(std::get<2>(key)) {}
52 assert(elementCount &&
"ArrayType needs at least one element");
53 return Base::get(elementType.
getContext(), elementType, elementCount,
59 assert(elementCount &&
"ArrayType needs at least one element");
60 return Base::get(elementType.
getContext(), elementType, elementCount, stride);
70 std::optional<StorageClass> storage) {
76 std::optional<StorageClass> storage) {
93 if (
auto vectorType = type.
dyn_cast<VectorType>())
101 switch (type.getNumElements()) {
111 return type.getRank() == 1 && type.getElementType().isa<
ScalarType>();
118 [](
auto type) {
return type.getElementType(); })
123 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
127 if (
auto arrayType = dyn_cast<ArrayType>())
128 return arrayType.getNumElements();
129 if (
auto matrixType = dyn_cast<MatrixType>())
130 return matrixType.getNumColumns();
131 if (
auto structType = dyn_cast<StructType>())
132 return structType.getNumElements();
133 if (
auto vectorType = dyn_cast<VectorType>())
134 return vectorType.getNumElements();
135 if (isa<CooperativeMatrixNVType>()) {
137 "invalid to query number of elements of spirv::CooperativeMatrix type");
139 if (isa<JointMatrixINTELType>()) {
141 "invalid to query number of elements of spirv::JointMatrix type");
143 if (isa<RuntimeArrayType>()) {
145 "invalid to query number of elements of spirv::RuntimeArray type");
147 llvm_unreachable(
"invalid composite type");
157 std::optional<StorageClass> storage) {
161 [&](
auto type) { type.getExtensions(extensions, storage); })
162 .Case<VectorType>([&](VectorType type) {
164 extensions, storage);
166 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
171 std::optional<StorageClass> storage) {
175 [&](
auto type) { type.getCapabilities(capabilities, storage); })
176 .Case<VectorType>([&](VectorType type) {
178 if (vecSize == 8 || vecSize == 16) {
179 static const Capability caps[] = {Capability::Vector16};
181 capabilities.push_back(ref);
184 capabilities, storage);
186 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
190 if (
auto arrayType = dyn_cast<ArrayType>())
191 return arrayType.getSizeInBytes();
192 if (
auto structType = dyn_cast<StructType>())
193 return structType.getSizeInBytes();
194 if (
auto vectorType = dyn_cast<VectorType>()) {
195 std::optional<int64_t> elementSize =
199 return *elementSize * vectorType.getNumElements();
209 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
218 return key ==
KeyTy(elementType, scope, rows, columns);
222 : elementType(std::get<0>(key)), rows(std::get<2>(key)),
223 columns(std::get<3>(key)), scope(std::get<1>(key)) {}
232 Scope scope,
unsigned rows,
234 return Base::get(elementType.
getContext(), elementType, scope, rows, columns);
251 std::optional<StorageClass> storage) {
253 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
255 extensions.push_back(ref);
260 std::optional<StorageClass> storage) {
262 static const Capability caps[] = {Capability::CooperativeMatrixNV};
264 capabilities.push_back(ref);
272 using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
281 return key ==
KeyTy(elementType, rows, columns, matrixLayout, scope);
285 : elementType(std::get<0>(key)), rows(std::get<1>(key)),
286 columns(std::get<2>(key)), scope(std::get<4>(key)),
287 matrixLayout(std::get<3>(key)) {}
297 unsigned rows,
unsigned columns,
298 MatrixLayout matrixLayout) {
299 return Base::get(elementType.
getContext(), elementType, rows, columns,
300 matrixLayout, scope);
314 return getImpl()->matrixLayout;
319 std::optional<StorageClass> storage) {
321 static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
323 extensions.push_back(ref);
328 std::optional<StorageClass> storage) {
330 static const Capability caps[] = {Capability::JointMatrixINTEL};
332 capabilities.push_back(ref);
339 template <
typename T>
345 static_assert((1 << 3) > getMaxEnumValForDim(),
346 "Not enough bits to encode Dim value");
351 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
352 "Not enough bits to encode ImageDepthInfo value");
357 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
358 "Not enough bits to encode ImageArrayedInfo value");
363 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
364 "Not enough bits to encode ImageSamplingInfo value");
369 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
370 "Not enough bits to encode ImageSamplerUseInfo value");
375 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
376 "Not enough bits to encode ImageFormat value");
382 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
383 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
391 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
392 samplerUseInfo, format);
396 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
397 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
398 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
399 format(std::get<6>(key)) {}
412 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
414 return Base::get(std::get<0>(value).getContext(), value);
428 return getImpl()->samplingInfo;
432 return getImpl()->samplerUseInfo;
438 std::optional<StorageClass>) {
444 std::optional<StorageClass>) {
445 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
446 capabilities.push_back(*dimCaps);
449 capabilities.push_back(*fmtCaps);
459 using KeyTy = std::pair<Type, StorageClass>;
468 return key ==
KeyTy(pointeeType, storageClass);
472 : pointeeType(key.first), storageClass(key.second) {}
479 return Base::get(pointeeType.
getContext(), pointeeType, storageClass);
485 return getImpl()->storageClass;
489 std::optional<StorageClass> storage) {
496 extensions.push_back(*scExts);
501 std::optional<StorageClass> storage) {
508 capabilities.push_back(*scCaps);
516 using KeyTy = std::pair<Type, unsigned>;
525 return key ==
KeyTy(elementType, stride);
529 : elementType(key.first), stride(key.second) {}
536 return Base::get(elementType.
getContext(), elementType, 0);
540 return Base::get(elementType.
getContext(), elementType, stride);
549 std::optional<StorageClass> storage) {
555 std::optional<StorageClass> storage) {
557 static const Capability caps[] = {Capability::Shader};
559 capabilities.push_back(ref);
572 if (
auto intType = type.
dyn_cast<IntegerType>()) {
579 return llvm::is_contained({16u, 32u, 64u}, type.
getWidth()) && !type.
isBF16();
583 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
587 std::optional<StorageClass> storage) {
595 case StorageClass::PushConstant:
596 case StorageClass::StorageBuffer:
597 case StorageClass::Uniform:
599 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
601 extensions.push_back(ref);
604 case StorageClass::Input:
605 case StorageClass::Output:
607 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
609 extensions.push_back(ref);
619 std::optional<StorageClass> storage) {
626 #define STORAGE_CASE(storage, cap8, cap16) \
627 case StorageClass::storage: { \
628 if (bitwidth == 8) { \
629 static const Capability caps[] = {Capability::cap8}; \
630 ArrayRef<Capability> ref(caps, std::size(caps)); \
631 capabilities.push_back(ref); \
634 if (bitwidth == 16) { \
635 static const Capability caps[] = {Capability::cap16}; \
636 ArrayRef<Capability> ref(caps, std::size(caps)); \
637 capabilities.push_back(ref); \
648 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
650 StorageBuffer16BitAccess);
651 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
653 case StorageClass::Input:
654 case StorageClass::Output: {
655 if (bitwidth == 16) {
656 static const Capability caps[] = {Capability::StorageInputOutput16};
658 capabilities.push_back(ref);
672 #define WIDTH_CASE(type, width) \
674 static const Capability caps[] = {Capability::type##width}; \
675 ArrayRef<Capability> ref(caps, std::size(caps)); \
676 capabilities.push_back(ref); \
679 if (
auto intType = dyn_cast<IntegerType>()) {
688 llvm_unreachable(
"invalid bitwidth to getCapabilities");
691 assert(isa<FloatType>());
698 llvm_unreachable(
"invalid bitwidth to getCapabilities");
724 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
728 if (
auto vectorType = type.
dyn_cast<VectorType>())
738 std::optional<StorageClass> storage) {
739 if (
auto scalarType = dyn_cast<ScalarType>()) {
740 scalarType.getExtensions(extensions, storage);
741 }
else if (
auto compositeType = dyn_cast<CompositeType>()) {
742 compositeType.getExtensions(extensions, storage);
743 }
else if (
auto imageType = dyn_cast<ImageType>()) {
744 imageType.getExtensions(extensions, storage);
745 }
else if (
auto sampledImageType = dyn_cast<SampledImageType>()) {
746 sampledImageType.getExtensions(extensions, storage);
747 }
else if (
auto matrixType = dyn_cast<MatrixType>()) {
748 matrixType.getExtensions(extensions, storage);
749 }
else if (
auto ptrType = dyn_cast<PointerType>()) {
750 ptrType.getExtensions(extensions, storage);
752 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
758 std::optional<StorageClass> storage) {
759 if (
auto scalarType = dyn_cast<ScalarType>()) {
760 scalarType.getCapabilities(capabilities, storage);
761 }
else if (
auto compositeType = dyn_cast<CompositeType>()) {
762 compositeType.getCapabilities(capabilities, storage);
763 }
else if (
auto imageType = dyn_cast<ImageType>()) {
764 imageType.getCapabilities(capabilities, storage);
765 }
else if (
auto sampledImageType = dyn_cast<SampledImageType>()) {
766 sampledImageType.getCapabilities(capabilities, storage);
767 }
else if (
auto matrixType = dyn_cast<MatrixType>()) {
768 matrixType.getCapabilities(capabilities, storage);
769 }
else if (
auto ptrType = dyn_cast<PointerType>()) {
770 ptrType.getCapabilities(capabilities, storage);
772 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
777 if (
auto scalarType = dyn_cast<ScalarType>())
778 return scalarType.getSizeInBytes();
779 if (
auto compositeType = dyn_cast<CompositeType>())
780 return compositeType.getSizeInBytes();
804 return Base::get(imageType.
getContext(), imageType);
819 return emitError() <<
"expected image type";
826 std::optional<StorageClass> storage) {
832 std::optional<StorageClass> storage) {
859 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
860 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
861 identifier(identifier) {}
866 unsigned numMembers,
Type const *memberTypes,
869 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
870 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
871 memberDecorationsInfo(memberDecorationsInfo) {}
897 if (isIdentified()) {
899 return getIdentifier() == std::get<0>(key);
902 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
903 getMemberDecorationsInfo());
914 StringRef keyIdentifier = std::get<0>(key);
916 if (!keyIdentifier.empty()) {
917 StringRef identifier = allocator.
copyInto(keyIdentifier);
928 const Type *typesList =
nullptr;
929 if (!keyTypes.empty()) {
930 typesList = allocator.
copyInto(keyTypes).data();
934 if (!std::get<2>(key).empty()) {
936 assert(keyOffsetInfo.size() == keyTypes.size() &&
937 "size of offset information must be same as the size of number of "
939 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
943 unsigned numMemberDecorations = 0;
944 if (!std::get<3>(key).empty()) {
945 auto keyMemberDecorations = std::get<3>(key);
946 numMemberDecorations = keyMemberDecorations.size();
947 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
952 numMemberDecorations, memberDecorationList);
956 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
967 if (memberDecorationsInfo) {
969 numMemberDecorations);
993 if (memberTypesAndIsBodySet.getInt() &&
994 (getMemberTypes() != structMemberTypes ||
995 getOffsetInfo() != structOffsetInfo ||
996 getMemberDecorationsInfo() != structMemberDecorationInfo))
999 memberTypesAndIsBodySet.setInt(
true);
1000 numMembers = structMemberTypes.size();
1003 if (!structMemberTypes.empty())
1004 memberTypesAndIsBodySet.setPointer(
1005 allocator.
copyInto(structMemberTypes).data());
1007 if (!structOffsetInfo.empty()) {
1008 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1009 "size of offset information must be same as the size of number of "
1011 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
1014 if (!structMemberDecorationInfo.empty()) {
1015 numMemberDecorations = structMemberDecorationInfo.size();
1016 memberDecorationsInfo =
1017 allocator.
copyInto(structMemberDecorationInfo).data();
1035 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1038 memberDecorations.begin(), memberDecorations.end());
1039 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1040 return Base::get(memberTypes.vec().front().getContext(),
1041 StringRef(), memberTypes, offsetInfo,
1046 StringRef identifier) {
1047 assert(!identifier.empty() &&
1048 "StructType identifier must be non-empty string");
1066 return newStructType;
1077 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1089 return getImpl()->offsetInfo[index];
1095 memberDecorations.clear();
1096 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1097 memberDecorations.append(implMemberDecorations.begin(),
1098 implMemberDecorations.end());
1105 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1106 decorationsInfo.clear();
1107 for (
const auto &memberDecoration : memberDecorations) {
1108 if (memberDecoration.memberIndex == index) {
1109 decorationsInfo.push_back(memberDecoration);
1111 if (memberDecoration.memberIndex > index) {
1122 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1126 std::optional<StorageClass> storage) {
1133 std::optional<StorageClass> storage) {
1140 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1150 : columnType(columnType), columnCount(columnCount) {}
1152 using KeyTy = std::tuple<Type, uint32_t>;
1163 return key ==
KeyTy(columnType, columnCount);
1171 return Base::get(columnType.
getContext(), columnType, columnCount);
1175 Type columnType, uint32_t columnCount) {
1181 Type columnType, uint32_t columnCount) {
1182 if (columnCount < 2 || columnCount > 4)
1183 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1186 return emitError() <<
"matrix columns must be vectors of floats";
1190 if (columnShape.size() != 1)
1191 return emitError() <<
"matrix columns must be 1D vectors";
1193 if (columnShape[0] < 2 || columnShape[0] > 4)
1194 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1201 if (
auto vectorType = columnType.
dyn_cast<VectorType>()) {
1202 if (vectorType.getElementType().isa<
FloatType>())
1225 std::optional<StorageClass> storage) {
1231 std::optional<StorageClass> storage) {
1233 static const Capability caps[] = {Capability::Matrix};
1235 capabilities.push_back(ref);
1245 void SPIRVDialect::registerTypes() {
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
unsigned getRows() const
return the number of rows of the matrix.
unsigned getColumns() const
return the number of columns of the matrix.
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Return the scope of the cooperative matrix.
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Return the scope of the joint matrix.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getColumns() const
return the number of columns of the matrix.
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
unsigned getRows() const
return the number of rows of the matrix.
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Type getElementType() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
Range class for element types.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
ElementTypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
CooperativeMatrixTypeStorage(const KeyTy &key)
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
JointMatrixTypeStorage(const KeyTy &key)
MatrixLayout matrixLayout
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const