18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
40 return key ==
KeyTy(elementType, elementCount, stride);
44 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
45 stride(std::
get<2>(key)) {}
53 assert(elementCount &&
"ArrayType needs at least one element");
60 assert(elementCount &&
"ArrayType needs at least one element");
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
77 std::optional<StorageClass> storage) {
79 .getCapabilities(capabilities, storage);
84 std::optional<int64_t> size = elementType.getSizeInBytes();
95 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
104 return type.getRank() == 1 &&
105 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
106 llvm::isa<ScalarType>(type.getElementType());
113 [](
auto type) {
return type.getElementType(); })
118 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
122 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
123 return arrayType.getNumElements();
124 if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this))
125 return matrixType.getNumColumns();
126 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
127 return structType.getNumElements();
128 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this))
129 return vectorType.getNumElements();
130 if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*
this)) {
132 "invalid to query number of elements of spirv Cooperative Matrix type");
134 if (llvm::isa<JointMatrixINTELType>(*
this)) {
136 "invalid to query number of elements of spirv::JointMatrix type");
138 if (llvm::isa<RuntimeArrayType>(*
this)) {
140 "invalid to query number of elements of spirv::RuntimeArray type");
142 llvm_unreachable(
"invalid composite type");
152 std::optional<StorageClass> storage) {
156 [&](
auto type) { type.getExtensions(extensions, storage); })
157 .Case<VectorType>([&](VectorType type) {
158 return llvm::cast<ScalarType>(type.getElementType())
159 .getExtensions(extensions, storage);
161 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
166 std::optional<StorageClass> storage) {
170 [&](
auto type) { type.getCapabilities(capabilities, storage); })
171 .Case<VectorType>([&](VectorType type) {
173 if (vecSize == 8 || vecSize == 16) {
174 static const Capability caps[] = {Capability::Vector16};
176 capabilities.push_back(ref);
178 return llvm::cast<ScalarType>(type.getElementType())
179 .getCapabilities(capabilities, storage);
181 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
185 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
186 return arrayType.getSizeInBytes();
187 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
188 return structType.getSizeInBytes();
189 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this)) {
190 std::optional<int64_t> elementSize =
191 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
194 return *elementSize * vectorType.getNumElements();
205 std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
214 return key ==
KeyTy(elementType, rows, columns, scope, use);
218 : elementType(std::
get<0>(key)), rows(std::
get<1>(key)),
219 columns(std::
get<2>(key)), scope(std::
get<3>(key)),
220 use(std::
get<4>(key)) {}
226 CooperativeMatrixUseKHR
use;
231 uint32_t columns, Scope scope,
232 CooperativeMatrixUseKHR use) {
255 std::optional<StorageClass> storage) {
256 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
257 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
258 extensions.push_back(exts);
263 std::optional<StorageClass> storage) {
265 .getCapabilities(capabilities, storage);
266 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
267 capabilities.push_back(caps);
275 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
284 return key ==
KeyTy(elementType, scope, rows, columns);
288 : elementType(std::
get<0>(key)), rows(std::
get<2>(key)),
289 columns(std::
get<3>(key)), scope(std::
get<1>(key)) {}
298 Scope scope,
unsigned rows,
317 std::optional<StorageClass> storage) {
318 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
319 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
321 extensions.push_back(ref);
326 std::optional<StorageClass> storage) {
328 .getCapabilities(capabilities, storage);
329 static const Capability caps[] = {Capability::CooperativeMatrixNV};
331 capabilities.push_back(ref);
339 using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
348 return key ==
KeyTy(elementType, rows, columns, matrixLayout, scope);
352 : elementType(std::
get<0>(key)), rows(std::
get<1>(key)),
353 columns(std::
get<2>(key)), scope(std::
get<4>(key)),
354 matrixLayout(std::
get<3>(key)) {}
364 unsigned rows,
unsigned columns,
365 MatrixLayout matrixLayout) {
367 matrixLayout, scope);
381 return getImpl()->matrixLayout;
386 std::optional<StorageClass> storage) {
387 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
388 static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
390 extensions.push_back(ref);
395 std::optional<StorageClass> storage) {
397 .getCapabilities(capabilities, storage);
398 static const Capability caps[] = {Capability::JointMatrixINTEL};
400 capabilities.push_back(ref);
407 template <
typename T>
413 static_assert((1 << 3) > getMaxEnumValForDim(),
414 "Not enough bits to encode Dim value");
419 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
420 "Not enough bits to encode ImageDepthInfo value");
425 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
426 "Not enough bits to encode ImageArrayedInfo value");
431 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
432 "Not enough bits to encode ImageSamplingInfo value");
437 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
438 "Not enough bits to encode ImageSamplerUseInfo value");
443 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
444 "Not enough bits to encode ImageFormat value");
450 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
451 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
459 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
460 samplerUseInfo, format);
464 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
465 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
466 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
467 format(std::
get<6>(key)) {}
480 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
496 return getImpl()->samplingInfo;
500 return getImpl()->samplerUseInfo;
506 std::optional<StorageClass>) {
512 std::optional<StorageClass>) {
513 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
514 capabilities.push_back(*dimCaps);
517 capabilities.push_back(*fmtCaps);
527 using KeyTy = std::pair<Type, StorageClass>;
536 return key ==
KeyTy(pointeeType, storageClass);
540 : pointeeType(key.first), storageClass(key.second) {}
553 return getImpl()->storageClass;
557 std::optional<StorageClass> storage) {
564 extensions.push_back(*scExts);
569 std::optional<StorageClass> storage) {
576 capabilities.push_back(*scCaps);
584 using KeyTy = std::pair<Type, unsigned>;
593 return key ==
KeyTy(elementType, stride);
597 : elementType(key.first), stride(key.second) {}
617 std::optional<StorageClass> storage) {
618 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
623 std::optional<StorageClass> storage) {
625 static const Capability caps[] = {Capability::Shader};
627 capabilities.push_back(ref);
630 .getCapabilities(capabilities, storage);
638 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
641 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
648 return llvm::is_contained({16u, 32u, 64u}, type.
getWidth()) && !type.
isBF16();
652 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
656 std::optional<StorageClass> storage) {
664 case StorageClass::PushConstant:
665 case StorageClass::StorageBuffer:
666 case StorageClass::Uniform:
668 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
670 extensions.push_back(ref);
673 case StorageClass::Input:
674 case StorageClass::Output:
676 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
678 extensions.push_back(ref);
688 std::optional<StorageClass> storage) {
695 #define STORAGE_CASE(storage, cap8, cap16) \
696 case StorageClass::storage: { \
697 if (bitwidth == 8) { \
698 static const Capability caps[] = {Capability::cap8}; \
699 ArrayRef<Capability> ref(caps, std::size(caps)); \
700 capabilities.push_back(ref); \
703 if (bitwidth == 16) { \
704 static const Capability caps[] = {Capability::cap16}; \
705 ArrayRef<Capability> ref(caps, std::size(caps)); \
706 capabilities.push_back(ref); \
717 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
719 StorageBuffer16BitAccess);
720 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
722 case StorageClass::Input:
723 case StorageClass::Output: {
724 if (bitwidth == 16) {
725 static const Capability caps[] = {Capability::StorageInputOutput16};
727 capabilities.push_back(ref);
741 #define WIDTH_CASE(type, width) \
743 static const Capability caps[] = {Capability::type##width}; \
744 ArrayRef<Capability> ref(caps, std::size(caps)); \
745 capabilities.push_back(ref); \
748 if (
auto intType = llvm::dyn_cast<IntegerType>(*
this)) {
757 llvm_unreachable(
"invalid bitwidth to getCapabilities");
760 assert(llvm::isa<FloatType>(*
this));
767 llvm_unreachable(
"invalid bitwidth to getCapabilities");
793 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
795 if (llvm::isa<ScalarType>(type))
797 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
807 std::optional<StorageClass> storage) {
808 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
809 scalarType.getExtensions(extensions, storage);
810 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
811 compositeType.getExtensions(extensions, storage);
812 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
813 imageType.getExtensions(extensions, storage);
814 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
815 sampledImageType.getExtensions(extensions, storage);
816 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
817 matrixType.getExtensions(extensions, storage);
818 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
819 ptrType.getExtensions(extensions, storage);
821 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
827 std::optional<StorageClass> storage) {
828 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
829 scalarType.getCapabilities(capabilities, storage);
830 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
831 compositeType.getCapabilities(capabilities, storage);
832 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
833 imageType.getCapabilities(capabilities, storage);
834 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
835 sampledImageType.getCapabilities(capabilities, storage);
836 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
837 matrixType.getCapabilities(capabilities, storage);
838 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
839 ptrType.getCapabilities(capabilities, storage);
841 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
846 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this))
847 return scalarType.getSizeInBytes();
848 if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this))
849 return compositeType.getSizeInBytes();
887 if (!llvm::isa<ImageType>(imageType))
888 return emitError() <<
"expected image type";
895 std::optional<StorageClass> storage) {
896 llvm::cast<ImageType>(
getImageType()).getExtensions(extensions, storage);
901 std::optional<StorageClass> storage) {
902 llvm::cast<ImageType>(
getImageType()).getCapabilities(capabilities, storage);
928 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
929 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
930 identifier(identifier) {}
935 unsigned numMembers,
Type const *memberTypes,
938 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
939 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
940 memberDecorationsInfo(memberDecorationsInfo) {}
966 if (isIdentified()) {
968 return getIdentifier() == std::get<0>(key);
971 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
972 getMemberDecorationsInfo());
983 StringRef keyIdentifier = std::get<0>(key);
985 if (!keyIdentifier.empty()) {
986 StringRef identifier = allocator.
copyInto(keyIdentifier);
997 const Type *typesList =
nullptr;
998 if (!keyTypes.empty()) {
999 typesList = allocator.
copyInto(keyTypes).data();
1003 if (!std::get<2>(key).empty()) {
1005 assert(keyOffsetInfo.size() == keyTypes.size() &&
1006 "size of offset information must be same as the size of number of "
1008 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
1012 unsigned numMemberDecorations = 0;
1013 if (!std::get<3>(key).empty()) {
1014 auto keyMemberDecorations = std::get<3>(key);
1015 numMemberDecorations = keyMemberDecorations.size();
1016 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
1021 numMemberDecorations, memberDecorationList);
1025 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
1036 if (memberDecorationsInfo) {
1038 numMemberDecorations);
1059 if (!isIdentified())
1062 if (memberTypesAndIsBodySet.getInt() &&
1063 (getMemberTypes() != structMemberTypes ||
1064 getOffsetInfo() != structOffsetInfo ||
1065 getMemberDecorationsInfo() != structMemberDecorationInfo))
1068 memberTypesAndIsBodySet.setInt(
true);
1069 numMembers = structMemberTypes.size();
1072 if (!structMemberTypes.empty())
1073 memberTypesAndIsBodySet.setPointer(
1074 allocator.
copyInto(structMemberTypes).data());
1076 if (!structOffsetInfo.empty()) {
1077 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1078 "size of offset information must be same as the size of number of "
1080 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
1083 if (!structMemberDecorationInfo.empty()) {
1084 numMemberDecorations = structMemberDecorationInfo.size();
1085 memberDecorationsInfo =
1086 allocator.
copyInto(structMemberDecorationInfo).data();
1104 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1107 memberDecorations.begin(), memberDecorations.end());
1108 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1109 return Base::get(memberTypes.vec().front().getContext(),
1110 StringRef(), memberTypes, offsetInfo,
1115 StringRef identifier) {
1116 assert(!identifier.empty() &&
1117 "StructType identifier must be non-empty string");
1135 return newStructType;
1146 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1158 return getImpl()->offsetInfo[index];
1164 memberDecorations.clear();
1165 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1166 memberDecorations.append(implMemberDecorations.begin(),
1167 implMemberDecorations.end());
1174 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1175 decorationsInfo.clear();
1176 for (
const auto &memberDecoration : memberDecorations) {
1177 if (memberDecoration.memberIndex == index) {
1178 decorationsInfo.push_back(memberDecoration);
1180 if (memberDecoration.memberIndex > index) {
1191 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1195 std::optional<StorageClass> storage) {
1197 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1202 std::optional<StorageClass> storage) {
1204 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1209 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1219 : columnType(columnType), columnCount(columnCount) {}
1221 using KeyTy = std::tuple<Type, uint32_t>;
1232 return key ==
KeyTy(columnType, columnCount);
1244 Type columnType, uint32_t columnCount) {
1250 Type columnType, uint32_t columnCount) {
1251 if (columnCount < 2 || columnCount > 4)
1252 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1255 return emitError() <<
"matrix columns must be vectors of floats";
1259 if (columnShape.size() != 1)
1260 return emitError() <<
"matrix columns must be 1D vectors";
1262 if (columnShape[0] < 2 || columnShape[0] > 4)
1263 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1270 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1271 if (llvm::isa<FloatType>(vectorType.getElementType()))
1280 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1286 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1294 std::optional<StorageClass> storage) {
1295 llvm::cast<SPIRVType>(
getColumnType()).getExtensions(extensions, storage);
1300 std::optional<StorageClass> storage) {
1302 static const Capability caps[] = {Capability::Matrix};
1304 capabilities.push_back(ref);
1307 llvm::cast<SPIRVType>(
getColumnType()).getCapabilities(capabilities, storage);
1314 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
unsigned getRows() const
Returns the number of rows of the matrix.
unsigned getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Returns the scope of the matrix.
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Scope getScope() const
Return the scope of the joint matrix.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getColumns() const
return the number of columns of the matrix.
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
unsigned getRows() const
return the number of rows of the matrix.
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Type getElementType() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixNVTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
static CooperativeMatrixNVTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixTypeStorage(const KeyTy &key)
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
JointMatrixTypeStorage(const KeyTy &key)
MatrixLayout matrixLayout
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const