18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
40 return key ==
KeyTy(elementType, elementCount, stride);
44 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
45 stride(std::
get<2>(key)) {}
53 assert(elementCount &&
"ArrayType needs at least one element");
60 assert(elementCount &&
"ArrayType needs at least one element");
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
77 std::optional<StorageClass> storage) {
79 .getCapabilities(capabilities, storage);
84 std::optional<int64_t> size = elementType.getSizeInBytes();
95 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
103 return type.getRank() == 1 &&
104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105 llvm::isa<ScalarType>(type.getElementType());
111 [](
auto type) {
return type.getElementType(); })
116 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
120 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
121 return arrayType.getNumElements();
122 if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this))
123 return matrixType.getNumColumns();
124 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
125 return structType.getNumElements();
126 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this))
127 return vectorType.getNumElements();
128 if (llvm::isa<CooperativeMatrixType>(*
this)) {
130 "invalid to query number of elements of spirv Cooperative Matrix type");
132 if (llvm::isa<RuntimeArrayType>(*
this)) {
134 "invalid to query number of elements of spirv::RuntimeArray type");
136 llvm_unreachable(
"invalid composite type");
140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*
this);
145 std::optional<StorageClass> storage) {
149 [&](
auto type) { type.getExtensions(extensions, storage); })
150 .Case<VectorType>([&](VectorType type) {
151 return llvm::cast<ScalarType>(type.getElementType())
152 .getExtensions(extensions, storage);
154 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
159 std::optional<StorageClass> storage) {
163 [&](
auto type) { type.getCapabilities(capabilities, storage); })
164 .Case<VectorType>([&](VectorType type) {
166 if (vecSize == 8 || vecSize == 16) {
167 static const Capability caps[] = {Capability::Vector16};
169 capabilities.push_back(ref);
171 return llvm::cast<ScalarType>(type.getElementType())
172 .getCapabilities(capabilities, storage);
174 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
178 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
179 return arrayType.getSizeInBytes();
180 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
181 return structType.getSizeInBytes();
182 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this)) {
183 std::optional<int64_t> elementSize =
184 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
187 return *elementSize * vectorType.getNumElements();
198 std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
207 return key ==
KeyTy(elementType,
rows, columns, scope, use);
211 : elementType(std::
get<0>(key)),
rows(std::
get<1>(key)),
212 columns(std::
get<2>(key)), scope(std::
get<3>(key)),
213 use(std::
get<4>(key)) {}
219 CooperativeMatrixUseKHR
use;
224 uint32_t columns, Scope scope,
225 CooperativeMatrixUseKHR use) {
248 std::optional<StorageClass> storage) {
249 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
250 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
251 extensions.push_back(exts);
256 std::optional<StorageClass> storage) {
258 .getCapabilities(capabilities, storage);
259 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
260 capabilities.push_back(caps);
267 template <
typename T>
273 static_assert((1 << 3) > getMaxEnumValForDim(),
274 "Not enough bits to encode Dim value");
279 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
280 "Not enough bits to encode ImageDepthInfo value");
285 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
286 "Not enough bits to encode ImageArrayedInfo value");
291 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
292 "Not enough bits to encode ImageSamplingInfo value");
297 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
298 "Not enough bits to encode ImageSamplerUseInfo value");
303 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
304 "Not enough bits to encode ImageFormat value");
310 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
311 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
319 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
320 samplerUseInfo, format);
324 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
325 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
326 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
327 format(std::
get<6>(key)) {}
340 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
356 return getImpl()->samplingInfo;
360 return getImpl()->samplerUseInfo;
366 std::optional<StorageClass>) {
372 std::optional<StorageClass>) {
373 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
374 capabilities.push_back(*dimCaps);
377 capabilities.push_back(*fmtCaps);
387 using KeyTy = std::pair<Type, StorageClass>;
396 return key ==
KeyTy(pointeeType, storageClass);
400 : pointeeType(key.first), storageClass(key.second) {}
413 return getImpl()->storageClass;
417 std::optional<StorageClass> storage) {
424 extensions.push_back(*scExts);
429 std::optional<StorageClass> storage) {
436 capabilities.push_back(*scCaps);
444 using KeyTy = std::pair<Type, unsigned>;
453 return key ==
KeyTy(elementType, stride);
457 : elementType(key.first), stride(key.second) {}
477 std::optional<StorageClass> storage) {
478 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
483 std::optional<StorageClass> storage) {
485 static const Capability caps[] = {Capability::Shader};
487 capabilities.push_back(ref);
490 .getCapabilities(capabilities, storage);
498 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
501 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
508 return llvm::is_contained({16u, 32u, 64u}, type.
getWidth()) && !type.
isBF16();
512 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
516 std::optional<StorageClass> storage) {
524 case StorageClass::PushConstant:
525 case StorageClass::StorageBuffer:
526 case StorageClass::Uniform:
528 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
530 extensions.push_back(ref);
533 case StorageClass::Input:
534 case StorageClass::Output:
536 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
538 extensions.push_back(ref);
548 std::optional<StorageClass> storage) {
555 #define STORAGE_CASE(storage, cap8, cap16) \
556 case StorageClass::storage: { \
557 if (bitwidth == 8) { \
558 static const Capability caps[] = {Capability::cap8}; \
559 ArrayRef<Capability> ref(caps, std::size(caps)); \
560 capabilities.push_back(ref); \
563 if (bitwidth == 16) { \
564 static const Capability caps[] = {Capability::cap16}; \
565 ArrayRef<Capability> ref(caps, std::size(caps)); \
566 capabilities.push_back(ref); \
577 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
579 StorageBuffer16BitAccess);
580 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
582 case StorageClass::Input:
583 case StorageClass::Output: {
584 if (bitwidth == 16) {
585 static const Capability caps[] = {Capability::StorageInputOutput16};
587 capabilities.push_back(ref);
601 #define WIDTH_CASE(type, width) \
603 static const Capability caps[] = {Capability::type##width}; \
604 ArrayRef<Capability> ref(caps, std::size(caps)); \
605 capabilities.push_back(ref); \
608 if (
auto intType = llvm::dyn_cast<IntegerType>(*
this)) {
617 llvm_unreachable(
"invalid bitwidth to getCapabilities");
620 assert(llvm::isa<FloatType>(*
this));
627 llvm_unreachable(
"invalid bitwidth to getCapabilities");
653 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
655 if (llvm::isa<ScalarType>(type))
657 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
667 std::optional<StorageClass> storage) {
668 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
669 scalarType.getExtensions(extensions, storage);
670 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
671 compositeType.getExtensions(extensions, storage);
672 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
673 imageType.getExtensions(extensions, storage);
674 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
675 sampledImageType.getExtensions(extensions, storage);
676 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
677 matrixType.getExtensions(extensions, storage);
678 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
679 ptrType.getExtensions(extensions, storage);
681 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
687 std::optional<StorageClass> storage) {
688 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
689 scalarType.getCapabilities(capabilities, storage);
690 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
691 compositeType.getCapabilities(capabilities, storage);
692 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
693 imageType.getCapabilities(capabilities, storage);
694 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
695 sampledImageType.getCapabilities(capabilities, storage);
696 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
697 matrixType.getCapabilities(capabilities, storage);
698 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
699 ptrType.getCapabilities(capabilities, storage);
701 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
706 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this))
707 return scalarType.getSizeInBytes();
708 if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this))
709 return compositeType.getSizeInBytes();
747 if (!llvm::isa<ImageType>(imageType))
748 return emitError() <<
"expected image type";
755 std::optional<StorageClass> storage) {
756 llvm::cast<ImageType>(
getImageType()).getExtensions(extensions, storage);
761 std::optional<StorageClass> storage) {
762 llvm::cast<ImageType>(
getImageType()).getCapabilities(capabilities, storage);
788 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
789 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
790 identifier(identifier) {}
795 unsigned numMembers,
Type const *memberTypes,
798 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
799 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
800 memberDecorationsInfo(memberDecorationsInfo) {}
826 if (isIdentified()) {
828 return getIdentifier() == std::get<0>(key);
831 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
832 getMemberDecorationsInfo());
843 StringRef keyIdentifier = std::get<0>(key);
845 if (!keyIdentifier.empty()) {
846 StringRef identifier = allocator.
copyInto(keyIdentifier);
857 const Type *typesList =
nullptr;
858 if (!keyTypes.empty()) {
859 typesList = allocator.
copyInto(keyTypes).data();
863 if (!std::get<2>(key).empty()) {
865 assert(keyOffsetInfo.size() == keyTypes.size() &&
866 "size of offset information must be same as the size of number of "
868 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
872 unsigned numMemberDecorations = 0;
873 if (!std::get<3>(key).empty()) {
874 auto keyMemberDecorations = std::get<3>(key);
875 numMemberDecorations = keyMemberDecorations.size();
876 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
881 numMemberDecorations, memberDecorationList);
885 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
896 if (memberDecorationsInfo) {
898 numMemberDecorations);
922 if (memberTypesAndIsBodySet.getInt() &&
923 (getMemberTypes() != structMemberTypes ||
924 getOffsetInfo() != structOffsetInfo ||
925 getMemberDecorationsInfo() != structMemberDecorationInfo))
928 memberTypesAndIsBodySet.setInt(
true);
929 numMembers = structMemberTypes.size();
932 if (!structMemberTypes.empty())
933 memberTypesAndIsBodySet.setPointer(
934 allocator.
copyInto(structMemberTypes).data());
936 if (!structOffsetInfo.empty()) {
937 assert(structOffsetInfo.size() == structMemberTypes.size() &&
938 "size of offset information must be same as the size of number of "
940 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
943 if (!structMemberDecorationInfo.empty()) {
944 numMemberDecorations = structMemberDecorationInfo.size();
945 memberDecorationsInfo =
946 allocator.
copyInto(structMemberDecorationInfo).data();
964 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
968 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
969 return Base::get(memberTypes.vec().front().getContext(),
970 StringRef(), memberTypes, offsetInfo,
975 StringRef identifier) {
976 assert(!identifier.empty() &&
977 "StructType identifier must be non-empty string");
995 return newStructType;
1006 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1018 return getImpl()->offsetInfo[index];
1024 memberDecorations.clear();
1025 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1026 memberDecorations.append(implMemberDecorations.begin(),
1027 implMemberDecorations.end());
1034 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1035 decorationsInfo.clear();
1036 for (
const auto &memberDecoration : memberDecorations) {
1037 if (memberDecoration.memberIndex == index) {
1038 decorationsInfo.push_back(memberDecoration);
1040 if (memberDecoration.memberIndex > index) {
1051 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1055 std::optional<StorageClass> storage) {
1057 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1062 std::optional<StorageClass> storage) {
1064 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1069 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1079 : columnType(columnType), columnCount(columnCount) {}
1081 using KeyTy = std::tuple<Type, uint32_t>;
1092 return key ==
KeyTy(columnType, columnCount);
1104 Type columnType, uint32_t columnCount) {
1111 Type columnType, uint32_t columnCount) {
1112 if (columnCount < 2 || columnCount > 4)
1113 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1116 return emitError() <<
"matrix columns must be vectors of floats";
1120 if (columnShape.size() != 1)
1121 return emitError() <<
"matrix columns must be 1D vectors";
1123 if (columnShape[0] < 2 || columnShape[0] > 4)
1124 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1131 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1132 if (llvm::isa<FloatType>(vectorType.getElementType()))
1141 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1147 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1155 std::optional<StorageClass> storage) {
1156 llvm::cast<SPIRVType>(
getColumnType()).getExtensions(extensions, storage);
1161 std::optional<StorageClass> storage) {
1163 static const Capability caps[] = {Capability::Matrix};
1165 capabilities.push_back(ref);
1168 llvm::cast<SPIRVType>(
getColumnType()).getCapabilities(capabilities, storage);
1175 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const