18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/ErrorHandling.h"
36 class TypeExtensionVisitor {
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
45 if (
auto [_it, inserted] = seen.insert({type, storage}); !inserted)
50 [
this](
auto concreteType) { addConcrete(concreteType); })
51 .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52 [
this](
auto concreteType) {
add(concreteType.getElementType()); })
56 .Case<StructType>([
this](
StructType concreteType) {
60 .DefaultUnreachable(
"Unhandled type");
63 void add(
Type type) {
add(cast<SPIRVType>(type)); }
73 std::optional<StorageClass> storage;
74 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
84 class TypeCapabilityVisitor {
87 std::optional<StorageClass> storage)
88 : capabilities(capabilities), storage(storage) {}
93 if (
auto [_it, inserted] = seen.insert({type, storage}); !inserted)
99 [
this](
auto concreteType) { addConcrete(concreteType); })
100 .Case<ArrayType>([
this](
ArrayType concreteType) {
106 .Case<StructType>([
this](
StructType concreteType) {
110 .DefaultUnreachable(
"Unhandled type");
113 void add(
Type type) {
add(cast<SPIRVType>(type)); }
124 void addConcrete(VectorType type);
127 std::optional<StorageClass> storage;
128 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
138 using KeyTy = std::tuple<Type, unsigned, unsigned>;
146 return key ==
KeyTy(elementType, elementCount, stride);
150 : elementType(std::
get<0>(key)), elementCount(std::
get<1>(key)),
151 stride(std::
get<2>(key)) {}
159 assert(elementCount &&
"ArrayType needs at least one element");
166 assert(elementCount &&
"ArrayType needs at least one element");
181 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
189 return type.getRank() == 1 &&
190 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
191 llvm::isa<ScalarType>(type.getElementType());
197 TensorArmType>([](
auto type) {
return type.getElementType(); })
201 .DefaultUnreachable(
"Invalid composite type");
207 [](
auto type) {
return type.getNumElements(); })
209 .DefaultUnreachable(
"Invalid type for number of elements query");
213 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*
this);
216 void TypeCapabilityVisitor::addConcrete(VectorType type) {
217 add(type.getElementType());
219 int64_t vecSize = type.getNumElements();
220 if (vecSize == 8 || vecSize == 16) {
221 static constexpr
auto cap = Capability::Vector16;
222 capabilities.push_back(cap);
245 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
254 return key ==
KeyTy(elementType, shape[0], shape[1], scope, use);
258 : elementType(std::
get<0>(key)),
259 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
260 use(std::get<4>(key)) {}
266 CooperativeMatrixUseKHR
use;
271 uint32_t columns, Scope scope,
272 CooperativeMatrixUseKHR use) {
282 assert(
getImpl()->shape[0] != ShapedType::kDynamic);
283 return static_cast<uint32_t
>(
getImpl()->shape[0]);
287 assert(
getImpl()->shape[1] != ShapedType::kDynamic);
288 return static_cast<uint32_t
>(
getImpl()->shape[1]);
303 static constexpr
auto ext = Extension::SPV_KHR_cooperative_matrix;
304 extensions.push_back(ext);
309 static constexpr
auto caps = Capability::CooperativeMatrixKHR;
310 capabilities.push_back(caps);
317 template <
typename T>
323 static_assert((1 << 3) > getMaxEnumValForDim(),
324 "Not enough bits to encode Dim value");
329 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
330 "Not enough bits to encode ImageDepthInfo value");
335 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
336 "Not enough bits to encode ImageArrayedInfo value");
341 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
342 "Not enough bits to encode ImageSamplingInfo value");
347 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
348 "Not enough bits to encode ImageSamplerUseInfo value");
353 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
354 "Not enough bits to encode ImageFormat value");
360 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
361 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
369 return key ==
KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
370 samplerUseInfo, format);
374 : elementType(std::
get<0>(key)), dim(std::
get<1>(key)),
375 depthInfo(std::
get<2>(key)), arrayedInfo(std::
get<3>(key)),
376 samplingInfo(std::
get<4>(key)), samplerUseInfo(std::
get<5>(key)),
377 format(std::
get<6>(key)) {}
390 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
406 return getImpl()->samplingInfo;
410 return getImpl()->samplerUseInfo;
415 void TypeCapabilityVisitor::addConcrete(
ImageType type) {
416 if (
auto dimCaps = spirv::getCapabilities(type.
getDim()))
417 capabilities.push_back(*dimCaps);
420 capabilities.push_back(*fmtCaps);
432 using KeyTy = std::pair<Type, StorageClass>;
441 return key ==
KeyTy(pointeeType, storageClass);
445 : pointeeType(key.first), storageClass(key.second) {}
458 return getImpl()->storageClass;
461 void TypeExtensionVisitor::addConcrete(
PointerType type) {
464 std::optional<StorageClass> oldStorageClass = storage;
467 storage = oldStorageClass;
470 extensions.push_back(*scExts);
473 void TypeCapabilityVisitor::addConcrete(
PointerType type) {
476 std::optional<StorageClass> oldStorageClass = storage;
479 storage = oldStorageClass;
482 capabilities.push_back(*scCaps);
490 using KeyTy = std::pair<Type, unsigned>;
499 return key ==
KeyTy(elementType, stride);
503 : elementType(key.first), stride(key.second) {}
523 static constexpr
auto cap = Capability::Shader;
524 capabilities.push_back(cap);
532 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
535 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
542 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
546 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
549 void TypeExtensionVisitor::addConcrete(
ScalarType type) {
550 if (isa<BFloat16Type>(type)) {
551 static constexpr
auto ext = Extension::SPV_KHR_bfloat16;
552 extensions.push_back(ext);
562 case StorageClass::PushConstant:
563 case StorageClass::StorageBuffer:
564 case StorageClass::Uniform:
566 static constexpr
auto ext = Extension::SPV_KHR_8bit_storage;
567 extensions.push_back(ext);
570 case StorageClass::Input:
571 case StorageClass::Output:
573 static constexpr
auto ext = Extension::SPV_KHR_16bit_storage;
574 extensions.push_back(ext);
582 void TypeCapabilityVisitor::addConcrete(
ScalarType type) {
589 #define STORAGE_CASE(storage, cap8, cap16) \
590 case StorageClass::storage: { \
591 if (bitwidth == 8) { \
592 static constexpr auto cap = Capability::cap8; \
593 capabilities.push_back(cap); \
596 if (bitwidth == 16) { \
597 static constexpr auto cap = Capability::cap16; \
598 capabilities.push_back(cap); \
609 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
611 StorageBuffer16BitAccess);
612 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
614 case StorageClass::Input:
615 case StorageClass::Output: {
616 if (bitwidth == 16) {
617 static constexpr
auto cap = Capability::StorageInputOutput16;
618 capabilities.push_back(cap);
632 #define WIDTH_CASE(type, width) \
634 static constexpr auto cap = Capability::type##width; \
635 capabilities.push_back(cap); \
638 if (
auto intType = dyn_cast<IntegerType>(type)) {
647 llvm_unreachable(
"invalid bitwidth to getCapabilities");
650 assert(isa<FloatType>(type));
653 if (isa<BFloat16Type>(type)) {
654 static constexpr
auto cap = Capability::BFloat16TypeKHR;
655 capabilities.push_back(cap);
657 static constexpr
auto cap = Capability::Float16;
658 capabilities.push_back(cap);
666 llvm_unreachable(
"invalid bitwidth to getCapabilities");
679 if (llvm::isa<SPIRVDialect>(type.
getDialect()))
681 if (llvm::isa<ScalarType>(type))
683 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
685 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
686 return llvm::isa<ScalarType>(tensorArmType.getElementType());
695 std::optional<StorageClass> storage) {
696 TypeExtensionVisitor{extensions, storage}.add(*
this);
701 std::optional<StorageClass> storage) {
702 TypeCapabilityVisitor{capabilities, storage}.add(*
this);
720 .Case<ArrayType>([](
ArrayType type) -> std::optional<int64_t> {
724 if (std::optional<int64_t> size = elementType.getSizeInBytes())
728 .Case<VectorType, TensorArmType>([](
auto type) -> std::optional<int64_t> {
729 if (std::optional<int64_t> elementSize =
734 .Default(std::optional<int64_t>());
771 auto image = dyn_cast<ImageType>(imageType);
773 return emitError() <<
"expected image type";
778 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
779 return emitError() <<
"Dim must not be SubpassData or Buffer";
809 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
810 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
811 numStructDecorations(0), structDecorationsInfo(nullptr),
812 identifier(identifier) {}
817 unsigned numMembers,
Type const *memberTypes,
820 unsigned numStructDecorations,
822 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
823 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
824 memberDecorationsInfo(memberDecorationsInfo),
825 numStructDecorations(numStructDecorations),
826 structDecorationsInfo(structDecorationsInfo) {}
855 if (isIdentified()) {
857 return getIdentifier() == std::get<0>(key);
860 return key ==
KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
861 getMemberDecorationsInfo(), getStructDecorationsInfo());
872 StringRef keyIdentifier = std::get<0>(key);
874 if (!keyIdentifier.empty()) {
875 StringRef identifier = allocator.
copyInto(keyIdentifier);
886 const Type *typesList =
nullptr;
887 if (!keyTypes.empty()) {
888 typesList = allocator.
copyInto(keyTypes).data();
892 if (!std::get<2>(key).empty()) {
894 assert(keyOffsetInfo.size() == keyTypes.size() &&
895 "size of offset information must be same as the size of number of "
897 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
901 unsigned numMemberDecorations = 0;
902 if (!std::get<3>(key).empty()) {
903 auto keyMemberDecorations = std::get<3>(key);
904 numMemberDecorations = keyMemberDecorations.size();
905 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
909 unsigned numStructDecorations = 0;
910 if (!std::get<4>(key).empty()) {
911 auto keyStructDecorations = std::get<4>(key);
912 numStructDecorations = keyStructDecorations.size();
913 structDecorationList = allocator.
copyInto(keyStructDecorations).data();
917 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
918 memberDecorationList, numStructDecorations, structDecorationList);
922 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
933 if (memberDecorationsInfo) {
935 numMemberDecorations);
941 if (structDecorationsInfo)
943 numStructDecorations);
967 if (memberTypesAndIsBodySet.getInt() &&
968 (getMemberTypes() != structMemberTypes ||
969 getOffsetInfo() != structOffsetInfo ||
970 getMemberDecorationsInfo() != structMemberDecorationInfo ||
971 getStructDecorationsInfo() != structDecorationInfo))
974 memberTypesAndIsBodySet.setInt(
true);
975 numMembers = structMemberTypes.size();
978 if (!structMemberTypes.empty())
979 memberTypesAndIsBodySet.setPointer(
980 allocator.
copyInto(structMemberTypes).data());
982 if (!structOffsetInfo.empty()) {
983 assert(structOffsetInfo.size() == structMemberTypes.size() &&
984 "size of offset information must be same as the size of number of "
986 offsetInfo = allocator.
copyInto(structOffsetInfo).data();
989 if (!structMemberDecorationInfo.empty()) {
990 numMemberDecorations = structMemberDecorationInfo.size();
991 memberDecorationsInfo =
992 allocator.
copyInto(structMemberDecorationInfo).data();
995 if (!structDecorationInfo.empty()) {
996 numStructDecorations = structDecorationInfo.size();
997 structDecorationsInfo = allocator.
copyInto(structDecorationInfo).data();
1018 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1022 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1023 sortedMemberDecorations.end());
1026 llvm::array_pod_sort(sortedStructDecorations.begin(),
1027 sortedStructDecorations.end());
1029 return Base::get(memberTypes.vec().front().getContext(),
1030 StringRef(), memberTypes, offsetInfo,
1031 sortedMemberDecorations, sortedStructDecorations);
1035 StringRef identifier) {
1036 assert(!identifier.empty() &&
1037 "StructType identifier must be non-empty string");
1058 return newStructType;
1069 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1081 getImpl()->getStructDecorationsInfo())
1082 if (info.decoration == decoration)
1090 return getImpl()->offsetInfo[index];
1096 memberDecorations.clear();
1097 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1098 memberDecorations.append(implMemberDecorations.begin(),
1099 implMemberDecorations.end());
1106 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1107 decorationsInfo.clear();
1108 for (
const auto &memberDecoration : memberDecorations) {
1109 if (memberDecoration.memberIndex == index) {
1110 decorationsInfo.push_back(memberDecoration);
1112 if (memberDecoration.memberIndex > index) {
1122 structDecorations.clear();
1123 auto implDecorations =
getImpl()->getStructDecorationsInfo();
1124 structDecorations.append(implDecorations.begin(), implDecorations.end());
1132 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1138 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1153 : columnType(columnType), columnCount(columnCount) {}
1155 using KeyTy = std::tuple<Type, uint32_t>;
1166 return key ==
KeyTy(columnType, columnCount);
1178 Type columnType, uint32_t columnCount) {
1185 Type columnType, uint32_t columnCount) {
1186 if (columnCount < 2 || columnCount > 4)
1187 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1190 return emitError() <<
"matrix columns must be vectors of floats";
1194 if (columnShape.size() != 1)
1195 return emitError() <<
"matrix columns must be 1D vectors";
1197 if (columnShape[0] < 2 || columnShape[0] > 4)
1198 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1205 if (
auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206 if (llvm::isa<FloatType>(vectorType.getElementType()))
1215 return llvm::cast<VectorType>(
getImpl()->columnType).getElementType();
1221 return llvm::cast<VectorType>(
getImpl()->columnType).getShape()[0];
1228 void TypeCapabilityVisitor::addConcrete(
MatrixType type) {
1230 static constexpr
auto cap = Capability::Matrix;
1231 capabilities.push_back(cap);
1243 auto [shape, elementType] = key;
1250 auto [shape, elementType] = key;
1251 return llvm::hash_combine(shape, elementType);
1255 return key ==
KeyTy(shape, elementType);
1259 : shape(shape), elementType(elementType) {}
1270 Type elementType)
const {
1277 void TypeExtensionVisitor::addConcrete(
TensorArmType type) {
1279 static constexpr
auto ext = Extension::SPV_ARM_tensors;
1280 extensions.push_back(ext);
1283 void TypeCapabilityVisitor::addConcrete(
TensorArmType type) {
1285 static constexpr
auto cap = Capability::TensorsARM;
1286 capabilities.push_back(cap);
1292 if (llvm::is_contained(shape, 0))
1293 return emitError() <<
"arm.tensor do not support dimensions = 0";
1294 if (llvm::any_of(shape, [](int64_t dim) {
return dim < 0; }) &&
1295 llvm::any_of(shape, [](int64_t dim) {
return dim > 0; }))
1297 <<
"arm.tensor shape dimensions must be either fully dynamic or "
1306 void SPIRVDialect::registerTypes() {
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::array< int64_t, 2 > shape
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
unsigned numStructDecorations
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const
ArrayRef< int64_t > shape