MLIR 23.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <cstdint>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::spirv;
27
28namespace {
29// Helper function to collect extensions implied by a type by visiting all its
30// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31//
32// Serves as the source-of-truth for type extension information. All extension
33// logic should be added to this class, while the
34// `SPIRVType::getExtensions` function should not handle extension-related logic
35// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36class TypeExtensionVisitor {
37public:
38 TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
41
42 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43 // as seen and dispatches to the right concrete `.add` function.
44 void add(SPIRVType type) {
45 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46 return;
47
49 .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
50 [this](auto concreteType) { addConcrete(concreteType); })
51 .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52 [this](auto concreteType) { add(concreteType.getElementType()); })
53 .Case([this](SampledImageType concreteType) {
54 add(concreteType.getImageType());
55 })
56 .Case([this](StructType concreteType) {
57 for (Type elementType : concreteType.getElementTypes())
58 add(elementType);
59 })
60 .Case<SamplerType>([](auto) { /* no extensions */ })
61 .DefaultUnreachable("Unhandled type");
62 }
63
64 void add(Type type) { add(cast<SPIRVType>(type)); }
65
66private:
67 // Types that add unique extensions.
68 void addConcrete(CooperativeMatrixType type);
69 void addConcrete(PointerType type);
70 void addConcrete(ScalarType type);
71 void addConcrete(TensorArmType type);
72
74 std::optional<StorageClass> storage;
75 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
76};
77
78// Helper function to collect capabilities implied by a type by visiting all its
79// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
80//
81// Serves as the source-of-truth for type capability information. All capability
82// logic should be added to this class, while the
83// `SPIRVType::getCapabilities` function should not handle capability-related
84// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
85class TypeCapabilityVisitor {
86public:
87 TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
88 std::optional<StorageClass> storage)
89 : capabilities(capabilities), storage(storage) {}
90
91 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
92 // as seen and dispatches to the right concrete `.add` function.
93 void add(SPIRVType type) {
94 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
95 return;
96
98 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
99 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
100 [this](auto concreteType) { addConcrete(concreteType); })
101 .Case([this](ArrayType concreteType) {
102 add(concreteType.getElementType());
103 })
104 .Case([this](SampledImageType concreteType) {
105 add(concreteType.getImageType());
106 })
107 .Case([this](StructType concreteType) {
108 for (Type elementType : concreteType.getElementTypes())
109 add(elementType);
110 })
111 .Case<SamplerType>([](auto) { /* no capabilities */ })
112 .DefaultUnreachable("Unhandled type");
113 }
114
115 void add(Type type) { add(cast<SPIRVType>(type)); }
116
117private:
118 // Types that add unique extensions.
119 void addConcrete(CooperativeMatrixType type);
120 void addConcrete(ImageType type);
121 void addConcrete(MatrixType type);
122 void addConcrete(PointerType type);
123 void addConcrete(RuntimeArrayType type);
124 void addConcrete(ScalarType type);
125 void addConcrete(TensorArmType type);
126 void addConcrete(VectorType type);
127
129 std::optional<StorageClass> storage;
130 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
131};
132
133} // namespace
134
135//===----------------------------------------------------------------------===//
136// ArrayType
137//===----------------------------------------------------------------------===//
138
140 using KeyTy = std::tuple<Type, unsigned, unsigned>;
141
143 const KeyTy &key) {
144 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
145 }
146
147 bool operator==(const KeyTy &key) const {
148 return key == KeyTy(elementType, elementCount, stride);
149 }
150
152 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
153 stride(std::get<2>(key)) {}
154
156 unsigned elementCount;
157 unsigned stride;
158};
159
160ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
161 assert(elementCount && "ArrayType needs at least one element");
162 return Base::get(elementType.getContext(), elementType, elementCount,
163 /*stride=*/0);
164}
165
166ArrayType ArrayType::get(Type elementType, unsigned elementCount,
167 unsigned stride) {
168 assert(elementCount && "ArrayType needs at least one element");
169 return Base::get(elementType.getContext(), elementType, elementCount, stride);
170}
171
172unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
173
174Type ArrayType::getElementType() const { return getImpl()->elementType; }
175
176unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
177
178//===----------------------------------------------------------------------===//
179// CompositeType
180//===----------------------------------------------------------------------===//
181
183 if (auto vectorType = dyn_cast<VectorType>(type))
184 return isValid(vectorType);
187 type);
188}
189
190bool CompositeType::isValid(VectorType type) {
191 return type.getRank() == 1 &&
192 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
193 (isa<ScalarType>(type.getElementType()) ||
194 isa<PointerType>(type.getElementType()));
195}
196
198 return TypeSwitch<Type, Type>(*this)
200 TensorArmType>([](auto type) { return type.getElementType(); })
201 .Case([](MatrixType type) { return type.getColumnType(); })
202 .Case([index](StructType type) { return type.getElementType(index); })
203 .DefaultUnreachable("Invalid composite type");
204}
205
208 .Case<ArrayType, StructType, TensorArmType, VectorType>(
209 [](auto type) { return type.getNumElements(); })
210 .Case([](MatrixType type) { return type.getNumColumns(); })
211 .DefaultUnreachable("Invalid type for number of elements query");
212}
213
215 return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
216}
217
218void TypeCapabilityVisitor::addConcrete(VectorType type) {
219 add(type.getElementType());
220
221 int64_t vecSize = type.getNumElements();
222 if (vecSize == 8 || vecSize == 16) {
223 static constexpr auto cap = Capability::Vector16;
224 capabilities.push_back(cap);
225 }
226}
227
228//===----------------------------------------------------------------------===//
229// CooperativeMatrixType
230//===----------------------------------------------------------------------===//
231
233 // In the specification dimensions of the Cooperative Matrix are 32-bit
234 // integers --- the initial implementation kept those values as such. However,
235 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
236 // as 32-bits and expose it as int64_t through `getShape`, however, this
237 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
238 // 32-bits integers would require an extra logic and storage. So, we diverge
239 // from the spec and internally represent the dimensions as 64-bit integers,
240 // so we can easily return an `ArrayRef` from `getShape` without any extra
241 // logic. Alternatively, we could store both rows and columns (both 32-bits)
242 // and shape (64-bits), assigning rows and columns to shape whenever
243 // `getShape` is called. This would be at the cost of extra logic and storage.
244 // Note: Because `ArrayRef` is returned we cannot construct an object in
245 // `getShape` on the fly.
246 using KeyTy =
247 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
248
250 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
251 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
253 }
254
255 bool operator==(const KeyTy &key) const {
256 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
257 }
258
260 : elementType(std::get<0>(key)),
261 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
262 use(std::get<4>(key)) {}
263
265 // [#rows, #columns]
266 std::array<int64_t, 2> shape;
267 Scope scope;
268 CooperativeMatrixUseKHR use;
269};
270
272 uint32_t rows,
273 uint32_t columns, Scope scope,
274 CooperativeMatrixUseKHR use) {
275 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
276 use);
277}
278
280 return getImpl()->elementType;
281}
282
284 assert(getImpl()->shape[0] != ShapedType::kDynamic);
285 return static_cast<uint32_t>(getImpl()->shape[0]);
286}
287
289 assert(getImpl()->shape[1] != ShapedType::kDynamic);
290 return static_cast<uint32_t>(getImpl()->shape[1]);
291}
292
296
297Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
298
299CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
300 return getImpl()->use;
301}
302
303void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
304 add(type.getElementType());
305 static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
306 extensions.push_back(ext);
307}
308
309void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
310 add(type.getElementType());
311 static constexpr auto caps = Capability::CooperativeMatrixKHR;
312 capabilities.push_back(caps);
313}
314
315//===----------------------------------------------------------------------===//
316// ImageType
317//===----------------------------------------------------------------------===//
318
319template <typename T>
320static constexpr unsigned getNumBits() {
321 return 0;
322}
323template <>
324constexpr unsigned getNumBits<Dim>() {
325 static_assert((1 << 3) > getMaxEnumValForDim(),
326 "Not enough bits to encode Dim value");
327 return 3;
328}
329template <>
330constexpr unsigned getNumBits<ImageDepthInfo>() {
331 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
332 "Not enough bits to encode ImageDepthInfo value");
333 return 2;
334}
335template <>
336constexpr unsigned getNumBits<ImageArrayedInfo>() {
337 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
338 "Not enough bits to encode ImageArrayedInfo value");
339 return 1;
340}
341template <>
342constexpr unsigned getNumBits<ImageSamplingInfo>() {
343 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
344 "Not enough bits to encode ImageSamplingInfo value");
345 return 1;
346}
347template <>
349 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
350 "Not enough bits to encode ImageSamplerUseInfo value");
351 return 2;
352}
353template <>
354constexpr unsigned getNumBits<ImageFormat>() {
355 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
356 "Not enough bits to encode ImageFormat value");
357 return 6;
358}
359
361public:
362 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
363 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
364
366 const KeyTy &key) {
367 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
368 }
369
370 bool operator==(const KeyTy &key) const {
373 }
374
376 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
377 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
378 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
379 format(std::get<6>(key)) {}
380
388};
389
391ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
392 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
393 value) {
394 return Base::get(std::get<0>(value).getContext(), value);
395}
396
397Type ImageType::getElementType() const { return getImpl()->elementType; }
398
399Dim ImageType::getDim() const { return getImpl()->dim; }
400
401ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
402
403ImageArrayedInfo ImageType::getArrayedInfo() const {
404 return getImpl()->arrayedInfo;
405}
406
407ImageSamplingInfo ImageType::getSamplingInfo() const {
408 return getImpl()->samplingInfo;
409}
410
411ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
412 return getImpl()->samplerUseInfo;
413}
414
415ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
416
417void TypeCapabilityVisitor::addConcrete(ImageType type) {
418 if (auto dimCaps = spirv::getCapabilities(type.getDim()))
419 capabilities.push_back(*dimCaps);
420
421 if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
422 capabilities.push_back(*fmtCaps);
423
424 add(type.getElementType());
425}
426
427//===----------------------------------------------------------------------===//
428// PointerType
429//===----------------------------------------------------------------------===//
430
432 // (Type, StorageClass) as the key: Type stored in this struct, and
433 // StorageClass stored as TypeStorage's subclass data.
434 using KeyTy = std::pair<Type, StorageClass>;
435
437 const KeyTy &key) {
438 return new (allocator.allocate<PointerTypeStorage>())
440 }
441
442 bool operator==(const KeyTy &key) const {
443 return key == KeyTy(pointeeType, storageClass);
444 }
445
447 : pointeeType(key.first), storageClass(key.second) {}
448
450 StorageClass storageClass;
451};
452
453PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
454 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
455}
456
457Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
458
459StorageClass PointerType::getStorageClass() const {
460 return getImpl()->storageClass;
461}
462
463void TypeExtensionVisitor::addConcrete(PointerType type) {
464 // Use this pointer type's storage class because this pointer indicates we are
465 // using the pointee type in that specific storage class.
466 std::optional<StorageClass> oldStorageClass = storage;
467 storage = type.getStorageClass();
468 add(type.getPointeeType());
469 storage = oldStorageClass;
470
471 if (auto scExts = spirv::getExtensions(type.getStorageClass()))
472 extensions.push_back(*scExts);
473}
474
475void TypeCapabilityVisitor::addConcrete(PointerType type) {
476 // Use this pointer type's storage class because this pointer indicates we are
477 // using the pointee type in that specific storage class.
478 std::optional<StorageClass> oldStorageClass = storage;
479 storage = type.getStorageClass();
480 add(type.getPointeeType());
481 storage = oldStorageClass;
482
483 if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
484 capabilities.push_back(*scCaps);
485}
486
487//===----------------------------------------------------------------------===//
488// RuntimeArrayType
489//===----------------------------------------------------------------------===//
490
492 using KeyTy = std::pair<Type, unsigned>;
493
495 const KeyTy &key) {
496 return new (allocator.allocate<RuntimeArrayTypeStorage>())
498 }
499
500 bool operator==(const KeyTy &key) const {
501 return key == KeyTy(elementType, stride);
502 }
503
505 : elementType(key.first), stride(key.second) {}
506
508 unsigned stride;
509};
510
512 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
513}
514
515RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
516 return Base::get(elementType.getContext(), elementType, stride);
517}
518
519Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
520
521unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
522
523void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
524 add(type.getElementType());
525 static constexpr auto cap = Capability::Shader;
526 capabilities.push_back(cap);
527}
528
529//===----------------------------------------------------------------------===//
530// ScalarType
531//===----------------------------------------------------------------------===//
532
534 if (auto floatType = dyn_cast<FloatType>(type)) {
535 return isValid(floatType);
536 }
537 if (auto intType = dyn_cast<IntegerType>(type)) {
538 return isValid(intType);
539 }
540 return false;
541}
542
543bool ScalarType::isValid(FloatType type) {
544 if (type.isF8E4M3FN() || type.isF8E5M2())
545 return true;
546 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
547}
548
549bool ScalarType::isValid(IntegerType type) {
550 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
551}
552
553void TypeExtensionVisitor::addConcrete(ScalarType type) {
554 if (isa<BFloat16Type>(type)) {
555 static constexpr auto ext = Extension::SPV_KHR_bfloat16;
556 extensions.push_back(ext);
557 }
558
559 if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
560 static constexpr auto ext = Extension::SPV_EXT_float8;
561 extensions.push_back(ext);
562 }
563
564 // 8- or 16-bit integer/floating-point numbers will require extra extensions
565 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
566 // SPV_KHR_8bit_storage for more details.
567 if (!storage)
568 return;
569
570 switch (*storage) {
571 case StorageClass::PushConstant:
572 case StorageClass::StorageBuffer:
573 case StorageClass::Uniform:
574 if (type.getIntOrFloatBitWidth() == 8) {
575 static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
576 extensions.push_back(ext);
577 }
578 [[fallthrough]];
579 case StorageClass::Input:
580 case StorageClass::Output:
581 if (type.getIntOrFloatBitWidth() == 16) {
582 static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
583 extensions.push_back(ext);
584 }
585 break;
586 default:
587 break;
588 }
589}
590
591void TypeCapabilityVisitor::addConcrete(ScalarType type) {
592 unsigned bitwidth = type.getIntOrFloatBitWidth();
593
594 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
595 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
596 // SPV_KHR_8bit_storage for more details.
597
598#define STORAGE_CASE(storage, cap8, cap16) \
599 case StorageClass::storage: { \
600 if (bitwidth == 8) { \
601 static constexpr auto cap = Capability::cap8; \
602 capabilities.push_back(cap); \
603 return; \
604 } \
605 if (bitwidth == 16) { \
606 static constexpr auto cap = Capability::cap16; \
607 capabilities.push_back(cap); \
608 return; \
609 } \
610 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
611 /* storage classes. Fall through to the next section. */ \
612 } break
613
614 // This part only handles the cases where special bitwidths appearing in
615 // interface storage classes.
616 if (storage) {
617 switch (*storage) {
618 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
619 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
620 StorageBuffer16BitAccess);
621 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
622 StorageUniform16);
623 case StorageClass::Input:
624 case StorageClass::Output: {
625 if (bitwidth == 16) {
626 static constexpr auto cap = Capability::StorageInputOutput16;
627 capabilities.push_back(cap);
628 return;
629 }
630 break;
631 }
632 default:
633 break;
634 }
635 }
636#undef STORAGE_CASE
637
638 // For other non-interface storage classes, require a different set of
639 // capabilities for special bitwidths.
640
641#define WIDTH_CASE(type, width) \
642 case width: { \
643 static constexpr auto cap = Capability::type##width; \
644 capabilities.push_back(cap); \
645 } break
646
647 if (auto intType = dyn_cast<IntegerType>(type)) {
648 switch (bitwidth) {
649 WIDTH_CASE(Int, 8);
650 WIDTH_CASE(Int, 16);
651 WIDTH_CASE(Int, 64);
652 case 1:
653 case 32:
654 break;
655 default:
656 llvm_unreachable("invalid bitwidth to getCapabilities");
657 }
658 } else {
659 assert(isa<FloatType>(type));
660 switch (bitwidth) {
661 case 8: {
662 if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
663 static constexpr auto cap = Capability::Float8EXT;
664 capabilities.push_back(cap);
665 } else {
666 llvm_unreachable("invalid 8-bit float type to getCapabilities");
667 }
668 break;
669 }
670 case 16: {
671 if (isa<BFloat16Type>(type)) {
672 static constexpr auto cap = Capability::BFloat16TypeKHR;
673 capabilities.push_back(cap);
674 } else {
675 static constexpr auto cap = Capability::Float16;
676 capabilities.push_back(cap);
677 }
678 break;
679 }
680 WIDTH_CASE(Float, 64);
681 case 32:
682 break;
683 default:
684 llvm_unreachable("invalid bitwidth to getCapabilities");
685 }
686 }
687
688#undef WIDTH_CASE
689}
690
691//===----------------------------------------------------------------------===//
692// SPIRVType
693//===----------------------------------------------------------------------===//
694
696 // Allow SPIR-V dialect types
697 if (isa<SPIRVDialect>(type.getDialect()))
698 return true;
699 if (isa<ScalarType>(type))
700 return true;
701 if (auto vectorType = dyn_cast<VectorType>(type))
702 return CompositeType::isValid(vectorType);
703 if (auto tensorArmType = dyn_cast<TensorArmType>(type))
704 return isa<ScalarType>(tensorArmType.getElementType());
705 return false;
706}
707
709 return isIntOrFloat() || isa<VectorType>(*this);
710}
711
713 std::optional<StorageClass> storage) {
714 TypeExtensionVisitor{extensions, storage}.add(*this);
715}
716
719 std::optional<StorageClass> storage) {
720 TypeCapabilityVisitor{capabilities, storage}.add(*this);
721}
722
723std::optional<int64_t> SPIRVType::getSizeInBytes() {
725 .Case([](ScalarType type) -> std::optional<int64_t> {
726 // According to the SPIR-V spec:
727 // "There is no physical size or bit pattern defined for values with
728 // boolean type. If they are stored (in conjunction with OpVariable),
729 // they can only be used with logical addressing operations, not
730 // physical, and only with non-externally visible shader Storage
731 // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
732 // Output."
733 int64_t bitWidth = type.getIntOrFloatBitWidth();
734 if (bitWidth == 1)
735 return std::nullopt;
736 return bitWidth / 8;
737 })
738 .Case([](ArrayType type) -> std::optional<int64_t> {
739 // Since array type may have an explicit stride declaration (in bytes),
740 // we also include it in the calculation.
741 auto elementType = cast<SPIRVType>(type.getElementType());
742 if (std::optional<int64_t> size = elementType.getSizeInBytes())
743 return (*size + type.getArrayStride()) * type.getNumElements();
744 return std::nullopt;
745 })
746 .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
747 if (std::optional<int64_t> elementSize =
748 cast<ScalarType>(type.getElementType()).getSizeInBytes())
749 return *elementSize * type.getNumElements();
750 return std::nullopt;
751 })
752 .Default(std::nullopt);
753}
754
755//===----------------------------------------------------------------------===//
756// SampledImageType
757//===----------------------------------------------------------------------===//
759 using KeyTy = Type;
760
762
763 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
764
766 const KeyTy &key) {
767 return new (allocator.allocate<SampledImageTypeStorage>())
769 }
770
772};
773
775 return Base::get(imageType.getContext(), imageType);
776}
777
783
784Type SampledImageType::getImageType() const { return getImpl()->imageType; }
785
786LogicalResult
788 Type imageType) {
789 auto image = dyn_cast<ImageType>(imageType);
790 if (!image)
791 return emitError() << "expected image type";
792
793 // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
794 // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
795 // ("3.3.6. Type-Declaration Instructions")
796 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
797 return emitError() << "Dim must not be SubpassData or Buffer";
798
799 return success();
800}
801
802//===----------------------------------------------------------------------===//
803// SamplerType
804//===----------------------------------------------------------------------===//
805
807 return Base::get(context);
808}
809
810//===----------------------------------------------------------------------===//
811// StructType
812//===----------------------------------------------------------------------===//
813
814/// Type storage for SPIR-V structure types:
815///
816/// Structures are uniqued using:
817/// - for identified structs:
818/// - a string identifier;
819/// - for literal structs:
820/// - a list of member types;
821/// - a list of member offset info;
822/// - a list of member decoration info;
823/// - a list of struct decoration info.
824///
825/// Identified structures only have a mutable component consisting of:
826/// - a list of member types;
827/// - a list of member offset info;
828/// - a list of member decoration info;
829/// - a list of struct decoration info.
831 /// Construct a storage object for an identified struct type. A struct type
832 /// associated with such storage must call StructType::trySetBody(...) later
833 /// in order to mutate the storage object providing the actual content.
839
840 /// Construct a storage object for a literal struct type. A struct type
841 /// associated with such storage is immutable.
853
854 /// A storage key is divided into 2 parts:
855 /// - for identified structs:
856 /// - a StringRef representing the struct identifier;
857 /// - for literal structs:
858 /// - an ArrayRef<Type> for member types;
859 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
860 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
861 /// info;
862 /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
863 /// info.
864 ///
865 /// An identified struct type is uniqued only by the first part (field 0)
866 /// of the key.
867 ///
868 /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
869 /// and 4) of the key. The identifier field (field 0) must be empty.
870 using KeyTy =
871 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
874
875 /// For identified structs, return true if the given key contains the same
876 /// identifier.
877 ///
878 /// For literal structs, return true if the given key contains a matching list
879 /// of member types + offset info + decoration info.
880 bool operator==(const KeyTy &key) const {
881 if (isIdentified()) {
882 // Identified types are uniqued by their identifier.
883 return getIdentifier() == std::get<0>(key);
884 }
885
886 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
888 }
889
890 /// If the given key contains a non-empty identifier, this method constructs
891 /// an identified struct and leaves the rest of the struct type data to be set
892 /// through a later call to StructType::trySetBody(...).
893 ///
894 /// If, on the other hand, the key contains an empty identifier, a literal
895 /// struct is constructed using the other fields of the key.
897 const KeyTy &key) {
898 StringRef keyIdentifier = std::get<0>(key);
899
900 if (!keyIdentifier.empty()) {
901 StringRef identifier = allocator.copyInto(keyIdentifier);
902
903 // Identified StructType body/members will be set through trySetBody(...)
904 // later.
905 return new (allocator.allocate<StructTypeStorage>())
907 }
908
909 ArrayRef<Type> keyTypes = std::get<1>(key);
910
911 // Copy the member type and layout information into the bump pointer
912 const Type *typesList = nullptr;
913 if (!keyTypes.empty()) {
914 typesList = allocator.copyInto(keyTypes).data();
915 }
916
917 const StructType::OffsetInfo *offsetInfoList = nullptr;
918 if (!std::get<2>(key).empty()) {
919 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
920 assert(keyOffsetInfo.size() == keyTypes.size() &&
921 "size of offset information must be same as the size of number of "
922 "elements");
923 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
924 }
925
926 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
927 unsigned numMemberDecorations = 0;
928 if (!std::get<3>(key).empty()) {
929 auto keyMemberDecorations = std::get<3>(key);
930 numMemberDecorations = keyMemberDecorations.size();
931 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
932 }
933
934 const StructType::StructDecorationInfo *structDecorationList = nullptr;
935 unsigned numStructDecorations = 0;
936 if (!std::get<4>(key).empty()) {
937 auto keyStructDecorations = std::get<4>(key);
938 numStructDecorations = keyStructDecorations.size();
939 structDecorationList = allocator.copyInto(keyStructDecorations).data();
940 }
941
942 return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
943 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
944 memberDecorationList, numStructDecorations, structDecorationList);
945 }
946
950
957
965
972
973 StringRef getIdentifier() const { return identifier; }
974
975 bool isIdentified() const { return !identifier.empty(); }
976
977 /// Sets the struct type content for identified structs. Calling this method
978 /// is only valid for identified structs.
979 ///
980 /// Fails under the following conditions:
981 /// - If called for a literal struct;
982 /// - If called for an identified struct whose body was set before (through a
983 /// call to this method) but with different contents from the passed
984 /// arguments.
985 LogicalResult
986 mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
987 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
988 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
989 ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
990 if (!isIdentified())
991 return failure();
992
993 if (memberTypesAndIsBodySet.getInt() &&
994 (getMemberTypes() != structMemberTypes ||
995 getOffsetInfo() != structOffsetInfo ||
996 getMemberDecorationsInfo() != structMemberDecorationInfo ||
997 getStructDecorationsInfo() != structDecorationInfo))
998 return failure();
999
1000 memberTypesAndIsBodySet.setInt(true);
1001 numMembers = structMemberTypes.size();
1002
1003 // Copy the member type and layout information into the bump pointer.
1004 if (!structMemberTypes.empty())
1005 memberTypesAndIsBodySet.setPointer(
1006 allocator.copyInto(structMemberTypes).data());
1007
1008 if (!structOffsetInfo.empty()) {
1009 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1010 "size of offset information must be same as the size of number of "
1011 "elements");
1012 offsetInfo = allocator.copyInto(structOffsetInfo).data();
1013 }
1014
1015 if (!structMemberDecorationInfo.empty()) {
1016 numMemberDecorations = structMemberDecorationInfo.size();
1018 allocator.copyInto(structMemberDecorationInfo).data();
1019 }
1020
1021 if (!structDecorationInfo.empty()) {
1022 numStructDecorations = structDecorationInfo.size();
1023 structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
1024 }
1025
1026 return success();
1027 }
1028
1029 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1031 unsigned numMembers;
1036 StringRef identifier;
1037};
1038
1044 assert(!memberTypes.empty() && "Struct needs at least one member type");
1045 // Sort the decorations.
1047 memberDecorations);
1048 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1049 sortedMemberDecorations.end());
1051 structDecorations);
1052 llvm::array_pod_sort(sortedStructDecorations.begin(),
1053 sortedStructDecorations.end());
1054
1055 return Base::get(memberTypes.vec().front().getContext(),
1056 /*identifier=*/StringRef(), memberTypes, offsetInfo,
1057 sortedMemberDecorations, sortedStructDecorations);
1058}
1059
1061 StringRef identifier) {
1062 assert(!identifier.empty() &&
1063 "StructType identifier must be non-empty string");
1064
1065 return Base::get(context, identifier, ArrayRef<Type>(),
1069}
1070
1071StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1072 StructType newStructType = Base::get(
1073 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1076 // Set an empty body in case this is a identified struct.
1077 if (newStructType.isIdentified() &&
1078 failed(newStructType.trySetBody(
1082 return StructType();
1083
1084 return newStructType;
1085}
1086
1087StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1088
1089bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1090
1091unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1092
1094 assert(getNumElements() > index && "member index out of range");
1095 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1096}
1097
1099 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1100 getNumElements());
1101}
1102
1103bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1104
1105bool StructType::hasDecoration(spirv::Decoration decoration) const {
1107 getImpl()->getStructDecorationsInfo())
1108 if (info.decoration == decoration)
1109 return true;
1110
1111 return false;
1112}
1113
1114uint64_t StructType::getMemberOffset(unsigned index) const {
1115 assert(getNumElements() > index && "member index out of range");
1116 return getImpl()->offsetInfo[index];
1117}
1118
1121 const {
1122 memberDecorations.clear();
1123 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1124 memberDecorations.append(implMemberDecorations.begin(),
1125 implMemberDecorations.end());
1126}
1127
1129 unsigned index,
1131 assert(getNumElements() > index && "member index out of range");
1132 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1133 decorationsInfo.clear();
1134 for (const auto &memberDecoration : memberDecorations) {
1135 if (memberDecoration.memberIndex == index) {
1136 decorationsInfo.push_back(memberDecoration);
1137 }
1138 if (memberDecoration.memberIndex > index) {
1139 // Early exit since the decorations are stored sorted.
1140 return;
1141 }
1142 }
1143}
1144
1147 const {
1148 structDecorations.clear();
1149 auto implDecorations = getImpl()->getStructDecorationsInfo();
1150 structDecorations.append(implDecorations.begin(), implDecorations.end());
1151}
1152
1153LogicalResult
1155 ArrayRef<OffsetInfo> offsetInfo,
1156 ArrayRef<MemberDecorationInfo> memberDecorations,
1157 ArrayRef<StructDecorationInfo> structDecorations) {
1158 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1159 structDecorations);
1160}
1161
1162llvm::hash_code spirv::hash_value(
1163 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1164 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1165 memberDecorationInfo.decoration);
1166}
1167
1168llvm::hash_code spirv::hash_value(
1169 const StructType::StructDecorationInfo &structDecorationInfo) {
1170 return llvm::hash_value(structDecorationInfo.decoration);
1171}
1172
1173//===----------------------------------------------------------------------===//
1174// MatrixType
1175//===----------------------------------------------------------------------===//
1176
1178 // Use a 64-bit integer as a column count internally to better support a
1179 // `ShapedType` interface. See comment in `CooperativeMatrixType` for more
1180 // context.
1181 using KeyTy = std::tuple<Type, int64_t>;
1182
1184 : columnType(std::get<0>(key)),
1185 shape({cast<VectorType>(std::get<0>(key)).getShape()[0],
1186 std::get<1>(key)}) {}
1187
1189 const KeyTy &key) {
1190
1191 // Initialize the memory using placement new.
1192 return new (allocator.allocate<MatrixTypeStorage>()) MatrixTypeStorage(key);
1193 }
1194
1195 bool operator==(const KeyTy &key) const {
1196 return key == KeyTy(columnType, shape[1]);
1197 }
1198
1200 // [#rows, #columns]
1201 std::array<int64_t, 2> shape;
1202};
1203
1204MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1205 return Base::get(columnType.getContext(), columnType, columnCount);
1206}
1207
1209 Type columnType, uint32_t columnCount) {
1210 return Base::getChecked(emitError, columnType.getContext(), columnType,
1211 columnCount);
1212}
1213
1214LogicalResult
1216 Type columnType, uint32_t columnCount) {
1217 if (columnCount < 2 || columnCount > 4)
1218 return emitError() << "matrix can have 2, 3, or 4 columns only";
1219
1220 if (!isValidColumnType(columnType))
1221 return emitError() << "matrix columns must be vectors of floats";
1222
1223 /// The underlying vectors (columns) must be of size 2, 3, or 4
1224 ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
1225 if (columnShape.size() != 1)
1226 return emitError() << "matrix columns must be 1D vectors";
1227
1228 if (columnShape[0] < 2 || columnShape[0] > 4)
1229 return emitError() << "matrix columns must be of size 2, 3, or 4";
1230
1231 return success();
1232}
1233
1234/// Returns true if the matrix elements are vectors of float elements
1236 if (auto vectorType = dyn_cast<VectorType>(columnType)) {
1237 if (isa<FloatType>(vectorType.getElementType()))
1238 return true;
1239 }
1240 return false;
1241}
1242
1243Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1244
1246 return cast<VectorType>(getImpl()->columnType).getElementType();
1247}
1248
1250 assert(getImpl()->shape[1] >= 0); // Also includes ShapedType::kDynamic.
1251 assert(getImpl()->shape[1] <= std::numeric_limits<unsigned>::max());
1252 return static_cast<uint32_t>(getImpl()->shape[1]);
1253}
1254
1255unsigned MatrixType::getNumRows() const {
1256 assert(getImpl()->shape[0] >= 0); // Also includes ShapedType::kDynamic.
1257 assert(getImpl()->shape[0] <= std::numeric_limits<unsigned>::max());
1258 return static_cast<uint32_t>(getImpl()->shape[0]);
1259}
1260
1262 return getNumColumns() * getNumRows();
1263}
1264
1266
1267void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1268 add(type.getColumnType());
1269 static constexpr auto cap = Capability::Matrix;
1270 capabilities.push_back(cap);
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// TensorArmType
1275//===----------------------------------------------------------------------===//
1276
1278 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1279
1281 const KeyTy &key) {
1282 auto [shape, elementType] = key;
1283 shape = allocator.copyInto(shape);
1284 return new (allocator.allocate<TensorArmTypeStorage>())
1286 }
1287
1288 static llvm::hash_code hashKey(const KeyTy &key) {
1289 auto [shape, elementType] = key;
1290 return llvm::hash_combine(shape, elementType);
1291 }
1292
1293 bool operator==(const KeyTy &key) const {
1294 return key == KeyTy(shape, elementType);
1295 }
1296
1299
1302};
1303
1305 return Base::get(elementType.getContext(), shape, elementType);
1306}
1307
1309 Type elementType) const {
1310 return TensorArmType::get(shape.value_or(getShape()), elementType);
1311}
1312
1313Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1315
1316void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1317 add(type.getElementType());
1318 static constexpr auto ext = Extension::SPV_ARM_tensors;
1319 extensions.push_back(ext);
1320}
1321
1322void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1323 add(type.getElementType());
1324 static constexpr auto cap = Capability::TensorsARM;
1325 capabilities.push_back(cap);
1326}
1327
1328LogicalResult
1330 ArrayRef<int64_t> shape, Type elementType) {
1331 if (llvm::is_contained(shape, 0))
1332 return emitError() << "arm.tensor do not support dimensions = 0";
1333 if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1334 llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1335 return emitError()
1336 << "arm.tensor shape dimensions must be either fully dynamic or "
1337 "completed shaped";
1338 return success();
1339}
1340
1341//===----------------------------------------------------------------------===//
1342// SPIR-V Dialect
1343//===----------------------------------------------------------------------===//
1344
1345void SPIRVDialect::registerTypes() {
1348 TensorArmType>();
1349}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
#define add(a, b)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:148
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
ArrayRef< int64_t > getShape() const
unsigned getNumRows() const
Returns the number of rows.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:66
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:55
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:263
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:498
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t > KeyTy
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
ArrayRef< Type > getMemberTypes() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const