18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
36class TypeExtensionVisitor {
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
44 void add(SPIRVType type) {
49 .Case<CooperativeMatrixType, ImageType, PointerType, ScalarType,
51 [
this](
auto concreteType) { addConcrete(concreteType); })
52 .Case<ArrayType, MatrixType, RuntimeArrayType, VectorType>(
53 [
this](
auto concreteType) {
add(concreteType.getElementType()); })
54 .Case([
this](SampledImageType concreteType) {
57 .Case([
this](StructType concreteType) {
61 .Case<SamplerType, NamedBarrierType>([](
auto) { })
62 .DefaultUnreachable(
"Unhandled type");
65 void add(Type type) {
add(cast<SPIRVType>(type)); }
69 void addConcrete(CooperativeMatrixType type);
70 void addConcrete(ImageType type);
71 void addConcrete(PointerType type);
72 void addConcrete(ScalarType type);
73 void addConcrete(TensorArmType type);
75 template <Extension... Es>
77 static constexpr Extension exts[] = {Es...};
78 extensions.push_back(exts);
82 std::optional<StorageClass> storage;
83 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
93class TypeCapabilityVisitor {
96 std::optional<StorageClass> storage)
97 : capabilities(capabilities), storage(storage) {}
101 void add(SPIRVType type) {
106 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
107 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
108 [
this](
auto concreteType) { addConcrete(concreteType); })
109 .Case([
this](ArrayType concreteType) {
112 .Case([
this](SampledImageType concreteType) {
115 .Case([
this](StructType concreteType) {
119 .Case([](SamplerType) { })
121 [
this](NamedBarrierType) { pushCaps<Capability::NamedBarrier>(); })
122 .DefaultUnreachable(
"Unhandled type");
125 void add(Type type) {
add(cast<SPIRVType>(type)); }
129 void addConcrete(CooperativeMatrixType type);
130 void addConcrete(ImageType type);
131 void addConcrete(MatrixType type);
132 void addConcrete(PointerType type);
133 void addConcrete(RuntimeArrayType type);
134 void addConcrete(ScalarType type);
135 void addConcrete(TensorArmType type);
136 void addConcrete(VectorType type);
138 template <Capability... Cs>
140 static constexpr Capability caps[] = {Cs...};
141 capabilities.push_back(caps);
145 std::optional<StorageClass> storage;
146 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
156 using KeyTy = std::tuple<Type, unsigned, unsigned>;
177 assert(elementCount &&
"ArrayType needs at least one element");
184 assert(elementCount &&
"ArrayType needs at least one element");
199 if (
auto vectorType = dyn_cast<VectorType>(type))
207 return type.getRank() == 1 &&
208 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
209 (isa<ScalarType>(type.getElementType()) ||
210 isa<PointerType>(type.getElementType()));
216 TensorArmType>([](
auto type) {
return type.getElementType(); })
219 .DefaultUnreachable(
"Invalid composite type");
225 [](
auto type) {
return type.getNumElements(); })
227 .DefaultUnreachable(
"Invalid type for number of elements query");
231 return !isa<CooperativeMatrixType, RuntimeArrayType>(*
this);
234void TypeCapabilityVisitor::addConcrete(VectorType type) {
235 add(type.getElementType());
237 int64_t vecSize = type.getNumElements();
238 if (vecSize == 8 || vecSize == 16)
239 pushCaps<Capability::Vector16>();
261 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
275 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
276 use(std::get<4>(key)) {}
282 CooperativeMatrixUseKHR
use;
287 uint32_t columns, Scope scope,
288 CooperativeMatrixUseKHR use) {
299 return static_cast<uint32_t
>(
getImpl()->shape[0]);
304 return static_cast<uint32_t
>(
getImpl()->shape[1]);
319 pushExts<Extension::SPV_KHR_cooperative_matrix>();
325 pushCaps<Capability::CooperativeMatrixKHR>();
327 pushCaps<Capability::BFloat16CooperativeMatrixKHR>();
329 pushCaps<Capability::Float8CooperativeMatrixEXT>();
342 static_assert((1 << 3) > getMaxEnumValForDim(),
343 "Not enough bits to encode Dim value");
348 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
349 "Not enough bits to encode ImageDepthInfo value");
354 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
355 "Not enough bits to encode ImageArrayedInfo value");
360 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
361 "Not enough bits to encode ImageSamplingInfo value");
366 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
367 "Not enough bits to encode ImageSamplerUseInfo value");
372 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
373 "Not enough bits to encode ImageFormat value");
379 using KeyTy = std::tuple<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
380 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
409 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
425 return getImpl()->samplingInfo;
429 return getImpl()->samplerUseInfo;
434void TypeExtensionVisitor::addConcrete(
ImageType type) {
438 intTy && intTy.getWidth() == 64)
439 pushExts<Extension::SPV_EXT_shader_image_int64>();
443void TypeCapabilityVisitor::addConcrete(
ImageType type) {
448 bool isMultisampled =
450 bool isArrayed = type.
getArrayedInfo() == ImageArrayedInfo::Arrayed;
452 bool noSampler = sampler == ImageSamplerUseInfo::NoSampler;
453 bool needSampler = sampler == ImageSamplerUseInfo::NeedSampler;
458 pushCaps<Capability::Sampled1D>();
460 pushCaps<Capability::Image1D>();
462 pushCaps<Capability::Image1D, Capability::Sampled1D>();
465 if (isMultisampled && noSampler)
466 pushCaps<Capability::StorageImageMultisample>();
467 if (isMultisampled && isArrayed)
468 pushCaps<Capability::ImageMSArray>();
473 pushCaps<Capability::Shader>();
475 pushCaps<Capability::ImageCubeArray>();
478 pushCaps<Capability::ImageRect, Capability::SampledRect>();
482 pushCaps<Capability::SampledBuffer>();
484 pushCaps<Capability::ImageBuffer>();
486 pushCaps<Capability::ImageBuffer, Capability::SampledBuffer>();
488 case Dim::SubpassData:
489 pushCaps<Capability::InputAttachment>();
494 capabilities.push_back(*fmtCaps);
498 intTy && intTy.getWidth() == 64)
499 pushCaps<Capability::Int64ImageEXT>();
511 using KeyTy = std::pair<Type, StorageClass>;
537 return getImpl()->storageClass;
540void TypeExtensionVisitor::addConcrete(
PointerType type) {
543 std::optional<StorageClass> oldStorageClass = storage;
546 storage = oldStorageClass;
549 extensions.push_back(*scExts);
552void TypeCapabilityVisitor::addConcrete(
PointerType type) {
555 std::optional<StorageClass> oldStorageClass = storage;
558 storage = oldStorageClass;
561 capabilities.push_back(*scCaps);
569 using KeyTy = std::pair<Type, unsigned>;
602 pushCaps<Capability::Shader>();
610 if (
auto floatType = dyn_cast<FloatType>(type)) {
613 if (
auto intType = dyn_cast<IntegerType>(type)) {
620 if (type.isF8E4M3FN() || type.isF8E5M2())
622 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
626 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
629void TypeExtensionVisitor::addConcrete(
ScalarType type) {
631 pushExts<Extension::SPV_KHR_bfloat16>();
634 pushExts<Extension::SPV_EXT_float8>();
643 case StorageClass::PushConstant:
644 case StorageClass::StorageBuffer:
645 case StorageClass::Uniform:
647 pushExts<Extension::SPV_KHR_8bit_storage>();
649 case StorageClass::Input:
650 case StorageClass::Output:
652 pushExts<Extension::SPV_KHR_16bit_storage>();
659void TypeCapabilityVisitor::addConcrete(
ScalarType type) {
666#define STORAGE_CASE(storage, cap8, cap16) \
667 case StorageClass::storage: { \
668 if (bitwidth == 8) { \
669 pushCaps<Capability::cap8>(); \
672 if (bitwidth == 16) { \
673 pushCaps<Capability::cap16>(); \
684 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
686 StorageBuffer16BitAccess);
689 case StorageClass::Input:
690 case StorageClass::Output: {
691 if (bitwidth == 16) {
692 pushCaps<Capability::StorageInputOutput16>();
706#define WIDTH_CASE(type, width) \
708 pushCaps<Capability::type##width>(); \
711 if (
auto intType = dyn_cast<IntegerType>(type)) {
720 llvm_unreachable(
"invalid bitwidth to getCapabilities");
723 assert(isa<FloatType>(type));
727 pushCaps<Capability::Float8EXT>();
729 llvm_unreachable(
"invalid 8-bit float type to getCapabilities");
734 pushCaps<Capability::BFloat16TypeKHR>();
736 pushCaps<Capability::Float16>();
743 llvm_unreachable(
"invalid bitwidth to getCapabilities");
758 if (isa<ScalarType>(type))
760 if (
auto vectorType = dyn_cast<VectorType>(type))
762 if (
auto tensorArmType = dyn_cast<TensorArmType>(type))
763 return isa<ScalarType>(tensorArmType.getElementType());
772 std::optional<StorageClass> storage) {
773 TypeExtensionVisitor{extensions, storage}.add(*
this);
778 std::optional<StorageClass> storage) {
779 TypeCapabilityVisitor{capabilities, storage}.add(*
this);
784 .Case([](
ScalarType type) -> std::optional<int64_t> {
797 .Case([](
ArrayType type) -> std::optional<int64_t> {
800 auto elementType = cast<SPIRVType>(type.getElementType());
801 if (std::optional<int64_t> size = elementType.getSizeInBytes())
802 return (*size + type.getArrayStride()) * type.getNumElements();
805 .Case<VectorType, TensorArmType>([](
auto type) -> std::optional<int64_t> {
806 if (std::optional<int64_t> elementSize =
807 cast<ScalarType>(type.getElementType()).getSizeInBytes())
808 return *elementSize * type.getNumElements();
811 .Default(std::nullopt);
848 auto image = dyn_cast<ImageType>(imageType);
850 return emitError() <<
"expected image type";
855 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
856 return emitError() <<
"Dim must not be SubpassData or Buffer";
965 StringRef keyIdentifier = std::get<0>(key);
967 if (!keyIdentifier.empty()) {
979 const Type *typesList =
nullptr;
980 if (!keyTypes.empty()) {
981 typesList = allocator.
copyInto(keyTypes).data();
985 if (!std::get<2>(key).empty()) {
987 assert(keyOffsetInfo.size() == keyTypes.size() &&
988 "size of offset information must be same as the size of number of "
990 offsetInfoList = allocator.
copyInto(keyOffsetInfo).data();
995 if (!std::get<3>(key).empty()) {
996 auto keyMemberDecorations = std::get<3>(key);
998 memberDecorationList = allocator.
copyInto(keyMemberDecorations).data();
1003 if (!std::get<4>(key).empty()) {
1004 auto keyStructDecorations = std::get<4>(key);
1006 structDecorationList = allocator.
copyInto(keyStructDecorations).data();
1071 if (!structMemberTypes.empty())
1073 allocator.
copyInto(structMemberTypes).data());
1075 if (!structOffsetInfo.empty()) {
1076 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1077 "size of offset information must be same as the size of number of "
1082 if (!structMemberDecorationInfo.empty()) {
1085 allocator.
copyInto(structMemberDecorationInfo).data();
1088 if (!structDecorationInfo.empty()) {
1111 assert(!memberTypes.empty() &&
"Struct needs at least one member type");
1115 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1116 sortedMemberDecorations.end());
1119 llvm::array_pod_sort(sortedStructDecorations.begin(),
1120 sortedStructDecorations.end());
1122 return Base::get(memberTypes.vec().front().getContext(),
1123 StringRef(), memberTypes, offsetInfo,
1124 sortedMemberDecorations, sortedStructDecorations);
1128 StringRef identifier) {
1129 assert(!identifier.empty() &&
1130 "StructType identifier must be non-empty string");
1151 return newStructType;
1162 return getImpl()->memberTypesAndIsBodySet.getPointer()[
index];
1174 getImpl()->getStructDecorationsInfo())
1175 if (info.decoration == decoration)
1189 memberDecorations.clear();
1190 auto implMemberDecorations =
getImpl()->getMemberDecorationsInfo();
1191 memberDecorations.append(implMemberDecorations.begin(),
1192 implMemberDecorations.end());
1199 auto memberDecorations =
getImpl()->getMemberDecorationsInfo();
1200 decorationsInfo.clear();
1201 for (
const auto &memberDecoration : memberDecorations) {
1202 if (memberDecoration.memberIndex ==
index) {
1203 decorationsInfo.push_back(memberDecoration);
1205 if (memberDecoration.memberIndex >
index) {
1215 structDecorations.clear();
1216 auto implDecorations =
getImpl()->getStructDecorationsInfo();
1217 structDecorations.append(implDecorations.begin(), implDecorations.end());
1225 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1231 return llvm::hash_combine(memberDecorationInfo.
memberIndex,
1237 return llvm::hash_value(structDecorationInfo.
decoration);
1248 using KeyTy = std::tuple<Type, int64_t>;
1252 shape({cast<VectorType>(std::get<0>(key)).getShape()[0],
1253 std::get<1>(key)}) {}
1276 Type columnType, uint32_t columnCount) {
1283 Type columnType, uint32_t columnCount) {
1284 if (columnCount < 2 || columnCount > 4)
1285 return emitError() <<
"matrix can have 2, 3, or 4 columns only";
1288 return emitError() <<
"matrix columns must be vectors of floats";
1292 if (columnShape.size() != 1)
1293 return emitError() <<
"matrix columns must be 1D vectors";
1295 if (columnShape[0] < 2 || columnShape[0] > 4)
1296 return emitError() <<
"matrix columns must be of size 2, 3, or 4";
1303 if (
auto vectorType = dyn_cast<VectorType>(columnType)) {
1304 if (isa<FloatType>(vectorType.getElementType()))
1313 return cast<VectorType>(
getImpl()->columnType).getElementType();
1318 assert(
getImpl()->
shape[1] <= std::numeric_limits<unsigned>::max());
1319 return static_cast<uint32_t
>(
getImpl()->shape[1]);
1324 assert(
getImpl()->
shape[0] <= std::numeric_limits<unsigned>::max());
1325 return static_cast<uint32_t
>(
getImpl()->shape[0]);
1334void TypeCapabilityVisitor::addConcrete(
MatrixType type) {
1336 pushCaps<Capability::Matrix>();
1375 Type elementType)
const {
1384 pushExts<Extension::SPV_ARM_tensors>();
1387void TypeCapabilityVisitor::addConcrete(
TensorArmType type) {
1389 pushCaps<Capability::TensorsARM>();
1395 if (llvm::is_contained(
shape, 0))
1396 return emitError() <<
"arm.tensor do not support dimensions = 0";
1397 if (llvm::any_of(
shape, [](
int64_t dim) {
return dim < 0; }) &&
1398 llvm::any_of(
shape, [](
int64_t dim) {
return dim > 0; }))
1400 <<
"arm.tensor shape dimensions must be either fully dynamic or "
1409void SPIRVDialect::registerTypes() {
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static ConcreteType get(MLIRContext *ctx, Args &&...args)
LogicalResult mutate(Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
ArrayRef< int64_t > getShape() const
unsigned getNumRows() const
Returns the number of rows.
static NamedBarrierType get(MLIRContext *context)
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
ArrayTypeStorage(const KeyTy &key)
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::array< int64_t, 2 > shape
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
bool operator==(const KeyTy &key) const
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
MatrixTypeStorage(const KeyTy &key)
std::array< int64_t, 2 > shape
std::tuple< Type, int64_t > KeyTy
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
RuntimeArrayTypeStorage(const KeyTy &key)
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::pair< Type, unsigned > KeyTy
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
SampledImageTypeStorage(const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
unsigned numStructDecorations
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
ArrayRef< Type > getMemberTypes() const
unsigned numMemberDecorations
bool isIdentified() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const
ArrayRef< int64_t > shape