MLIR 22.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <cstdint>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::spirv;
27
28namespace {
29// Helper function to collect extensions implied by a type by visiting all its
30// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31//
32// Serves as the source-of-truth for type extension information. All extension
33// logic should be added to this class, while the
34// `SPIRVType::getExtensions` function should not handle extension-related logic
35// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36class TypeExtensionVisitor {
37public:
38 TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
41
42 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43 // as seen and dispatches to the right concrete `.add` function.
44 void add(SPIRVType type) {
45 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46 return;
47
49 .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
50 [this](auto concreteType) { addConcrete(concreteType); })
51 .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52 [this](auto concreteType) { add(concreteType.getElementType()); })
53 .Case<SampledImageType>([this](SampledImageType concreteType) {
54 add(concreteType.getImageType());
55 })
56 .Case<StructType>([this](StructType concreteType) {
57 for (Type elementType : concreteType.getElementTypes())
58 add(elementType);
59 })
60 .DefaultUnreachable("Unhandled type");
61 }
62
63 void add(Type type) { add(cast<SPIRVType>(type)); }
64
65private:
66 // Types that add unique extensions.
67 void addConcrete(CooperativeMatrixType type);
68 void addConcrete(PointerType type);
69 void addConcrete(ScalarType type);
70 void addConcrete(TensorArmType type);
71
73 std::optional<StorageClass> storage;
74 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
75};
76
77// Helper function to collect capabilities implied by a type by visiting all its
78// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
79//
80// Serves as the source-of-truth for type capability information. All capability
81// logic should be added to this class, while the
82// `SPIRVType::getCapabilities` function should not handle capability-related
83// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
84class TypeCapabilityVisitor {
85public:
86 TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
87 std::optional<StorageClass> storage)
88 : capabilities(capabilities), storage(storage) {}
89
90 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
91 // as seen and dispatches to the right concrete `.add` function.
92 void add(SPIRVType type) {
93 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
94 return;
95
97 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
98 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
99 [this](auto concreteType) { addConcrete(concreteType); })
100 .Case<ArrayType>([this](ArrayType concreteType) {
101 add(concreteType.getElementType());
102 })
103 .Case<SampledImageType>([this](SampledImageType concreteType) {
104 add(concreteType.getImageType());
105 })
106 .Case<StructType>([this](StructType concreteType) {
107 for (Type elementType : concreteType.getElementTypes())
108 add(elementType);
109 })
110 .DefaultUnreachable("Unhandled type");
111 }
112
113 void add(Type type) { add(cast<SPIRVType>(type)); }
114
115private:
116 // Types that add unique extensions.
117 void addConcrete(CooperativeMatrixType type);
118 void addConcrete(ImageType type);
119 void addConcrete(MatrixType type);
120 void addConcrete(PointerType type);
121 void addConcrete(RuntimeArrayType type);
122 void addConcrete(ScalarType type);
123 void addConcrete(TensorArmType type);
124 void addConcrete(VectorType type);
125
127 std::optional<StorageClass> storage;
128 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
129};
130
131} // namespace
132
133//===----------------------------------------------------------------------===//
134// ArrayType
135//===----------------------------------------------------------------------===//
136
138 using KeyTy = std::tuple<Type, unsigned, unsigned>;
139
141 const KeyTy &key) {
142 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
143 }
144
145 bool operator==(const KeyTy &key) const {
146 return key == KeyTy(elementType, elementCount, stride);
147 }
148
150 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
151 stride(std::get<2>(key)) {}
152
154 unsigned elementCount;
155 unsigned stride;
156};
157
158ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
159 assert(elementCount && "ArrayType needs at least one element");
160 return Base::get(elementType.getContext(), elementType, elementCount,
161 /*stride=*/0);
162}
163
164ArrayType ArrayType::get(Type elementType, unsigned elementCount,
165 unsigned stride) {
166 assert(elementCount && "ArrayType needs at least one element");
167 return Base::get(elementType.getContext(), elementType, elementCount, stride);
168}
169
170unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
171
172Type ArrayType::getElementType() const { return getImpl()->elementType; }
173
174unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
175
176//===----------------------------------------------------------------------===//
177// CompositeType
178//===----------------------------------------------------------------------===//
179
181 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
182 return isValid(vectorType);
186}
187
188bool CompositeType::isValid(VectorType type) {
189 return type.getRank() == 1 &&
190 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
191 llvm::isa<ScalarType>(type.getElementType());
192}
193
195 return TypeSwitch<Type, Type>(*this)
197 TensorArmType>([](auto type) { return type.getElementType(); })
198 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
199 .Case<StructType>(
200 [index](StructType type) { return type.getElementType(index); })
201 .DefaultUnreachable("Invalid composite type");
202}
203
206 .Case<ArrayType, StructType, TensorArmType, VectorType>(
207 [](auto type) { return type.getNumElements(); })
208 .Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
209 .DefaultUnreachable("Invalid type for number of elements query");
210}
211
213 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
214}
215
216void TypeCapabilityVisitor::addConcrete(VectorType type) {
217 add(type.getElementType());
218
219 int64_t vecSize = type.getNumElements();
220 if (vecSize == 8 || vecSize == 16) {
221 static constexpr auto cap = Capability::Vector16;
222 capabilities.push_back(cap);
223 }
224}
225
226//===----------------------------------------------------------------------===//
227// CooperativeMatrixType
228//===----------------------------------------------------------------------===//
229
231 // In the specification dimensions of the Cooperative Matrix are 32-bit
232 // integers --- the initial implementation kept those values as such. However,
233 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
234 // as 32-bits and expose it as int64_t through `getShape`, however, this
235 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
236 // 32-bits integers would require an extra logic and storage. So, we diverge
237 // from the spec and internally represent the dimensions as 64-bit integers,
238 // so we can easily return an `ArrayRef` from `getShape` without any extra
239 // logic. Alternatively, we could store both rows and columns (both 32-bits)
240 // and shape (64-bits), assigning rows and columns to shape whenever
241 // `getShape` is called. This would be at the cost of extra logic and storage.
242 // Note: Because `ArrayRef` is returned we cannot construct an object in
243 // `getShape` on the fly.
244 using KeyTy =
245 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
246
248 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
249 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
251 }
252
253 bool operator==(const KeyTy &key) const {
254 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
255 }
256
258 : elementType(std::get<0>(key)),
259 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
260 use(std::get<4>(key)) {}
261
263 // [#rows, #columns]
264 std::array<int64_t, 2> shape;
265 Scope scope;
266 CooperativeMatrixUseKHR use;
267};
268
270 uint32_t rows,
271 uint32_t columns, Scope scope,
272 CooperativeMatrixUseKHR use) {
273 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
274 use);
275}
276
278 return getImpl()->elementType;
279}
280
282 assert(getImpl()->shape[0] != ShapedType::kDynamic);
283 return static_cast<uint32_t>(getImpl()->shape[0]);
284}
285
287 assert(getImpl()->shape[1] != ShapedType::kDynamic);
288 return static_cast<uint32_t>(getImpl()->shape[1]);
289}
290
294
295Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
296
297CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
298 return getImpl()->use;
299}
300
301void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
302 add(type.getElementType());
303 static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
304 extensions.push_back(ext);
305}
306
307void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
308 add(type.getElementType());
309 static constexpr auto caps = Capability::CooperativeMatrixKHR;
310 capabilities.push_back(caps);
311}
312
313//===----------------------------------------------------------------------===//
314// ImageType
315//===----------------------------------------------------------------------===//
316
317template <typename T>
318static constexpr unsigned getNumBits() {
319 return 0;
320}
321template <>
322constexpr unsigned getNumBits<Dim>() {
323 static_assert((1 << 3) > getMaxEnumValForDim(),
324 "Not enough bits to encode Dim value");
325 return 3;
326}
327template <>
328constexpr unsigned getNumBits<ImageDepthInfo>() {
329 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
330 "Not enough bits to encode ImageDepthInfo value");
331 return 2;
332}
333template <>
334constexpr unsigned getNumBits<ImageArrayedInfo>() {
335 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
336 "Not enough bits to encode ImageArrayedInfo value");
337 return 1;
338}
339template <>
340constexpr unsigned getNumBits<ImageSamplingInfo>() {
341 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
342 "Not enough bits to encode ImageSamplingInfo value");
343 return 1;
344}
345template <>
347 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
348 "Not enough bits to encode ImageSamplerUseInfo value");
349 return 2;
350}
351template <>
352constexpr unsigned getNumBits<ImageFormat>() {
353 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
354 "Not enough bits to encode ImageFormat value");
355 return 6;
356}
357
359public:
360 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
361 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
362
364 const KeyTy &key) {
365 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
366 }
367
368 bool operator==(const KeyTy &key) const {
371 }
372
374 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
375 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
376 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
377 format(std::get<6>(key)) {}
378
386};
387
389ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
390 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
391 value) {
392 return Base::get(std::get<0>(value).getContext(), value);
393}
394
395Type ImageType::getElementType() const { return getImpl()->elementType; }
396
397Dim ImageType::getDim() const { return getImpl()->dim; }
398
399ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
400
401ImageArrayedInfo ImageType::getArrayedInfo() const {
402 return getImpl()->arrayedInfo;
403}
404
405ImageSamplingInfo ImageType::getSamplingInfo() const {
406 return getImpl()->samplingInfo;
407}
408
409ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
410 return getImpl()->samplerUseInfo;
411}
412
413ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
414
415void TypeCapabilityVisitor::addConcrete(ImageType type) {
416 if (auto dimCaps = spirv::getCapabilities(type.getDim()))
417 capabilities.push_back(*dimCaps);
418
419 if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
420 capabilities.push_back(*fmtCaps);
421
422 add(type.getElementType());
423}
424
425//===----------------------------------------------------------------------===//
426// PointerType
427//===----------------------------------------------------------------------===//
428
430 // (Type, StorageClass) as the key: Type stored in this struct, and
431 // StorageClass stored as TypeStorage's subclass data.
432 using KeyTy = std::pair<Type, StorageClass>;
433
435 const KeyTy &key) {
436 return new (allocator.allocate<PointerTypeStorage>())
438 }
439
440 bool operator==(const KeyTy &key) const {
441 return key == KeyTy(pointeeType, storageClass);
442 }
443
445 : pointeeType(key.first), storageClass(key.second) {}
446
448 StorageClass storageClass;
449};
450
451PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
452 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
453}
454
455Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
456
457StorageClass PointerType::getStorageClass() const {
458 return getImpl()->storageClass;
459}
460
461void TypeExtensionVisitor::addConcrete(PointerType type) {
462 // Use this pointer type's storage class because this pointer indicates we are
463 // using the pointee type in that specific storage class.
464 std::optional<StorageClass> oldStorageClass = storage;
465 storage = type.getStorageClass();
466 add(type.getPointeeType());
467 storage = oldStorageClass;
468
469 if (auto scExts = spirv::getExtensions(type.getStorageClass()))
470 extensions.push_back(*scExts);
471}
472
473void TypeCapabilityVisitor::addConcrete(PointerType type) {
474 // Use this pointer type's storage class because this pointer indicates we are
475 // using the pointee type in that specific storage class.
476 std::optional<StorageClass> oldStorageClass = storage;
477 storage = type.getStorageClass();
478 add(type.getPointeeType());
479 storage = oldStorageClass;
480
481 if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
482 capabilities.push_back(*scCaps);
483}
484
485//===----------------------------------------------------------------------===//
486// RuntimeArrayType
487//===----------------------------------------------------------------------===//
488
490 using KeyTy = std::pair<Type, unsigned>;
491
493 const KeyTy &key) {
494 return new (allocator.allocate<RuntimeArrayTypeStorage>())
496 }
497
498 bool operator==(const KeyTy &key) const {
499 return key == KeyTy(elementType, stride);
500 }
501
503 : elementType(key.first), stride(key.second) {}
504
506 unsigned stride;
507};
508
510 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
511}
512
513RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
514 return Base::get(elementType.getContext(), elementType, stride);
515}
516
517Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
518
519unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
520
521void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
522 add(type.getElementType());
523 static constexpr auto cap = Capability::Shader;
524 capabilities.push_back(cap);
525}
526
527//===----------------------------------------------------------------------===//
528// ScalarType
529//===----------------------------------------------------------------------===//
530
532 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
533 return isValid(floatType);
534 }
535 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
536 return isValid(intType);
537 }
538 return false;
539}
540
541bool ScalarType::isValid(FloatType type) {
542 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
543}
544
545bool ScalarType::isValid(IntegerType type) {
546 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
547}
548
549void TypeExtensionVisitor::addConcrete(ScalarType type) {
550 if (isa<BFloat16Type>(type)) {
551 static constexpr auto ext = Extension::SPV_KHR_bfloat16;
552 extensions.push_back(ext);
553 }
554
555 // 8- or 16-bit integer/floating-point numbers will require extra extensions
556 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
557 // SPV_KHR_8bit_storage for more details.
558 if (!storage)
559 return;
560
561 switch (*storage) {
562 case StorageClass::PushConstant:
563 case StorageClass::StorageBuffer:
564 case StorageClass::Uniform:
565 if (type.getIntOrFloatBitWidth() == 8) {
566 static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
567 extensions.push_back(ext);
568 }
569 [[fallthrough]];
570 case StorageClass::Input:
571 case StorageClass::Output:
572 if (type.getIntOrFloatBitWidth() == 16) {
573 static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
574 extensions.push_back(ext);
575 }
576 break;
577 default:
578 break;
579 }
580}
581
582void TypeCapabilityVisitor::addConcrete(ScalarType type) {
583 unsigned bitwidth = type.getIntOrFloatBitWidth();
584
585 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
586 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587 // SPV_KHR_8bit_storage for more details.
588
589#define STORAGE_CASE(storage, cap8, cap16) \
590 case StorageClass::storage: { \
591 if (bitwidth == 8) { \
592 static constexpr auto cap = Capability::cap8; \
593 capabilities.push_back(cap); \
594 return; \
595 } \
596 if (bitwidth == 16) { \
597 static constexpr auto cap = Capability::cap16; \
598 capabilities.push_back(cap); \
599 return; \
600 } \
601 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
602 /* storage classes. Fall through to the next section. */ \
603 } break
604
605 // This part only handles the cases where special bitwidths appearing in
606 // interface storage classes.
607 if (storage) {
608 switch (*storage) {
609 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
610 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
611 StorageBuffer16BitAccess);
612 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
613 StorageUniform16);
614 case StorageClass::Input:
615 case StorageClass::Output: {
616 if (bitwidth == 16) {
617 static constexpr auto cap = Capability::StorageInputOutput16;
618 capabilities.push_back(cap);
619 return;
620 }
621 break;
622 }
623 default:
624 break;
625 }
626 }
627#undef STORAGE_CASE
628
629 // For other non-interface storage classes, require a different set of
630 // capabilities for special bitwidths.
631
632#define WIDTH_CASE(type, width) \
633 case width: { \
634 static constexpr auto cap = Capability::type##width; \
635 capabilities.push_back(cap); \
636 } break
637
638 if (auto intType = dyn_cast<IntegerType>(type)) {
639 switch (bitwidth) {
640 WIDTH_CASE(Int, 8);
641 WIDTH_CASE(Int, 16);
642 WIDTH_CASE(Int, 64);
643 case 1:
644 case 32:
645 break;
646 default:
647 llvm_unreachable("invalid bitwidth to getCapabilities");
648 }
649 } else {
650 assert(isa<FloatType>(type));
651 switch (bitwidth) {
652 case 16: {
653 if (isa<BFloat16Type>(type)) {
654 static constexpr auto cap = Capability::BFloat16TypeKHR;
655 capabilities.push_back(cap);
656 } else {
657 static constexpr auto cap = Capability::Float16;
658 capabilities.push_back(cap);
659 }
660 break;
661 }
662 WIDTH_CASE(Float, 64);
663 case 32:
664 break;
665 default:
666 llvm_unreachable("invalid bitwidth to getCapabilities");
667 }
668 }
669
670#undef WIDTH_CASE
671}
672
673//===----------------------------------------------------------------------===//
674// SPIRVType
675//===----------------------------------------------------------------------===//
676
678 // Allow SPIR-V dialect types
679 if (llvm::isa<SPIRVDialect>(type.getDialect()))
680 return true;
681 if (llvm::isa<ScalarType>(type))
682 return true;
683 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
684 return CompositeType::isValid(vectorType);
685 if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
686 return llvm::isa<ScalarType>(tensorArmType.getElementType());
687 return false;
688}
689
691 return isIntOrFloat() || llvm::isa<VectorType>(*this);
692}
693
695 std::optional<StorageClass> storage) {
696 TypeExtensionVisitor{extensions, storage}.add(*this);
697}
698
701 std::optional<StorageClass> storage) {
702 TypeCapabilityVisitor{capabilities, storage}.add(*this);
703}
704
705std::optional<int64_t> SPIRVType::getSizeInBytes() {
707 .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
708 // According to the SPIR-V spec:
709 // "There is no physical size or bit pattern defined for values with
710 // boolean type. If they are stored (in conjunction with OpVariable),
711 // they can only be used with logical addressing operations, not
712 // physical, and only with non-externally visible shader Storage
713 // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
714 // Output."
715 int64_t bitWidth = type.getIntOrFloatBitWidth();
716 if (bitWidth == 1)
717 return std::nullopt;
718 return bitWidth / 8;
719 })
720 .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
721 // Since array type may have an explicit stride declaration (in bytes),
722 // we also include it in the calculation.
723 auto elementType = cast<SPIRVType>(type.getElementType());
724 if (std::optional<int64_t> size = elementType.getSizeInBytes())
725 return (*size + type.getArrayStride()) * type.getNumElements();
726 return std::nullopt;
727 })
728 .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
729 if (std::optional<int64_t> elementSize =
730 cast<ScalarType>(type.getElementType()).getSizeInBytes())
731 return *elementSize * type.getNumElements();
732 return std::nullopt;
733 })
734 .Default(std::nullopt);
735}
736
737//===----------------------------------------------------------------------===//
738// SampledImageType
739//===----------------------------------------------------------------------===//
741 using KeyTy = Type;
742
744
745 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
746
748 const KeyTy &key) {
749 return new (allocator.allocate<SampledImageTypeStorage>())
751 }
752
754};
755
757 return Base::get(imageType.getContext(), imageType);
758}
759
765
766Type SampledImageType::getImageType() const { return getImpl()->imageType; }
767
768LogicalResult
770 Type imageType) {
771 auto image = dyn_cast<ImageType>(imageType);
772 if (!image)
773 return emitError() << "expected image type";
774
775 // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
776 // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
777 // ("3.3.6. Type-Declaration Instructions")
778 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
779 return emitError() << "Dim must not be SubpassData or Buffer";
780
781 return success();
782}
783
784//===----------------------------------------------------------------------===//
785// StructType
786//===----------------------------------------------------------------------===//
787
788/// Type storage for SPIR-V structure types:
789///
790/// Structures are uniqued using:
791/// - for identified structs:
792/// - a string identifier;
793/// - for literal structs:
794/// - a list of member types;
795/// - a list of member offset info;
796/// - a list of member decoration info;
797/// - a list of struct decoration info.
798///
799/// Identified structures only have a mutable component consisting of:
800/// - a list of member types;
801/// - a list of member offset info;
802/// - a list of member decoration info;
803/// - a list of struct decoration info.
805 /// Construct a storage object for an identified struct type. A struct type
806 /// associated with such storage must call StructType::trySetBody(...) later
807 /// in order to mutate the storage object providing the actual content.
813
814 /// Construct a storage object for a literal struct type. A struct type
815 /// associated with such storage is immutable.
827
828 /// A storage key is divided into 2 parts:
829 /// - for identified structs:
830 /// - a StringRef representing the struct identifier;
831 /// - for literal structs:
832 /// - an ArrayRef<Type> for member types;
833 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
834 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
835 /// info;
836 /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
837 /// info.
838 ///
839 /// An identified struct type is uniqued only by the first part (field 0)
840 /// of the key.
841 ///
842 /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
843 /// and 4) of the key. The identifier field (field 0) must be empty.
844 using KeyTy =
845 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
848
849 /// For identified structs, return true if the given key contains the same
850 /// identifier.
851 ///
852 /// For literal structs, return true if the given key contains a matching list
853 /// of member types + offset info + decoration info.
854 bool operator==(const KeyTy &key) const {
855 if (isIdentified()) {
856 // Identified types are uniqued by their identifier.
857 return getIdentifier() == std::get<0>(key);
858 }
859
860 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
862 }
863
864 /// If the given key contains a non-empty identifier, this method constructs
865 /// an identified struct and leaves the rest of the struct type data to be set
866 /// through a later call to StructType::trySetBody(...).
867 ///
868 /// If, on the other hand, the key contains an empty identifier, a literal
869 /// struct is constructed using the other fields of the key.
871 const KeyTy &key) {
872 StringRef keyIdentifier = std::get<0>(key);
873
874 if (!keyIdentifier.empty()) {
875 StringRef identifier = allocator.copyInto(keyIdentifier);
876
877 // Identified StructType body/members will be set through trySetBody(...)
878 // later.
879 return new (allocator.allocate<StructTypeStorage>())
881 }
882
883 ArrayRef<Type> keyTypes = std::get<1>(key);
884
885 // Copy the member type and layout information into the bump pointer
886 const Type *typesList = nullptr;
887 if (!keyTypes.empty()) {
888 typesList = allocator.copyInto(keyTypes).data();
889 }
890
891 const StructType::OffsetInfo *offsetInfoList = nullptr;
892 if (!std::get<2>(key).empty()) {
893 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
894 assert(keyOffsetInfo.size() == keyTypes.size() &&
895 "size of offset information must be same as the size of number of "
896 "elements");
897 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
898 }
899
900 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
901 unsigned numMemberDecorations = 0;
902 if (!std::get<3>(key).empty()) {
903 auto keyMemberDecorations = std::get<3>(key);
904 numMemberDecorations = keyMemberDecorations.size();
905 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
906 }
907
908 const StructType::StructDecorationInfo *structDecorationList = nullptr;
909 unsigned numStructDecorations = 0;
910 if (!std::get<4>(key).empty()) {
911 auto keyStructDecorations = std::get<4>(key);
912 numStructDecorations = keyStructDecorations.size();
913 structDecorationList = allocator.copyInto(keyStructDecorations).data();
914 }
915
916 return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
917 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
918 memberDecorationList, numStructDecorations, structDecorationList);
919 }
920
924
931
939
946
947 StringRef getIdentifier() const { return identifier; }
948
949 bool isIdentified() const { return !identifier.empty(); }
950
951 /// Sets the struct type content for identified structs. Calling this method
952 /// is only valid for identified structs.
953 ///
954 /// Fails under the following conditions:
955 /// - If called for a literal struct;
956 /// - If called for an identified struct whose body was set before (through a
957 /// call to this method) but with different contents from the passed
958 /// arguments.
959 LogicalResult
960 mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
961 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
962 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
963 ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
964 if (!isIdentified())
965 return failure();
966
967 if (memberTypesAndIsBodySet.getInt() &&
968 (getMemberTypes() != structMemberTypes ||
969 getOffsetInfo() != structOffsetInfo ||
970 getMemberDecorationsInfo() != structMemberDecorationInfo ||
971 getStructDecorationsInfo() != structDecorationInfo))
972 return failure();
973
974 memberTypesAndIsBodySet.setInt(true);
975 numMembers = structMemberTypes.size();
976
977 // Copy the member type and layout information into the bump pointer.
978 if (!structMemberTypes.empty())
979 memberTypesAndIsBodySet.setPointer(
980 allocator.copyInto(structMemberTypes).data());
981
982 if (!structOffsetInfo.empty()) {
983 assert(structOffsetInfo.size() == structMemberTypes.size() &&
984 "size of offset information must be same as the size of number of "
985 "elements");
986 offsetInfo = allocator.copyInto(structOffsetInfo).data();
987 }
988
989 if (!structMemberDecorationInfo.empty()) {
990 numMemberDecorations = structMemberDecorationInfo.size();
992 allocator.copyInto(structMemberDecorationInfo).data();
993 }
994
995 if (!structDecorationInfo.empty()) {
996 numStructDecorations = structDecorationInfo.size();
997 structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
998 }
999
1000 return success();
1001 }
1002
1003 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1005 unsigned numMembers;
1010 StringRef identifier;
1011};
1012
1018 assert(!memberTypes.empty() && "Struct needs at least one member type");
1019 // Sort the decorations.
1021 memberDecorations);
1022 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1023 sortedMemberDecorations.end());
1025 structDecorations);
1026 llvm::array_pod_sort(sortedStructDecorations.begin(),
1027 sortedStructDecorations.end());
1028
1029 return Base::get(memberTypes.vec().front().getContext(),
1030 /*identifier=*/StringRef(), memberTypes, offsetInfo,
1031 sortedMemberDecorations, sortedStructDecorations);
1032}
1033
1035 StringRef identifier) {
1036 assert(!identifier.empty() &&
1037 "StructType identifier must be non-empty string");
1038
1039 return Base::get(context, identifier, ArrayRef<Type>(),
1043}
1044
1045StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1046 StructType newStructType = Base::get(
1047 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1050 // Set an empty body in case this is a identified struct.
1051 if (newStructType.isIdentified() &&
1052 failed(newStructType.trySetBody(
1056 return StructType();
1057
1058 return newStructType;
1059}
1060
1061StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1062
1063bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1064
1065unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1066
1068 assert(getNumElements() > index && "member index out of range");
1069 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1070}
1071
1073 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1074 getNumElements());
1075}
1076
1077bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1078
1079bool StructType::hasDecoration(spirv::Decoration decoration) const {
1081 getImpl()->getStructDecorationsInfo())
1082 if (info.decoration == decoration)
1083 return true;
1084
1085 return false;
1086}
1087
1088uint64_t StructType::getMemberOffset(unsigned index) const {
1089 assert(getNumElements() > index && "member index out of range");
1090 return getImpl()->offsetInfo[index];
1091}
1092
1095 const {
1096 memberDecorations.clear();
1097 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1098 memberDecorations.append(implMemberDecorations.begin(),
1099 implMemberDecorations.end());
1100}
1101
1103 unsigned index,
1105 assert(getNumElements() > index && "member index out of range");
1106 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1107 decorationsInfo.clear();
1108 for (const auto &memberDecoration : memberDecorations) {
1109 if (memberDecoration.memberIndex == index) {
1110 decorationsInfo.push_back(memberDecoration);
1111 }
1112 if (memberDecoration.memberIndex > index) {
1113 // Early exit since the decorations are stored sorted.
1114 return;
1115 }
1116 }
1117}
1118
1121 const {
1122 structDecorations.clear();
1123 auto implDecorations = getImpl()->getStructDecorationsInfo();
1124 structDecorations.append(implDecorations.begin(), implDecorations.end());
1125}
1126
1127LogicalResult
1129 ArrayRef<OffsetInfo> offsetInfo,
1130 ArrayRef<MemberDecorationInfo> memberDecorations,
1131 ArrayRef<StructDecorationInfo> structDecorations) {
1132 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1133 structDecorations);
1134}
1135
1136llvm::hash_code spirv::hash_value(
1137 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1138 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1139 memberDecorationInfo.decoration);
1140}
1141
1142llvm::hash_code spirv::hash_value(
1143 const StructType::StructDecorationInfo &structDecorationInfo) {
1144 return llvm::hash_value(structDecorationInfo.decoration);
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// MatrixType
1149//===----------------------------------------------------------------------===//
1150
1154
1155 using KeyTy = std::tuple<Type, uint32_t>;
1156
1158 const KeyTy &key) {
1159
1160 // Initialize the memory using placement new.
1161 return new (allocator.allocate<MatrixTypeStorage>())
1162 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1163 }
1164
1165 bool operator==(const KeyTy &key) const {
1166 return key == KeyTy(columnType, columnCount);
1167 }
1168
1170 const uint32_t columnCount;
1171};
1172
1173MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1174 return Base::get(columnType.getContext(), columnType, columnCount);
1175}
1176
1178 Type columnType, uint32_t columnCount) {
1179 return Base::getChecked(emitError, columnType.getContext(), columnType,
1180 columnCount);
1181}
1182
1183LogicalResult
1185 Type columnType, uint32_t columnCount) {
1186 if (columnCount < 2 || columnCount > 4)
1187 return emitError() << "matrix can have 2, 3, or 4 columns only";
1188
1189 if (!isValidColumnType(columnType))
1190 return emitError() << "matrix columns must be vectors of floats";
1191
1192 /// The underlying vectors (columns) must be of size 2, 3, or 4
1193 ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1194 if (columnShape.size() != 1)
1195 return emitError() << "matrix columns must be 1D vectors";
1196
1197 if (columnShape[0] < 2 || columnShape[0] > 4)
1198 return emitError() << "matrix columns must be of size 2, 3, or 4";
1199
1200 return success();
1201}
1202
1203/// Returns true if the matrix elements are vectors of float elements
1205 if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206 if (llvm::isa<FloatType>(vectorType.getElementType()))
1207 return true;
1208 }
1209 return false;
1210}
1211
1212Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1213
1215 return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1216}
1217
1218unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1219
1220unsigned MatrixType::getNumRows() const {
1221 return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1222}
1223
1225 return (getImpl()->columnCount) * getNumRows();
1226}
1227
1228void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1229 add(type.getColumnType());
1230 static constexpr auto cap = Capability::Matrix;
1231 capabilities.push_back(cap);
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// TensorArmType
1236//===----------------------------------------------------------------------===//
1237
1239 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1240
1242 const KeyTy &key) {
1243 auto [shape, elementType] = key;
1244 shape = allocator.copyInto(shape);
1245 return new (allocator.allocate<TensorArmTypeStorage>())
1247 }
1248
1249 static llvm::hash_code hashKey(const KeyTy &key) {
1250 auto [shape, elementType] = key;
1251 return llvm::hash_combine(shape, elementType);
1252 }
1253
1254 bool operator==(const KeyTy &key) const {
1255 return key == KeyTy(shape, elementType);
1256 }
1257
1260
1263};
1264
1266 return Base::get(elementType.getContext(), shape, elementType);
1267}
1268
1270 Type elementType) const {
1271 return TensorArmType::get(shape.value_or(getShape()), elementType);
1272}
1273
1274Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1276
1277void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1278 add(type.getElementType());
1279 static constexpr auto ext = Extension::SPV_ARM_tensors;
1280 extensions.push_back(ext);
1281}
1282
1283void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1284 add(type.getElementType());
1285 static constexpr auto cap = Capability::TensorsARM;
1286 capabilities.push_back(cap);
1287}
1288
1289LogicalResult
1291 ArrayRef<int64_t> shape, Type elementType) {
1292 if (llvm::is_contained(shape, 0))
1293 return emitError() << "arm.tensor do not support dimensions = 0";
1294 if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1295 llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1296 return emitError()
1297 << "arm.tensor shape dimensions must be either fully dynamic or "
1298 "completed shaped";
1299 return success();
1300}
1301
1302//===----------------------------------------------------------------------===//
1303// SPIR-V Dialect
1304//===----------------------------------------------------------------------===//
1305
1306void SPIRVDialect::registerTypes() {
1309}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
#define add(a, b)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, uint32_t > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
ArrayRef< Type > getMemberTypes() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const