MLIR 23.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <cstdint>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::spirv;
27
28namespace {
29// Helper function to collect extensions implied by a type by visiting all its
30// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31//
32// Serves as the source-of-truth for type extension information. All extension
33// logic should be added to this class, while the
34// `SPIRVType::getExtensions` function should not handle extension-related logic
35// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36class TypeExtensionVisitor {
37public:
38 TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
41
42 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43 // as seen and dispatches to the right concrete `.add` function.
44 void add(SPIRVType type) {
45 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46 return;
47
49 .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
50 [this](auto concreteType) { addConcrete(concreteType); })
51 .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52 [this](auto concreteType) { add(concreteType.getElementType()); })
53 .Case([this](SampledImageType concreteType) {
54 add(concreteType.getImageType());
55 })
56 .Case([this](StructType concreteType) {
57 for (Type elementType : concreteType.getElementTypes())
58 add(elementType);
59 })
60 .DefaultUnreachable("Unhandled type");
61 }
62
63 void add(Type type) { add(cast<SPIRVType>(type)); }
64
65private:
66 // Types that add unique extensions.
67 void addConcrete(CooperativeMatrixType type);
68 void addConcrete(PointerType type);
69 void addConcrete(ScalarType type);
70 void addConcrete(TensorArmType type);
71
73 std::optional<StorageClass> storage;
74 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
75};
76
77// Helper function to collect capabilities implied by a type by visiting all its
78// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
79//
80// Serves as the source-of-truth for type capability information. All capability
81// logic should be added to this class, while the
82// `SPIRVType::getCapabilities` function should not handle capability-related
83// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
84class TypeCapabilityVisitor {
85public:
86 TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
87 std::optional<StorageClass> storage)
88 : capabilities(capabilities), storage(storage) {}
89
90 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
91 // as seen and dispatches to the right concrete `.add` function.
92 void add(SPIRVType type) {
93 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
94 return;
95
97 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
98 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
99 [this](auto concreteType) { addConcrete(concreteType); })
100 .Case([this](ArrayType concreteType) {
101 add(concreteType.getElementType());
102 })
103 .Case([this](SampledImageType concreteType) {
104 add(concreteType.getImageType());
105 })
106 .Case([this](StructType concreteType) {
107 for (Type elementType : concreteType.getElementTypes())
108 add(elementType);
109 })
110 .DefaultUnreachable("Unhandled type");
111 }
112
113 void add(Type type) { add(cast<SPIRVType>(type)); }
114
115private:
116 // Types that add unique extensions.
117 void addConcrete(CooperativeMatrixType type);
118 void addConcrete(ImageType type);
119 void addConcrete(MatrixType type);
120 void addConcrete(PointerType type);
121 void addConcrete(RuntimeArrayType type);
122 void addConcrete(ScalarType type);
123 void addConcrete(TensorArmType type);
124 void addConcrete(VectorType type);
125
127 std::optional<StorageClass> storage;
128 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
129};
130
131} // namespace
132
133//===----------------------------------------------------------------------===//
134// ArrayType
135//===----------------------------------------------------------------------===//
136
138 using KeyTy = std::tuple<Type, unsigned, unsigned>;
139
141 const KeyTy &key) {
142 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
143 }
144
145 bool operator==(const KeyTy &key) const {
146 return key == KeyTy(elementType, elementCount, stride);
147 }
148
150 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
151 stride(std::get<2>(key)) {}
152
154 unsigned elementCount;
155 unsigned stride;
156};
157
158ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
159 assert(elementCount && "ArrayType needs at least one element");
160 return Base::get(elementType.getContext(), elementType, elementCount,
161 /*stride=*/0);
162}
163
164ArrayType ArrayType::get(Type elementType, unsigned elementCount,
165 unsigned stride) {
166 assert(elementCount && "ArrayType needs at least one element");
167 return Base::get(elementType.getContext(), elementType, elementCount, stride);
168}
169
170unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
171
172Type ArrayType::getElementType() const { return getImpl()->elementType; }
173
174unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
175
176//===----------------------------------------------------------------------===//
177// CompositeType
178//===----------------------------------------------------------------------===//
179
181 if (auto vectorType = dyn_cast<VectorType>(type))
182 return isValid(vectorType);
185 type);
186}
187
188bool CompositeType::isValid(VectorType type) {
189 return type.getRank() == 1 &&
190 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
191 isa<ScalarType>(type.getElementType());
192}
193
195 return TypeSwitch<Type, Type>(*this)
197 TensorArmType>([](auto type) { return type.getElementType(); })
198 .Case([](MatrixType type) { return type.getColumnType(); })
199 .Case([index](StructType type) { return type.getElementType(index); })
200 .DefaultUnreachable("Invalid composite type");
201}
202
205 .Case<ArrayType, StructType, TensorArmType, VectorType>(
206 [](auto type) { return type.getNumElements(); })
207 .Case([](MatrixType type) { return type.getNumColumns(); })
208 .DefaultUnreachable("Invalid type for number of elements query");
209}
210
212 return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
213}
214
215void TypeCapabilityVisitor::addConcrete(VectorType type) {
216 add(type.getElementType());
217
218 int64_t vecSize = type.getNumElements();
219 if (vecSize == 8 || vecSize == 16) {
220 static constexpr auto cap = Capability::Vector16;
221 capabilities.push_back(cap);
222 }
223}
224
225//===----------------------------------------------------------------------===//
226// CooperativeMatrixType
227//===----------------------------------------------------------------------===//
228
230 // In the specification dimensions of the Cooperative Matrix are 32-bit
231 // integers --- the initial implementation kept those values as such. However,
232 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
233 // as 32-bits and expose it as int64_t through `getShape`, however, this
234 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
235 // 32-bits integers would require an extra logic and storage. So, we diverge
236 // from the spec and internally represent the dimensions as 64-bit integers,
237 // so we can easily return an `ArrayRef` from `getShape` without any extra
238 // logic. Alternatively, we could store both rows and columns (both 32-bits)
239 // and shape (64-bits), assigning rows and columns to shape whenever
240 // `getShape` is called. This would be at the cost of extra logic and storage.
241 // Note: Because `ArrayRef` is returned we cannot construct an object in
242 // `getShape` on the fly.
243 using KeyTy =
244 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
245
247 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
248 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
250 }
251
252 bool operator==(const KeyTy &key) const {
253 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
254 }
255
257 : elementType(std::get<0>(key)),
258 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
259 use(std::get<4>(key)) {}
260
262 // [#rows, #columns]
263 std::array<int64_t, 2> shape;
264 Scope scope;
265 CooperativeMatrixUseKHR use;
266};
267
269 uint32_t rows,
270 uint32_t columns, Scope scope,
271 CooperativeMatrixUseKHR use) {
272 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
273 use);
274}
275
277 return getImpl()->elementType;
278}
279
281 assert(getImpl()->shape[0] != ShapedType::kDynamic);
282 return static_cast<uint32_t>(getImpl()->shape[0]);
283}
284
286 assert(getImpl()->shape[1] != ShapedType::kDynamic);
287 return static_cast<uint32_t>(getImpl()->shape[1]);
288}
289
293
294Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
295
296CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
297 return getImpl()->use;
298}
299
300void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
301 add(type.getElementType());
302 static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
303 extensions.push_back(ext);
304}
305
306void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
307 add(type.getElementType());
308 static constexpr auto caps = Capability::CooperativeMatrixKHR;
309 capabilities.push_back(caps);
310}
311
312//===----------------------------------------------------------------------===//
313// ImageType
314//===----------------------------------------------------------------------===//
315
316template <typename T>
317static constexpr unsigned getNumBits() {
318 return 0;
319}
320template <>
321constexpr unsigned getNumBits<Dim>() {
322 static_assert((1 << 3) > getMaxEnumValForDim(),
323 "Not enough bits to encode Dim value");
324 return 3;
325}
326template <>
327constexpr unsigned getNumBits<ImageDepthInfo>() {
328 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
329 "Not enough bits to encode ImageDepthInfo value");
330 return 2;
331}
332template <>
333constexpr unsigned getNumBits<ImageArrayedInfo>() {
334 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
335 "Not enough bits to encode ImageArrayedInfo value");
336 return 1;
337}
338template <>
339constexpr unsigned getNumBits<ImageSamplingInfo>() {
340 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
341 "Not enough bits to encode ImageSamplingInfo value");
342 return 1;
343}
344template <>
346 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
347 "Not enough bits to encode ImageSamplerUseInfo value");
348 return 2;
349}
350template <>
351constexpr unsigned getNumBits<ImageFormat>() {
352 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
353 "Not enough bits to encode ImageFormat value");
354 return 6;
355}
356
358public:
359 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
360 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
361
363 const KeyTy &key) {
364 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
365 }
366
367 bool operator==(const KeyTy &key) const {
370 }
371
373 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
374 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
375 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
376 format(std::get<6>(key)) {}
377
385};
386
388ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
389 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
390 value) {
391 return Base::get(std::get<0>(value).getContext(), value);
392}
393
394Type ImageType::getElementType() const { return getImpl()->elementType; }
395
396Dim ImageType::getDim() const { return getImpl()->dim; }
397
398ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
399
400ImageArrayedInfo ImageType::getArrayedInfo() const {
401 return getImpl()->arrayedInfo;
402}
403
404ImageSamplingInfo ImageType::getSamplingInfo() const {
405 return getImpl()->samplingInfo;
406}
407
408ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
409 return getImpl()->samplerUseInfo;
410}
411
412ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
413
414void TypeCapabilityVisitor::addConcrete(ImageType type) {
415 if (auto dimCaps = spirv::getCapabilities(type.getDim()))
416 capabilities.push_back(*dimCaps);
417
418 if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
419 capabilities.push_back(*fmtCaps);
420
421 add(type.getElementType());
422}
423
424//===----------------------------------------------------------------------===//
425// PointerType
426//===----------------------------------------------------------------------===//
427
429 // (Type, StorageClass) as the key: Type stored in this struct, and
430 // StorageClass stored as TypeStorage's subclass data.
431 using KeyTy = std::pair<Type, StorageClass>;
432
434 const KeyTy &key) {
435 return new (allocator.allocate<PointerTypeStorage>())
437 }
438
439 bool operator==(const KeyTy &key) const {
440 return key == KeyTy(pointeeType, storageClass);
441 }
442
444 : pointeeType(key.first), storageClass(key.second) {}
445
447 StorageClass storageClass;
448};
449
450PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
451 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
452}
453
454Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
455
456StorageClass PointerType::getStorageClass() const {
457 return getImpl()->storageClass;
458}
459
460void TypeExtensionVisitor::addConcrete(PointerType type) {
461 // Use this pointer type's storage class because this pointer indicates we are
462 // using the pointee type in that specific storage class.
463 std::optional<StorageClass> oldStorageClass = storage;
464 storage = type.getStorageClass();
465 add(type.getPointeeType());
466 storage = oldStorageClass;
467
468 if (auto scExts = spirv::getExtensions(type.getStorageClass()))
469 extensions.push_back(*scExts);
470}
471
472void TypeCapabilityVisitor::addConcrete(PointerType type) {
473 // Use this pointer type's storage class because this pointer indicates we are
474 // using the pointee type in that specific storage class.
475 std::optional<StorageClass> oldStorageClass = storage;
476 storage = type.getStorageClass();
477 add(type.getPointeeType());
478 storage = oldStorageClass;
479
480 if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
481 capabilities.push_back(*scCaps);
482}
483
484//===----------------------------------------------------------------------===//
485// RuntimeArrayType
486//===----------------------------------------------------------------------===//
487
489 using KeyTy = std::pair<Type, unsigned>;
490
492 const KeyTy &key) {
493 return new (allocator.allocate<RuntimeArrayTypeStorage>())
495 }
496
497 bool operator==(const KeyTy &key) const {
498 return key == KeyTy(elementType, stride);
499 }
500
502 : elementType(key.first), stride(key.second) {}
503
505 unsigned stride;
506};
507
509 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
510}
511
512RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
513 return Base::get(elementType.getContext(), elementType, stride);
514}
515
516Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
517
518unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
519
520void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
521 add(type.getElementType());
522 static constexpr auto cap = Capability::Shader;
523 capabilities.push_back(cap);
524}
525
526//===----------------------------------------------------------------------===//
527// ScalarType
528//===----------------------------------------------------------------------===//
529
531 if (auto floatType = dyn_cast<FloatType>(type)) {
532 return isValid(floatType);
533 }
534 if (auto intType = dyn_cast<IntegerType>(type)) {
535 return isValid(intType);
536 }
537 return false;
538}
539
540bool ScalarType::isValid(FloatType type) {
541 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
542}
543
544bool ScalarType::isValid(IntegerType type) {
545 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
546}
547
548void TypeExtensionVisitor::addConcrete(ScalarType type) {
549 if (isa<BFloat16Type>(type)) {
550 static constexpr auto ext = Extension::SPV_KHR_bfloat16;
551 extensions.push_back(ext);
552 }
553
554 if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
555 static constexpr auto ext = Extension::SPV_EXT_float8;
556 extensions.push_back(ext);
557 }
558
559 // 8- or 16-bit integer/floating-point numbers will require extra extensions
560 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
561 // SPV_KHR_8bit_storage for more details.
562 if (!storage)
563 return;
564
565 switch (*storage) {
566 case StorageClass::PushConstant:
567 case StorageClass::StorageBuffer:
568 case StorageClass::Uniform:
569 if (type.getIntOrFloatBitWidth() == 8) {
570 static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
571 extensions.push_back(ext);
572 }
573 [[fallthrough]];
574 case StorageClass::Input:
575 case StorageClass::Output:
576 if (type.getIntOrFloatBitWidth() == 16) {
577 static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
578 extensions.push_back(ext);
579 }
580 break;
581 default:
582 break;
583 }
584}
585
586void TypeCapabilityVisitor::addConcrete(ScalarType type) {
587 unsigned bitwidth = type.getIntOrFloatBitWidth();
588
589 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
590 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
591 // SPV_KHR_8bit_storage for more details.
592
593#define STORAGE_CASE(storage, cap8, cap16) \
594 case StorageClass::storage: { \
595 if (bitwidth == 8) { \
596 static constexpr auto cap = Capability::cap8; \
597 capabilities.push_back(cap); \
598 return; \
599 } \
600 if (bitwidth == 16) { \
601 static constexpr auto cap = Capability::cap16; \
602 capabilities.push_back(cap); \
603 return; \
604 } \
605 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
606 /* storage classes. Fall through to the next section. */ \
607 } break
608
609 // This part only handles the cases where special bitwidths appearing in
610 // interface storage classes.
611 if (storage) {
612 switch (*storage) {
613 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
614 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
615 StorageBuffer16BitAccess);
616 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
617 StorageUniform16);
618 case StorageClass::Input:
619 case StorageClass::Output: {
620 if (bitwidth == 16) {
621 static constexpr auto cap = Capability::StorageInputOutput16;
622 capabilities.push_back(cap);
623 return;
624 }
625 break;
626 }
627 default:
628 break;
629 }
630 }
631#undef STORAGE_CASE
632
633 // For other non-interface storage classes, require a different set of
634 // capabilities for special bitwidths.
635
636#define WIDTH_CASE(type, width) \
637 case width: { \
638 static constexpr auto cap = Capability::type##width; \
639 capabilities.push_back(cap); \
640 } break
641
642 if (auto intType = dyn_cast<IntegerType>(type)) {
643 switch (bitwidth) {
644 WIDTH_CASE(Int, 8);
645 WIDTH_CASE(Int, 16);
646 WIDTH_CASE(Int, 64);
647 case 1:
648 case 32:
649 break;
650 default:
651 llvm_unreachable("invalid bitwidth to getCapabilities");
652 }
653 } else {
654 assert(isa<FloatType>(type));
655 switch (bitwidth) {
656 case 8: {
657 if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
658 static constexpr auto cap = Capability::Float8EXT;
659 capabilities.push_back(cap);
660 } else {
661 llvm_unreachable("invalid 8-bit float type to getCapabilities");
662 }
663 break;
664 }
665 case 16: {
666 if (isa<BFloat16Type>(type)) {
667 static constexpr auto cap = Capability::BFloat16TypeKHR;
668 capabilities.push_back(cap);
669 } else {
670 static constexpr auto cap = Capability::Float16;
671 capabilities.push_back(cap);
672 }
673 break;
674 }
675 WIDTH_CASE(Float, 64);
676 case 32:
677 break;
678 default:
679 llvm_unreachable("invalid bitwidth to getCapabilities");
680 }
681 }
682
683#undef WIDTH_CASE
684}
685
686//===----------------------------------------------------------------------===//
687// SPIRVType
688//===----------------------------------------------------------------------===//
689
691 // Allow SPIR-V dialect types
692 if (isa<SPIRVDialect>(type.getDialect()))
693 return true;
694 if (isa<ScalarType>(type))
695 return true;
696 if (auto vectorType = dyn_cast<VectorType>(type))
697 return CompositeType::isValid(vectorType);
698 if (auto tensorArmType = dyn_cast<TensorArmType>(type))
699 return isa<ScalarType>(tensorArmType.getElementType());
700 return false;
701}
702
704 return isIntOrFloat() || isa<VectorType>(*this);
705}
706
708 std::optional<StorageClass> storage) {
709 TypeExtensionVisitor{extensions, storage}.add(*this);
710}
711
714 std::optional<StorageClass> storage) {
715 TypeCapabilityVisitor{capabilities, storage}.add(*this);
716}
717
718std::optional<int64_t> SPIRVType::getSizeInBytes() {
720 .Case([](ScalarType type) -> std::optional<int64_t> {
721 // According to the SPIR-V spec:
722 // "There is no physical size or bit pattern defined for values with
723 // boolean type. If they are stored (in conjunction with OpVariable),
724 // they can only be used with logical addressing operations, not
725 // physical, and only with non-externally visible shader Storage
726 // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
727 // Output."
728 int64_t bitWidth = type.getIntOrFloatBitWidth();
729 if (bitWidth == 1)
730 return std::nullopt;
731 return bitWidth / 8;
732 })
733 .Case([](ArrayType type) -> std::optional<int64_t> {
734 // Since array type may have an explicit stride declaration (in bytes),
735 // we also include it in the calculation.
736 auto elementType = cast<SPIRVType>(type.getElementType());
737 if (std::optional<int64_t> size = elementType.getSizeInBytes())
738 return (*size + type.getArrayStride()) * type.getNumElements();
739 return std::nullopt;
740 })
741 .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
742 if (std::optional<int64_t> elementSize =
743 cast<ScalarType>(type.getElementType()).getSizeInBytes())
744 return *elementSize * type.getNumElements();
745 return std::nullopt;
746 })
747 .Default(std::nullopt);
748}
749
750//===----------------------------------------------------------------------===//
751// SampledImageType
752//===----------------------------------------------------------------------===//
754 using KeyTy = Type;
755
757
758 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
759
761 const KeyTy &key) {
762 return new (allocator.allocate<SampledImageTypeStorage>())
764 }
765
767};
768
770 return Base::get(imageType.getContext(), imageType);
771}
772
778
779Type SampledImageType::getImageType() const { return getImpl()->imageType; }
780
781LogicalResult
783 Type imageType) {
784 auto image = dyn_cast<ImageType>(imageType);
785 if (!image)
786 return emitError() << "expected image type";
787
788 // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
789 // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
790 // ("3.3.6. Type-Declaration Instructions")
791 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
792 return emitError() << "Dim must not be SubpassData or Buffer";
793
794 return success();
795}
796
797//===----------------------------------------------------------------------===//
798// StructType
799//===----------------------------------------------------------------------===//
800
801/// Type storage for SPIR-V structure types:
802///
803/// Structures are uniqued using:
804/// - for identified structs:
805/// - a string identifier;
806/// - for literal structs:
807/// - a list of member types;
808/// - a list of member offset info;
809/// - a list of member decoration info;
810/// - a list of struct decoration info.
811///
812/// Identified structures only have a mutable component consisting of:
813/// - a list of member types;
814/// - a list of member offset info;
815/// - a list of member decoration info;
816/// - a list of struct decoration info.
818 /// Construct a storage object for an identified struct type. A struct type
819 /// associated with such storage must call StructType::trySetBody(...) later
820 /// in order to mutate the storage object providing the actual content.
826
827 /// Construct a storage object for a literal struct type. A struct type
828 /// associated with such storage is immutable.
840
841 /// A storage key is divided into 2 parts:
842 /// - for identified structs:
843 /// - a StringRef representing the struct identifier;
844 /// - for literal structs:
845 /// - an ArrayRef<Type> for member types;
846 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
847 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
848 /// info;
849 /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
850 /// info.
851 ///
852 /// An identified struct type is uniqued only by the first part (field 0)
853 /// of the key.
854 ///
855 /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
856 /// and 4) of the key. The identifier field (field 0) must be empty.
857 using KeyTy =
858 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
861
862 /// For identified structs, return true if the given key contains the same
863 /// identifier.
864 ///
865 /// For literal structs, return true if the given key contains a matching list
866 /// of member types + offset info + decoration info.
867 bool operator==(const KeyTy &key) const {
868 if (isIdentified()) {
869 // Identified types are uniqued by their identifier.
870 return getIdentifier() == std::get<0>(key);
871 }
872
873 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
875 }
876
877 /// If the given key contains a non-empty identifier, this method constructs
878 /// an identified struct and leaves the rest of the struct type data to be set
879 /// through a later call to StructType::trySetBody(...).
880 ///
881 /// If, on the other hand, the key contains an empty identifier, a literal
882 /// struct is constructed using the other fields of the key.
884 const KeyTy &key) {
885 StringRef keyIdentifier = std::get<0>(key);
886
887 if (!keyIdentifier.empty()) {
888 StringRef identifier = allocator.copyInto(keyIdentifier);
889
890 // Identified StructType body/members will be set through trySetBody(...)
891 // later.
892 return new (allocator.allocate<StructTypeStorage>())
894 }
895
896 ArrayRef<Type> keyTypes = std::get<1>(key);
897
898 // Copy the member type and layout information into the bump pointer
899 const Type *typesList = nullptr;
900 if (!keyTypes.empty()) {
901 typesList = allocator.copyInto(keyTypes).data();
902 }
903
904 const StructType::OffsetInfo *offsetInfoList = nullptr;
905 if (!std::get<2>(key).empty()) {
906 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
907 assert(keyOffsetInfo.size() == keyTypes.size() &&
908 "size of offset information must be same as the size of number of "
909 "elements");
910 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
911 }
912
913 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
914 unsigned numMemberDecorations = 0;
915 if (!std::get<3>(key).empty()) {
916 auto keyMemberDecorations = std::get<3>(key);
917 numMemberDecorations = keyMemberDecorations.size();
918 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
919 }
920
921 const StructType::StructDecorationInfo *structDecorationList = nullptr;
922 unsigned numStructDecorations = 0;
923 if (!std::get<4>(key).empty()) {
924 auto keyStructDecorations = std::get<4>(key);
925 numStructDecorations = keyStructDecorations.size();
926 structDecorationList = allocator.copyInto(keyStructDecorations).data();
927 }
928
929 return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
930 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
931 memberDecorationList, numStructDecorations, structDecorationList);
932 }
933
937
944
952
959
960 StringRef getIdentifier() const { return identifier; }
961
962 bool isIdentified() const { return !identifier.empty(); }
963
964 /// Sets the struct type content for identified structs. Calling this method
965 /// is only valid for identified structs.
966 ///
967 /// Fails under the following conditions:
968 /// - If called for a literal struct;
969 /// - If called for an identified struct whose body was set before (through a
970 /// call to this method) but with different contents from the passed
971 /// arguments.
972 LogicalResult
973 mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
974 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
975 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
976 ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
977 if (!isIdentified())
978 return failure();
979
980 if (memberTypesAndIsBodySet.getInt() &&
981 (getMemberTypes() != structMemberTypes ||
982 getOffsetInfo() != structOffsetInfo ||
983 getMemberDecorationsInfo() != structMemberDecorationInfo ||
984 getStructDecorationsInfo() != structDecorationInfo))
985 return failure();
986
987 memberTypesAndIsBodySet.setInt(true);
988 numMembers = structMemberTypes.size();
989
990 // Copy the member type and layout information into the bump pointer.
991 if (!structMemberTypes.empty())
992 memberTypesAndIsBodySet.setPointer(
993 allocator.copyInto(structMemberTypes).data());
994
995 if (!structOffsetInfo.empty()) {
996 assert(structOffsetInfo.size() == structMemberTypes.size() &&
997 "size of offset information must be same as the size of number of "
998 "elements");
999 offsetInfo = allocator.copyInto(structOffsetInfo).data();
1000 }
1001
1002 if (!structMemberDecorationInfo.empty()) {
1003 numMemberDecorations = structMemberDecorationInfo.size();
1005 allocator.copyInto(structMemberDecorationInfo).data();
1006 }
1007
1008 if (!structDecorationInfo.empty()) {
1009 numStructDecorations = structDecorationInfo.size();
1010 structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
1011 }
1012
1013 return success();
1014 }
1015
1016 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1018 unsigned numMembers;
1023 StringRef identifier;
1024};
1025
1031 assert(!memberTypes.empty() && "Struct needs at least one member type");
1032 // Sort the decorations.
1034 memberDecorations);
1035 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1036 sortedMemberDecorations.end());
1038 structDecorations);
1039 llvm::array_pod_sort(sortedStructDecorations.begin(),
1040 sortedStructDecorations.end());
1041
1042 return Base::get(memberTypes.vec().front().getContext(),
1043 /*identifier=*/StringRef(), memberTypes, offsetInfo,
1044 sortedMemberDecorations, sortedStructDecorations);
1045}
1046
1048 StringRef identifier) {
1049 assert(!identifier.empty() &&
1050 "StructType identifier must be non-empty string");
1051
1052 return Base::get(context, identifier, ArrayRef<Type>(),
1056}
1057
1058StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1059 StructType newStructType = Base::get(
1060 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1063 // Set an empty body in case this is a identified struct.
1064 if (newStructType.isIdentified() &&
1065 failed(newStructType.trySetBody(
1069 return StructType();
1070
1071 return newStructType;
1072}
1073
1074StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1075
1076bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1077
1078unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1079
1081 assert(getNumElements() > index && "member index out of range");
1082 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1083}
1084
1086 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1087 getNumElements());
1088}
1089
1090bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1091
1092bool StructType::hasDecoration(spirv::Decoration decoration) const {
1094 getImpl()->getStructDecorationsInfo())
1095 if (info.decoration == decoration)
1096 return true;
1097
1098 return false;
1099}
1100
1101uint64_t StructType::getMemberOffset(unsigned index) const {
1102 assert(getNumElements() > index && "member index out of range");
1103 return getImpl()->offsetInfo[index];
1104}
1105
1108 const {
1109 memberDecorations.clear();
1110 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1111 memberDecorations.append(implMemberDecorations.begin(),
1112 implMemberDecorations.end());
1113}
1114
1116 unsigned index,
1118 assert(getNumElements() > index && "member index out of range");
1119 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1120 decorationsInfo.clear();
1121 for (const auto &memberDecoration : memberDecorations) {
1122 if (memberDecoration.memberIndex == index) {
1123 decorationsInfo.push_back(memberDecoration);
1124 }
1125 if (memberDecoration.memberIndex > index) {
1126 // Early exit since the decorations are stored sorted.
1127 return;
1128 }
1129 }
1130}
1131
1134 const {
1135 structDecorations.clear();
1136 auto implDecorations = getImpl()->getStructDecorationsInfo();
1137 structDecorations.append(implDecorations.begin(), implDecorations.end());
1138}
1139
1140LogicalResult
1142 ArrayRef<OffsetInfo> offsetInfo,
1143 ArrayRef<MemberDecorationInfo> memberDecorations,
1144 ArrayRef<StructDecorationInfo> structDecorations) {
1145 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1146 structDecorations);
1147}
1148
1149llvm::hash_code spirv::hash_value(
1150 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1151 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1152 memberDecorationInfo.decoration);
1153}
1154
1155llvm::hash_code spirv::hash_value(
1156 const StructType::StructDecorationInfo &structDecorationInfo) {
1157 return llvm::hash_value(structDecorationInfo.decoration);
1158}
1159
1160//===----------------------------------------------------------------------===//
1161// MatrixType
1162//===----------------------------------------------------------------------===//
1163
1167
1168 using KeyTy = std::tuple<Type, uint32_t>;
1169
1171 const KeyTy &key) {
1172
1173 // Initialize the memory using placement new.
1174 return new (allocator.allocate<MatrixTypeStorage>())
1175 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1176 }
1177
1178 bool operator==(const KeyTy &key) const {
1179 return key == KeyTy(columnType, columnCount);
1180 }
1181
1183 const uint32_t columnCount;
1184};
1185
1186MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1187 return Base::get(columnType.getContext(), columnType, columnCount);
1188}
1189
1191 Type columnType, uint32_t columnCount) {
1192 return Base::getChecked(emitError, columnType.getContext(), columnType,
1193 columnCount);
1194}
1195
1196LogicalResult
1198 Type columnType, uint32_t columnCount) {
1199 if (columnCount < 2 || columnCount > 4)
1200 return emitError() << "matrix can have 2, 3, or 4 columns only";
1201
1202 if (!isValidColumnType(columnType))
1203 return emitError() << "matrix columns must be vectors of floats";
1204
1205 /// The underlying vectors (columns) must be of size 2, 3, or 4
1206 ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
1207 if (columnShape.size() != 1)
1208 return emitError() << "matrix columns must be 1D vectors";
1209
1210 if (columnShape[0] < 2 || columnShape[0] > 4)
1211 return emitError() << "matrix columns must be of size 2, 3, or 4";
1212
1213 return success();
1214}
1215
1216/// Returns true if the matrix elements are vectors of float elements
1218 if (auto vectorType = dyn_cast<VectorType>(columnType)) {
1219 if (isa<FloatType>(vectorType.getElementType()))
1220 return true;
1221 }
1222 return false;
1223}
1224
1225Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1226
1228 return cast<VectorType>(getImpl()->columnType).getElementType();
1229}
1230
1231unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1232
1233unsigned MatrixType::getNumRows() const {
1234 return cast<VectorType>(getImpl()->columnType).getShape()[0];
1235}
1236
1238 return (getImpl()->columnCount) * getNumRows();
1239}
1240
1241void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1242 add(type.getColumnType());
1243 static constexpr auto cap = Capability::Matrix;
1244 capabilities.push_back(cap);
1245}
1246
1247//===----------------------------------------------------------------------===//
1248// TensorArmType
1249//===----------------------------------------------------------------------===//
1250
1252 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1253
1255 const KeyTy &key) {
1256 auto [shape, elementType] = key;
1257 shape = allocator.copyInto(shape);
1258 return new (allocator.allocate<TensorArmTypeStorage>())
1260 }
1261
1262 static llvm::hash_code hashKey(const KeyTy &key) {
1263 auto [shape, elementType] = key;
1264 return llvm::hash_combine(shape, elementType);
1265 }
1266
1267 bool operator==(const KeyTy &key) const {
1268 return key == KeyTy(shape, elementType);
1269 }
1270
1273
1276};
1277
1279 return Base::get(elementType.getContext(), shape, elementType);
1280}
1281
1283 Type elementType) const {
1284 return TensorArmType::get(shape.value_or(getShape()), elementType);
1285}
1286
1287Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1289
1290void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1291 add(type.getElementType());
1292 static constexpr auto ext = Extension::SPV_ARM_tensors;
1293 extensions.push_back(ext);
1294}
1295
1296void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1297 add(type.getElementType());
1298 static constexpr auto cap = Capability::TensorsARM;
1299 capabilities.push_back(cap);
1300}
1301
1302LogicalResult
1304 ArrayRef<int64_t> shape, Type elementType) {
1305 if (llvm::is_contained(shape, 0))
1306 return emitError() << "arm.tensor do not support dimensions = 0";
1307 if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1308 llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1309 return emitError()
1310 << "arm.tensor shape dimensions must be either fully dynamic or "
1311 "completed shaped";
1312 return success();
1313}
1314
1315//===----------------------------------------------------------------------===//
1316// SPIR-V Dialect
1317//===----------------------------------------------------------------------===//
1318
1319void SPIRVDialect::registerTypes() {
1322}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
#define add(a, b)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, uint32_t > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
ArrayRef< Type > getMemberTypes() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const