17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
30 using KeyTy = std::tuple<Type, unsigned, unsigned>;
38 return key ==
KeyTy(elementType, elementCount, stride);
42 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
43 stride(std::
get<2>(key)) {}
51 assert(elementCount &&
"ArrayType needs at least one element");
58 assert(elementCount &&
"ArrayType needs at least one element");
69 std::optional<StorageClass> storage) {
70 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
75 std::optional<StorageClass> storage) {
77 .getCapabilities(capabilities, storage);
82 std::optional<int64_t> size = elementType.getSizeInBytes();
93 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
101 return type.getRank() == 1 &&
102 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
103 llvm::isa<ScalarType>(type.getElementType());
109 TensorArmType>([](
auto type) {
return type.getElementType(); })
114 [](
Type) ->
Type { llvm_unreachable(
"invalid composite type"); });
118 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
119 return arrayType.getNumElements();
120 if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this))
121 return matrixType.getNumColumns();
122 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
123 return structType.getNumElements();
124 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this))
125 return vectorType.getNumElements();
126 if (
auto tensorArmType = dyn_cast<TensorArmType>(*
this))
127 return tensorArmType.getNumElements();
128 if (llvm::isa<CooperativeMatrixType>(*
this)) {
130 "invalid to query number of elements of spirv Cooperative Matrix type");
132 if (llvm::isa<RuntimeArrayType>(*
this)) {
134 "invalid to query number of elements of spirv::RuntimeArray type");
136 llvm_unreachable(
"invalid composite type");
140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*
this);
145 std::optional<StorageClass> storage) {
149 [&](
auto type) { type.getExtensions(extensions, storage); })
150 .Case<VectorType>([&](VectorType type) {
151 return llvm::cast<ScalarType>(type.getElementType())
152 .getExtensions(extensions, storage);
155 static constexpr Extension ext{Extension::SPV_ARM_tensors};
156 extensions.push_back(ext);
158 .getExtensions(extensions, storage);
161 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
166 std::optional<StorageClass> storage) {
170 [&](
auto type) { type.getCapabilities(capabilities, storage); })
171 .Case<VectorType>([&](VectorType type) {
173 if (vecSize == 8 || vecSize == 16) {
174 static const Capability caps[] = {Capability::Vector16};
176 capabilities.push_back(ref);
178 return llvm::cast<ScalarType>(type.getElementType())
179 .getCapabilities(capabilities, storage);
182 static constexpr Capability cap{Capability::TensorsARM};
183 capabilities.push_back(cap);
185 .getCapabilities(capabilities, storage);
187 .Default([](
Type) { llvm_unreachable(
"invalid composite type"); });
191 if (
auto arrayType = llvm::dyn_cast<ArrayType>(*
this))
192 return arrayType.getSizeInBytes();
193 if (
auto structType = llvm::dyn_cast<StructType>(*
this))
194 return structType.getSizeInBytes();
195 if (
auto vectorType = llvm::dyn_cast<VectorType>(*
this)) {
196 std::optional<int64_t> elementSize =
197 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
200 return *elementSize * vectorType.getNumElements();
202 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(*
this)) {
203 std::optional<int64_t> elementSize =
204 llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
207 return *elementSize * tensorArmType.getNumElements();
231 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
240 return key ==
KeyTy(elementType, shape[0], shape[1], scope, use);
244 : elementType(std::
get<0>(key)),
245 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
246 use(std::get<4>(key)) {}
252 CooperativeMatrixUseKHR
use;
257 uint32_t columns, Scope scope,
258 CooperativeMatrixUseKHR use) {
268 assert(
getImpl()->shape[0] != ShapedType::kDynamic);
269 return static_cast<uint32_t
>(
getImpl()->shape[0]);
273 assert(
getImpl()->shape[1] != ShapedType::kDynamic);
274 return static_cast<uint32_t
>(
getImpl()->shape[1]);
289 std::optional<StorageClass> storage) {
290 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
291 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292 extensions.push_back(exts);
297 std::optional<StorageClass> storage) {
299 .getCapabilities(capabilities, storage);
300 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
301 capabilities.push_back(caps);
308 template <
typename T>
314 static_assert((1 << 3) > getMaxEnumValForDim(),
315 "Not enough bits to encode Dim value");
320 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
321 "Not enough bits to encode ImageDepthInfo value");
326 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
327 "Not enough bits to encode ImageArrayedInfo value");
332 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
333 "Not enough bits to encode ImageSamplingInfo value");
338 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
339 "Not enough bits to encode ImageSamplerUseInfo value");
344 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
345 "Not enough bits to encode ImageFormat value");
351 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
352 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
360 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
361 samplerUseInfo, format);
365 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
366 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
367 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
368 format(std::
get<6>(key)) {}
381 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
397 return getImpl()->samplingInfo;
401 return getImpl()->samplerUseInfo;
407 std::optional<StorageClass>) {
413 std::optional<StorageClass>) {
414 if (
auto dimCaps = spirv::getCapabilities(
getDim()))
415 capabilities.push_back(*dimCaps);
418 capabilities.push_back(*fmtCaps);
428 using KeyTy = std::pair<Type, StorageClass>;
437 return key ==
KeyTy(pointeeType, storageClass);
441 : pointeeType(key.first), storageClass(key.second) {}
454 return getImpl()->storageClass;
458 std::optional<StorageClass> storage) {
465 extensions.push_back(*scExts);
470 std::optional<StorageClass> storage) {
477 capabilities.push_back(*scCaps);
485 using KeyTy = std::pair<Type, unsigned>;
494 return key ==
KeyTy(elementType, stride);
498 : elementType(key.first), stride(key.second) {}
518 std::optional<StorageClass> storage) {
519 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
524 std::optional<StorageClass> storage) {
526 static const Capability caps[] = {Capability::Shader};
528 capabilities.push_back(ref);
531 .getCapabilities(capabilities, storage);
539 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
542 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
549 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
553 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
557 std::optional<StorageClass> storage) {
558 if (isa<BFloat16Type>(*
this)) {
559 static const Extension ext = Extension::SPV_KHR_bfloat16;
560 extensions.push_back(ext);
570 case StorageClass::PushConstant:
571 case StorageClass::StorageBuffer:
572 case StorageClass::Uniform:
574 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
576 extensions.push_back(ref);
579 case StorageClass::Input:
580 case StorageClass::Output:
582 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
584 extensions.push_back(ref);
594 std::optional<StorageClass> storage) {
601 #define STORAGE_CASE(storage, cap8, cap16) \
602 case StorageClass::storage: { \
603 if (bitwidth == 8) { \
604 static const Capability caps[] = {Capability::cap8}; \
605 ArrayRef<Capability> ref(caps, std::size(caps)); \
606 capabilities.push_back(ref); \
609 if (bitwidth == 16) { \
610 static const Capability caps[] = {Capability::cap16}; \
611 ArrayRef<Capability> ref(caps, std::size(caps)); \
612 capabilities.push_back(ref); \
623 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
625 StorageBuffer16BitAccess);
628 case StorageClass::Input:
629 case StorageClass::Output: {
630 if (bitwidth == 16) {
631 static const Capability caps[] = {Capability::StorageInputOutput16};
633 capabilities.push_back(ref);
647 #define WIDTH_CASE(type, width) \
649 static const Capability caps[] = {Capability::type##width}; \
650 ArrayRef<Capability> ref(caps, std::size(caps)); \
651 capabilities.push_back(ref); \
654 if (
auto intType = llvm::dyn_cast<IntegerType>(*
this)) {
663 llvm_unreachable(
"invalid bitwidth to getCapabilities");
666 assert(llvm::isa<FloatType>(*
this));
669 if (isa<BFloat16Type>(*
this)) {
670 static const Capability cap = Capability::BFloat16TypeKHR;
671 capabilities.push_back(cap);
673 static const Capability cap = Capability::Float16;
674 capabilities.push_back(cap);
682 llvm_unreachable(
"invalid bitwidth to getCapabilities");
708 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
710 if (llvm::isa<ScalarType>(type))
712 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
714 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
715 return llvm::isa<ScalarType>(tensorArmType.getElementType());
724 std::optional<StorageClass> storage) {
725 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
726 scalarType.getExtensions(extensions, storage);
727 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
728 compositeType.getExtensions(extensions, storage);
729 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
730 imageType.getExtensions(extensions, storage);
731 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
732 sampledImageType.getExtensions(extensions, storage);
733 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
734 matrixType.getExtensions(extensions, storage);
735 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
736 ptrType.getExtensions(extensions, storage);
737 }
else if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(*
this)) {
738 tensorArmType.getExtensions(extensions, storage);
740 llvm_unreachable(
"invalid SPIR-V Type to getExtensions");
746 std::optional<StorageClass> storage) {
747 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this)) {
748 scalarType.getCapabilities(capabilities, storage);
749 }
else if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this)) {
750 compositeType.getCapabilities(capabilities, storage);
751 }
else if (
auto imageType = llvm::dyn_cast<ImageType>(*
this)) {
752 imageType.getCapabilities(capabilities, storage);
753 }
else if (
auto sampledImageType = llvm::dyn_cast<SampledImageType>(*
this)) {
754 sampledImageType.getCapabilities(capabilities, storage);
755 }
else if (
auto matrixType = llvm::dyn_cast<MatrixType>(*
this)) {
756 matrixType.getCapabilities(capabilities, storage);
757 }
else if (
auto ptrType = llvm::dyn_cast<PointerType>(*
this)) {
758 ptrType.getCapabilities(capabilities, storage);
759 }
else if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(*
this)) {
760 tensorArmType.getCapabilities(capabilities, storage);
762 llvm_unreachable(
"invalid SPIR-V Type to getCapabilities");
767 if (
auto scalarType = llvm::dyn_cast<ScalarType>(*
this))
768 return scalarType.getSizeInBytes();
769 if (
auto compositeType = llvm::dyn_cast<CompositeType>(*
this))
770 return compositeType.getSizeInBytes();
808 auto image = dyn_cast<ImageType>(imageType);
810 return emitError() <<
"expected image type";
815 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
816 return emitError() <<
"Dim must not be SubpassData or Buffer";
823 std::optional<StorageClass> storage) {
824 llvm::cast<ImageType>(
getImageType()).getExtensions(extensions, storage);
829 std::optional<StorageClass> storage) {
830 llvm::cast<ImageType>(
getImageType()).getCapabilities(capabilities, storage);
858 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
859 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
860 numStructDecorations(0), structDecorationsInfo(nullptr),
861 identifier(identifier) {}
866 unsigned numMembers,
Type const *memberTypes,
869 unsigned numStructDecorations,
871 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
872 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
873 memberDecorationsInfo(memberDecorationsInfo),
874 numStructDecorations(numStructDecorations),
875 structDecorationsInfo(structDecorationsInfo) {}
904 if (isIdentified()) {
906 return getIdentifier() == std::get<0>(key);
909 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
910 getMemberDecorationsInfo(), getStructDecorationsInfo());
921 StringRef keyIdentifier = std::get<0>(key);
923 if (!keyIdentifier.empty()) {
924 StringRef identifier = allocator.
copyInto(keyIdentifier);
935 const Type *typesList =
nullptr;
936 if (!keyTypes.empty()) {
937 typesList = allocator.
copyInto(keyTypes).data();
941 if (!std::get<2>(key).empty()) {
943 assert(keyOffsetInfo.size() == keyTypes.size() &&
944 "size of offset information must be same as the size of number of "
946 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
950 unsigned numMemberDecorations = 0;
951 if (!std::get<3>(key).empty()) {
952 auto keyMemberDecorations = std::get<3>(key);
953 numMemberDecorations = keyMemberDecorations.size();
954 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
958 unsigned numStructDecorations = 0;
959 if (!std::get<4>(key).empty()) {
960 auto keyStructDecorations = std::get<4>(key);
961 numStructDecorations = keyStructDecorations.size();
962 structDecorationList = allocator.
copyInto(keyStructDecorations).data();
966 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
967 memberDecorationList, numStructDecorations, structDecorationList);
971 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
982 if (memberDecorationsInfo) {
984 numMemberDecorations);
990 if (structDecorationsInfo)
992 numStructDecorations);
1013 if (!isIdentified())
1016 if (memberTypesAndIsBodySet.getInt() &&
1017 (getMemberTypes() != structMemberTypes ||
1018 getOffsetInfo() != structOffsetInfo ||
1019 getMemberDecorationsInfo() != structMemberDecorationInfo ||
1020 getStructDecorationsInfo() != structDecorationInfo))
1023 memberTypesAndIsBodySet.setInt(
true);
1024 numMembers = structMemberTypes.size();
1027 if (!structMemberTypes.empty())
1028 memberTypesAndIsBodySet.setPointer(
1029 allocator.
copyInto(structMemberTypes).data());
1031 if (!structOffsetInfo.empty()) {
1032 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1033 "size of offset information must be same as the size of number of "
1035 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
1038 if (!structMemberDecorationInfo.empty()) {
1039 numMemberDecorations = structMemberDecorationInfo.size();
1040 memberDecorationsInfo =
1041 allocator.
copyInto(structMemberDecorationInfo).data();
1044 if (!structDecorationInfo.empty()) {
1045 numStructDecorations = structDecorationInfo.size();
1046 structDecorationsInfo = allocator.
copyInto(structDecorationInfo).data();
1067 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1071 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1072 sortedMemberDecorations.end());
1075 llvm::array_pod_sort(sortedStructDecorations.begin(),
1076 sortedStructDecorations.end());
1078 return Base::get(memberTypes.vec().front().getContext(),
1079 StringRef(), memberTypes, offsetInfo,
1080 sortedMemberDecorations, sortedStructDecorations);
1084 StringRef identifier) {
1085 assert(!identifier.empty() &&
1086 "StructType identifier must be non-empty string");
1107 return newStructType;
1118 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1130 getImpl()->getStructDecorationsInfo())
1131 if (info.decoration == decoration)
1139 return getImpl()->offsetInfo[index];
1145 memberDecorations.clear();
1146 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1147 memberDecorations.append(implMemberDecorations.begin(),
1148 implMemberDecorations.end());
1155 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1156 decorationsInfo.clear();
1157 for (
const auto &memberDecoration : memberDecorations) {
1158 if (memberDecoration.memberIndex == index) {
1159 decorationsInfo.push_back(memberDecoration);
1161 if (memberDecoration.memberIndex > index) {
1171 structDecorations.clear();
1172 auto implDecorations =
getImpl()->getStructDecorationsInfo();
1173 structDecorations.append(implDecorations.begin(), implDecorations.end());
1181 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1186 std::optional<StorageClass> storage) {
1188 llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1193 std::optional<StorageClass> storage) {
1195 llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1200 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1215 : columnType(columnType), columnCount(columnCount) {}
1217 using KeyTy = std::tuple<Type, uint32_t>;
1228 return key ==
KeyTy(columnType, columnCount);
1240 Type columnType, uint32_t columnCount) {
1247 Type columnType, uint32_t columnCount) {
1248 if (columnCount < 2 || columnCount > 4)
1249 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1252 return emitError() <<
"matrix columns must be vectors of floats";
1256 if (columnShape.size() != 1)
1257 return emitError() <<
"matrix columns must be 1D vectors";
1259 if (columnShape[0] < 2 || columnShape[0] > 4)
1260 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1267 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1268 if (llvm::isa<FloatType>(vectorType.getElementType()))
1277 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1283 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1291 std::optional<StorageClass> storage) {
1292 llvm::cast<SPIRVType>(
getColumnType()).getExtensions(extensions, storage);
1297 std::optional<StorageClass> storage) {
1299 static const Capability caps[] = {Capability::Matrix};
1301 capabilities.push_back(ref);
1304 llvm::cast<SPIRVType>(
getColumnType()).getCapabilities(capabilities, storage);
1316 auto [shape, elementType] = key;
1323 auto [shape, elementType] = key;
1324 return llvm::hash_combine(shape, elementType);
1328 return key ==
KeyTy(shape, elementType);
1332 : shape(shape), elementType(elementType) {}
1343 Type elementType)
const {
1352 std::optional<StorageClass> storage) {
1354 llvm::cast<SPIRVType>(
getElementType()).getExtensions(extensions, storage);
1355 static constexpr Extension ext{Extension::SPV_ARM_tensors};
1356 extensions.push_back(ext);
1361 std::optional<StorageClass> storage) {
1363 .getCapabilities(capabilities, storage);
1364 static constexpr Capability cap{Capability::TensorsARM};
1365 capabilities.push_back(cap);
1371 if (llvm::is_contained(shape, 0))
1372 return emitError() <<
"arm.tensor do not support dimensions = 0";
1373 if (llvm::any_of(shape, [](int64_t dim) {
return dim < 0; }) &&
1374 llvm::any_of(shape, [](int64_t dim) {
return dim > 0; }))
1376 <<
"arm.tensor shape dimensions must be either fully dynamic or "
1385 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::array< int64_t, 2 > shape
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
unsigned numStructDecorations
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const
ArrayRef< int64_t > shape