18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
40 return key ==
KeyTy(elementType, elementCount, stride);
44 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
45 stride(std::
get<2>(key)) {}
53 assert(elementCount &&
"ArrayType needs at least one element");
60 assert(elementCount &&
"ArrayType needs at least one element");
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
77 std::optional<StorageClass> storage) {
79 .getCapabilities(capabilities, storage);
84 std::optional<int64_t> size = elementType.getSizeInBytes();
95 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
103 return type.getRank() == 1 &&
104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105 llvm::isa<ScalarType>(type.getElementType());
112 [](
auto type) {
return type.getElementType(); })
117 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
121 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
122 return arrayType.getNumElements();
123 if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this))
124 return matrixType.getNumColumns();
125 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
126 return structType.getNumElements();
127 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this))
128 return vectorType.getNumElements();
129 if (llvm::isa<CooperativeMatrixType>(*
this)) {
131 "invalid to query number of elements of spirv Cooperative Matrix type");
133 if (llvm::isa<JointMatrixINTELType>(*
this)) {
135 "invalid to query number of elements of spirv::JointMatrix type");
137 if (llvm::isa<RuntimeArrayType>(*
this)) {
139 "invalid to query number of elements of spirv::RuntimeArray type");
141 llvm_unreachable(
"invalid composite type");
151 std::optional<StorageClass> storage) {
155 [&](
auto type) { type.getExtensions(extensions, storage); })
156 .Case<VectorType>([&](VectorType type) {
157 return llvm::cast<ScalarType>(type.getElementType())
158 .getExtensions(extensions, storage);
160 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
165 std::optional<StorageClass> storage) {
169 [&](
auto type) { type.getCapabilities(capabilities, storage); })
170 .Case<VectorType>([&](VectorType type) {
172 if (vecSize == 8 || vecSize == 16) {
173 static const Capability caps[] = {Capability::Vector16};
175 capabilities.push_back(ref);
177 return llvm::cast<ScalarType>(type.getElementType())
178 .getCapabilities(capabilities, storage);
180 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
184 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
185 return arrayType.getSizeInBytes();
186 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
187 return structType.getSizeInBytes();
188 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this)) {
189 std::optional<int64_t> elementSize =
190 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
193 return *elementSize * vectorType.getNumElements();
204 std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
213 return key ==
KeyTy(elementType, rows, columns, scope, use);
217 : elementType(std::
get<0>(key)), rows(std::
get<1>(key)),
218 columns(std::
get<2>(key)), scope(std::
get<3>(key)),
219 use(std::
get<4>(key)) {}
225 CooperativeMatrixUseKHR
use;
230 uint32_t columns, Scope scope,
231 CooperativeMatrixUseKHR use) {
254 std::optional<StorageClass> storage) {
255 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
256 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
257 extensions.push_back(exts);
262 std::optional<StorageClass> storage) {
264 .getCapabilities(capabilities, storage);
265 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
266 capabilities.push_back(caps);
274 using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
283 return key ==
KeyTy(elementType, rows, columns, matrixLayout, scope);
287 : elementType(std::
get<0>(key)), rows(std::
get<1>(key)),
288 columns(std::
get<2>(key)), scope(std::
get<4>(key)),
289 matrixLayout(std::
get<3>(key)) {}
299 unsigned rows,
unsigned columns,
300 MatrixLayout matrixLayout) {
302 matrixLayout, scope);
316 return getImpl()->matrixLayout;
321 std::optional<StorageClass> storage) {
322 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
323 static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
325 extensions.push_back(ref);
330 std::optional<StorageClass> storage) {
332 .getCapabilities(capabilities, storage);
333 static const Capability caps[] = {Capability::JointMatrixINTEL};
335 capabilities.push_back(ref);
342 template <
typename T>
348 static_assert((1 << 3) > getMaxEnumValForDim(),
349 "Not enough bits to encode Dim value");
354 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
355 "Not enough bits to encode ImageDepthInfo value");
360 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
361 "Not enough bits to encode ImageArrayedInfo value");
366 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
367 "Not enough bits to encode ImageSamplingInfo value");
372 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
373 "Not enough bits to encode ImageSamplerUseInfo value");
378 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
379 "Not enough bits to encode ImageFormat value");
385 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
386 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
394 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
395 samplerUseInfo, format);
399 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
400 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
401 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
402 format(std::
get<6>(key)) {}
415 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
431 return getImpl()->samplingInfo;
435 return getImpl()->samplerUseInfo;
441 std::optional<StorageClass>) {
447 std::optional<StorageClass>) {
448 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
449 capabilities.push_back(*dimCaps);
452 capabilities.push_back(*fmtCaps);
462 using KeyTy = std::pair<Type, StorageClass>;
471 return key ==
KeyTy(pointeeType, storageClass);
475 : pointeeType(key.first), storageClass(key.second) {}
488 return getImpl()->storageClass;
492 std::optional<StorageClass> storage) {
499 extensions.push_back(*scExts);
504 std::optional<StorageClass> storage) {
511 capabilities.push_back(*scCaps);
519 using KeyTy = std::pair<Type, unsigned>;
528 return key ==
KeyTy(elementType, stride);
532 : elementType(key.first), stride(key.second) {}
552 std::optional<StorageClass> storage) {
553 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
558 std::optional<StorageClass> storage) {
560 static const Capability caps[] = {Capability::Shader};
562 capabilities.push_back(ref);
565 .getCapabilities(capabilities, storage);
573 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
576 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
583 return llvm::is_contained({16u, 32u, 64u}, type.
getWidth()) && !type.
isBF16();
587 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
591 std::optional<StorageClass> storage) {
599 case StorageClass::PushConstant:
600 case StorageClass::StorageBuffer:
601 case StorageClass::Uniform:
603 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
605 extensions.push_back(ref);
608 case StorageClass::Input:
609 case StorageClass::Output:
611 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
613 extensions.push_back(ref);
623 std::optional<StorageClass> storage) {
630 #define STORAGE_CASE(storage, cap8, cap16) \
631 case StorageClass::storage: { \
632 if (bitwidth == 8) { \
633 static const Capability caps[] = {Capability::cap8}; \
634 ArrayRef<Capability> ref(caps, std::size(caps)); \
635 capabilities.push_back(ref); \
638 if (bitwidth == 16) { \
639 static const Capability caps[] = {Capability::cap16}; \
640 ArrayRef<Capability> ref(caps, std::size(caps)); \
641 capabilities.push_back(ref); \
652 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
654 StorageBuffer16BitAccess);
655 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
657 case StorageClass::Input:
658 case StorageClass::Output: {
659 if (bitwidth == 16) {
660 static const Capability caps[] = {Capability::StorageInputOutput16};
662 capabilities.push_back(ref);
676 #define WIDTH_CASE(type, width) \
678 static const Capability caps[] = {Capability::type##width}; \
679 ArrayRef<Capability> ref(caps, std::size(caps)); \
680 capabilities.push_back(ref); \
683 if (
auto intType = llvm::dyn_cast<IntegerType>(*
this)) {
692 llvm_unreachable(
"invalid bitwidth to getCapabilities");
695 assert(llvm::isa<FloatType>(*
this));
702 llvm_unreachable(
"invalid bitwidth to getCapabilities");
728 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
730 if (llvm::isa<ScalarType>(type))
732 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
742 std::optional<StorageClass> storage) {
743 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
744 scalarType.getExtensions(extensions, storage);
745 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
746 compositeType.getExtensions(extensions, storage);
747 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
748 imageType.getExtensions(extensions, storage);
749 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
750 sampledImageType.getExtensions(extensions, storage);
751 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
752 matrixType.getExtensions(extensions, storage);
753 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
754 ptrType.getExtensions(extensions, storage);
756 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
762 std::optional<StorageClass> storage) {
763 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
764 scalarType.getCapabilities(capabilities, storage);
765 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
766 compositeType.getCapabilities(capabilities, storage);
767 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
768 imageType.getCapabilities(capabilities, storage);
769 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
770 sampledImageType.getCapabilities(capabilities, storage);
771 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
772 matrixType.getCapabilities(capabilities, storage);
773 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
774 ptrType.getCapabilities(capabilities, storage);
776 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
781 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this))
782 return scalarType.getSizeInBytes();
783 if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this))
784 return compositeType.getSizeInBytes();
822 if (!llvm::isa<ImageType>(imageType))
823 return emitError() <<
"expected image type";
830 std::optional<StorageClass> storage) {
831 llvm::cast<ImageType>(
getImageType()).getExtensions(extensions, storage);
836 std::optional<StorageClass> storage) {
837 llvm::cast<ImageType>(
getImageType()).getCapabilities(capabilities, storage);
863 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
864 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
865 identifier(identifier) {}
870 unsigned numMembers,
Type const *memberTypes,
873 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
874 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
875 memberDecorationsInfo(memberDecorationsInfo) {}
901 if (isIdentified()) {
903 return getIdentifier() == std::get<0>(key);
906 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
907 getMemberDecorationsInfo());
918 StringRef keyIdentifier = std::get<0>(key);
920 if (!keyIdentifier.empty()) {
921 StringRef identifier = allocator.
copyInto(keyIdentifier);
932 const Type *typesList =
nullptr;
933 if (!keyTypes.empty()) {
934 typesList = allocator.
copyInto(keyTypes).data();
938 if (!std::get<2>(key).empty()) {
940 assert(keyOffsetInfo.size() == keyTypes.size() &&
941 "size of offset information must be same as the size of number of "
943 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
947 unsigned numMemberDecorations = 0;
948 if (!std::get<3>(key).empty()) {
949 auto keyMemberDecorations = std::get<3>(key);
950 numMemberDecorations = keyMemberDecorations.size();
951 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
956 numMemberDecorations, memberDecorationList);
960 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
971 if (memberDecorationsInfo) {
973 numMemberDecorations);
997 if (memberTypesAndIsBodySet.getInt() &&
998 (getMemberTypes() != structMemberTypes ||
999 getOffsetInfo() != structOffsetInfo ||
1000 getMemberDecorationsInfo() != structMemberDecorationInfo))
1003 memberTypesAndIsBodySet.setInt(
true);
1004 numMembers = structMemberTypes.size();
1007 if (!structMemberTypes.empty())
1008 memberTypesAndIsBodySet.setPointer(
1009 allocator.
copyInto(structMemberTypes).data());
1011 if (!structOffsetInfo.empty()) {
1012 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1013 "size of offset information must be same as the size of number of "
1015 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
1018 if (!structMemberDecorationInfo.empty()) {
1019 numMemberDecorations = structMemberDecorationInfo.size();
1020 memberDecorationsInfo =
1021 allocator.
copyInto(structMemberDecorationInfo).data();
1039 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1042 memberDecorations.begin(), memberDecorations.end());
1043 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1044 return Base::get(memberTypes.vec().front().getContext(),
1045 StringRef(), memberTypes, offsetInfo,
1050 StringRef identifier) {
1051 assert(!identifier.empty() &&
1052 "StructType identifier must be non-empty string");
1070 return newStructType;
1081 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1093 return getImpl()->offsetInfo[index];
1099 memberDecorations.clear();
1100 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1101 memberDecorations.append(implMemberDecorations.begin(),
1102 implMemberDecorations.end());
1109 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1110 decorationsInfo.clear();
1111 for (
const auto &memberDecoration : memberDecorations) {
1112 if (memberDecoration.memberIndex == index) {
1113 decorationsInfo.push_back(memberDecoration);
1115 if (memberDecoration.memberIndex > index) {
1126 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1130 std::optional<StorageClass> storage) {
1132 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1137 std::optional<StorageClass> storage) {
1139 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1144 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1154 : columnType(columnType), columnCount(columnCount) {}
1156 using KeyTy = std::tuple<Type, uint32_t>;
1167 return key ==
KeyTy(columnType, columnCount);
1179 Type columnType, uint32_t columnCount) {
1185 Type columnType, uint32_t columnCount) {
1186 if (columnCount < 2 || columnCount > 4)
1187 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1190 return emitError() <<
"matrix columns must be vectors of floats";
1194 if (columnShape.size() != 1)
1195 return emitError() <<
"matrix columns must be 1D vectors";
1197 if (columnShape[0] < 2 || columnShape[0] > 4)
1198 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1205 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206 if (llvm::isa<FloatType>(vectorType.getElementType()))
1215 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1221 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1229 std::optional<StorageClass> storage) {
1230 llvm::cast<SPIRVType>(
getColumnType()).getExtensions(extensions, storage);
1235 std::optional<StorageClass> storage) {
1237 static const Capability caps[] = {Capability::Matrix};
1239 capabilities.push_back(ref);
1242 llvm::cast<SPIRVType>(
getColumnType()).getCapabilities(capabilities, storage);
1249 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Return the scope of the joint matrix.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getColumns() const
return the number of columns of the matrix.
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
unsigned getRows() const
return the number of rows of the matrix.
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Type getElementType() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
JointMatrixTypeStorage(const KeyTy &key)
MatrixLayout matrixLayout
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const