MLIR 23.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <cstdint>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::spirv;
27
28namespace {
29// Helper function to collect extensions implied by a type by visiting all its
30// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31//
32// Serves as the source-of-truth for type extension information. All extension
33// logic should be added to this class, while the
34// `SPIRVType::getExtensions` function should not handle extension-related logic
35// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36class TypeExtensionVisitor {
37public:
38 TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
41
42 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43 // as seen and dispatches to the right concrete `.add` function.
44 void add(SPIRVType type) {
45 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46 return;
47
49 .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
50 [this](auto concreteType) { addConcrete(concreteType); })
51 .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52 [this](auto concreteType) { add(concreteType.getElementType()); })
53 .Case([this](SampledImageType concreteType) {
54 add(concreteType.getImageType());
55 })
56 .Case([this](StructType concreteType) {
57 for (Type elementType : concreteType.getElementTypes())
58 add(elementType);
59 })
60 .DefaultUnreachable("Unhandled type");
61 }
62
63 void add(Type type) { add(cast<SPIRVType>(type)); }
64
65private:
66 // Types that add unique extensions.
67 void addConcrete(CooperativeMatrixType type);
68 void addConcrete(PointerType type);
69 void addConcrete(ScalarType type);
70 void addConcrete(TensorArmType type);
71
73 std::optional<StorageClass> storage;
74 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
75};
76
77// Helper function to collect capabilities implied by a type by visiting all its
78// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
79//
80// Serves as the source-of-truth for type capability information. All capability
81// logic should be added to this class, while the
82// `SPIRVType::getCapabilities` function should not handle capability-related
83// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
84class TypeCapabilityVisitor {
85public:
86 TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
87 std::optional<StorageClass> storage)
88 : capabilities(capabilities), storage(storage) {}
89
90 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
91 // as seen and dispatches to the right concrete `.add` function.
92 void add(SPIRVType type) {
93 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
94 return;
95
97 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
98 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
99 [this](auto concreteType) { addConcrete(concreteType); })
100 .Case([this](ArrayType concreteType) {
101 add(concreteType.getElementType());
102 })
103 .Case([this](SampledImageType concreteType) {
104 add(concreteType.getImageType());
105 })
106 .Case([this](StructType concreteType) {
107 for (Type elementType : concreteType.getElementTypes())
108 add(elementType);
109 })
110 .DefaultUnreachable("Unhandled type");
111 }
112
113 void add(Type type) { add(cast<SPIRVType>(type)); }
114
115private:
116 // Types that add unique extensions.
117 void addConcrete(CooperativeMatrixType type);
118 void addConcrete(ImageType type);
119 void addConcrete(MatrixType type);
120 void addConcrete(PointerType type);
121 void addConcrete(RuntimeArrayType type);
122 void addConcrete(ScalarType type);
123 void addConcrete(TensorArmType type);
124 void addConcrete(VectorType type);
125
127 std::optional<StorageClass> storage;
128 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
129};
130
131} // namespace
132
133//===----------------------------------------------------------------------===//
134// ArrayType
135//===----------------------------------------------------------------------===//
136
138 using KeyTy = std::tuple<Type, unsigned, unsigned>;
139
141 const KeyTy &key) {
142 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
143 }
144
145 bool operator==(const KeyTy &key) const {
146 return key == KeyTy(elementType, elementCount, stride);
147 }
148
150 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
151 stride(std::get<2>(key)) {}
152
154 unsigned elementCount;
155 unsigned stride;
156};
157
158ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
159 assert(elementCount && "ArrayType needs at least one element");
160 return Base::get(elementType.getContext(), elementType, elementCount,
161 /*stride=*/0);
162}
163
164ArrayType ArrayType::get(Type elementType, unsigned elementCount,
165 unsigned stride) {
166 assert(elementCount && "ArrayType needs at least one element");
167 return Base::get(elementType.getContext(), elementType, elementCount, stride);
168}
169
170unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
171
172Type ArrayType::getElementType() const { return getImpl()->elementType; }
173
174unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
175
176//===----------------------------------------------------------------------===//
177// CompositeType
178//===----------------------------------------------------------------------===//
179
181 if (auto vectorType = dyn_cast<VectorType>(type))
182 return isValid(vectorType);
185 type);
186}
187
188bool CompositeType::isValid(VectorType type) {
189 return type.getRank() == 1 &&
190 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
191 isa<ScalarType>(type.getElementType());
192}
193
195 return TypeSwitch<Type, Type>(*this)
197 TensorArmType>([](auto type) { return type.getElementType(); })
198 .Case([](MatrixType type) { return type.getColumnType(); })
199 .Case([index](StructType type) { return type.getElementType(index); })
200 .DefaultUnreachable("Invalid composite type");
201}
202
205 .Case<ArrayType, StructType, TensorArmType, VectorType>(
206 [](auto type) { return type.getNumElements(); })
207 .Case([](MatrixType type) { return type.getNumColumns(); })
208 .DefaultUnreachable("Invalid type for number of elements query");
209}
210
212 return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
213}
214
215void TypeCapabilityVisitor::addConcrete(VectorType type) {
216 add(type.getElementType());
217
218 int64_t vecSize = type.getNumElements();
219 if (vecSize == 8 || vecSize == 16) {
220 static constexpr auto cap = Capability::Vector16;
221 capabilities.push_back(cap);
222 }
223}
224
225//===----------------------------------------------------------------------===//
226// CooperativeMatrixType
227//===----------------------------------------------------------------------===//
228
230 // In the specification dimensions of the Cooperative Matrix are 32-bit
231 // integers --- the initial implementation kept those values as such. However,
232 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
233 // as 32-bits and expose it as int64_t through `getShape`, however, this
234 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
235 // 32-bits integers would require an extra logic and storage. So, we diverge
236 // from the spec and internally represent the dimensions as 64-bit integers,
237 // so we can easily return an `ArrayRef` from `getShape` without any extra
238 // logic. Alternatively, we could store both rows and columns (both 32-bits)
239 // and shape (64-bits), assigning rows and columns to shape whenever
240 // `getShape` is called. This would be at the cost of extra logic and storage.
241 // Note: Because `ArrayRef` is returned we cannot construct an object in
242 // `getShape` on the fly.
243 using KeyTy =
244 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
245
247 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
248 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
250 }
251
252 bool operator==(const KeyTy &key) const {
253 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
254 }
255
257 : elementType(std::get<0>(key)),
258 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
259 use(std::get<4>(key)) {}
260
262 // [#rows, #columns]
263 std::array<int64_t, 2> shape;
264 Scope scope;
265 CooperativeMatrixUseKHR use;
266};
267
269 uint32_t rows,
270 uint32_t columns, Scope scope,
271 CooperativeMatrixUseKHR use) {
272 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
273 use);
274}
275
277 return getImpl()->elementType;
278}
279
281 assert(getImpl()->shape[0] != ShapedType::kDynamic);
282 return static_cast<uint32_t>(getImpl()->shape[0]);
283}
284
286 assert(getImpl()->shape[1] != ShapedType::kDynamic);
287 return static_cast<uint32_t>(getImpl()->shape[1]);
288}
289
293
294Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
295
296CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
297 return getImpl()->use;
298}
299
300void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
301 add(type.getElementType());
302 static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
303 extensions.push_back(ext);
304}
305
306void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
307 add(type.getElementType());
308 static constexpr auto caps = Capability::CooperativeMatrixKHR;
309 capabilities.push_back(caps);
310}
311
312//===----------------------------------------------------------------------===//
313// ImageType
314//===----------------------------------------------------------------------===//
315
316template <typename T>
317static constexpr unsigned getNumBits() {
318 return 0;
319}
320template <>
321constexpr unsigned getNumBits<Dim>() {
322 static_assert((1 << 3) > getMaxEnumValForDim(),
323 "Not enough bits to encode Dim value");
324 return 3;
325}
326template <>
327constexpr unsigned getNumBits<ImageDepthInfo>() {
328 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
329 "Not enough bits to encode ImageDepthInfo value");
330 return 2;
331}
332template <>
333constexpr unsigned getNumBits<ImageArrayedInfo>() {
334 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
335 "Not enough bits to encode ImageArrayedInfo value");
336 return 1;
337}
338template <>
339constexpr unsigned getNumBits<ImageSamplingInfo>() {
340 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
341 "Not enough bits to encode ImageSamplingInfo value");
342 return 1;
343}
344template <>
346 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
347 "Not enough bits to encode ImageSamplerUseInfo value");
348 return 2;
349}
350template <>
351constexpr unsigned getNumBits<ImageFormat>() {
352 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
353 "Not enough bits to encode ImageFormat value");
354 return 6;
355}
356
358public:
359 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
360 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
361
363 const KeyTy &key) {
364 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
365 }
366
367 bool operator==(const KeyTy &key) const {
370 }
371
373 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
374 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
375 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
376 format(std::get<6>(key)) {}
377
385};
386
388ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
389 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
390 value) {
391 return Base::get(std::get<0>(value).getContext(), value);
392}
393
394Type ImageType::getElementType() const { return getImpl()->elementType; }
395
396Dim ImageType::getDim() const { return getImpl()->dim; }
397
398ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
399
400ImageArrayedInfo ImageType::getArrayedInfo() const {
401 return getImpl()->arrayedInfo;
402}
403
404ImageSamplingInfo ImageType::getSamplingInfo() const {
405 return getImpl()->samplingInfo;
406}
407
408ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
409 return getImpl()->samplerUseInfo;
410}
411
412ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
413
414void TypeCapabilityVisitor::addConcrete(ImageType type) {
415 if (auto dimCaps = spirv::getCapabilities(type.getDim()))
416 capabilities.push_back(*dimCaps);
417
418 if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
419 capabilities.push_back(*fmtCaps);
420
421 add(type.getElementType());
422}
423
424//===----------------------------------------------------------------------===//
425// PointerType
426//===----------------------------------------------------------------------===//
427
429 // (Type, StorageClass) as the key: Type stored in this struct, and
430 // StorageClass stored as TypeStorage's subclass data.
431 using KeyTy = std::pair<Type, StorageClass>;
432
434 const KeyTy &key) {
435 return new (allocator.allocate<PointerTypeStorage>())
437 }
438
439 bool operator==(const KeyTy &key) const {
440 return key == KeyTy(pointeeType, storageClass);
441 }
442
444 : pointeeType(key.first), storageClass(key.second) {}
445
447 StorageClass storageClass;
448};
449
450PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
451 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
452}
453
454Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
455
456StorageClass PointerType::getStorageClass() const {
457 return getImpl()->storageClass;
458}
459
460void TypeExtensionVisitor::addConcrete(PointerType type) {
461 // Use this pointer type's storage class because this pointer indicates we are
462 // using the pointee type in that specific storage class.
463 std::optional<StorageClass> oldStorageClass = storage;
464 storage = type.getStorageClass();
465 add(type.getPointeeType());
466 storage = oldStorageClass;
467
468 if (auto scExts = spirv::getExtensions(type.getStorageClass()))
469 extensions.push_back(*scExts);
470}
471
472void TypeCapabilityVisitor::addConcrete(PointerType type) {
473 // Use this pointer type's storage class because this pointer indicates we are
474 // using the pointee type in that specific storage class.
475 std::optional<StorageClass> oldStorageClass = storage;
476 storage = type.getStorageClass();
477 add(type.getPointeeType());
478 storage = oldStorageClass;
479
480 if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
481 capabilities.push_back(*scCaps);
482}
483
484//===----------------------------------------------------------------------===//
485// RuntimeArrayType
486//===----------------------------------------------------------------------===//
487
489 using KeyTy = std::pair<Type, unsigned>;
490
492 const KeyTy &key) {
493 return new (allocator.allocate<RuntimeArrayTypeStorage>())
495 }
496
497 bool operator==(const KeyTy &key) const {
498 return key == KeyTy(elementType, stride);
499 }
500
502 : elementType(key.first), stride(key.second) {}
503
505 unsigned stride;
506};
507
509 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
510}
511
512RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
513 return Base::get(elementType.getContext(), elementType, stride);
514}
515
516Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
517
518unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
519
520void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
521 add(type.getElementType());
522 static constexpr auto cap = Capability::Shader;
523 capabilities.push_back(cap);
524}
525
526//===----------------------------------------------------------------------===//
527// ScalarType
528//===----------------------------------------------------------------------===//
529
531 if (auto floatType = dyn_cast<FloatType>(type)) {
532 return isValid(floatType);
533 }
534 if (auto intType = dyn_cast<IntegerType>(type)) {
535 return isValid(intType);
536 }
537 return false;
538}
539
540bool ScalarType::isValid(FloatType type) {
541 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
542}
543
544bool ScalarType::isValid(IntegerType type) {
545 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
546}
547
548void TypeExtensionVisitor::addConcrete(ScalarType type) {
549 if (isa<BFloat16Type>(type)) {
550 static constexpr auto ext = Extension::SPV_KHR_bfloat16;
551 extensions.push_back(ext);
552 }
553
554 // 8- or 16-bit integer/floating-point numbers will require extra extensions
555 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
556 // SPV_KHR_8bit_storage for more details.
557 if (!storage)
558 return;
559
560 switch (*storage) {
561 case StorageClass::PushConstant:
562 case StorageClass::StorageBuffer:
563 case StorageClass::Uniform:
564 if (type.getIntOrFloatBitWidth() == 8) {
565 static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
566 extensions.push_back(ext);
567 }
568 [[fallthrough]];
569 case StorageClass::Input:
570 case StorageClass::Output:
571 if (type.getIntOrFloatBitWidth() == 16) {
572 static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
573 extensions.push_back(ext);
574 }
575 break;
576 default:
577 break;
578 }
579}
580
581void TypeCapabilityVisitor::addConcrete(ScalarType type) {
582 unsigned bitwidth = type.getIntOrFloatBitWidth();
583
584 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
585 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
586 // SPV_KHR_8bit_storage for more details.
587
588#define STORAGE_CASE(storage, cap8, cap16) \
589 case StorageClass::storage: { \
590 if (bitwidth == 8) { \
591 static constexpr auto cap = Capability::cap8; \
592 capabilities.push_back(cap); \
593 return; \
594 } \
595 if (bitwidth == 16) { \
596 static constexpr auto cap = Capability::cap16; \
597 capabilities.push_back(cap); \
598 return; \
599 } \
600 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
601 /* storage classes. Fall through to the next section. */ \
602 } break
603
604 // This part only handles the cases where special bitwidths appearing in
605 // interface storage classes.
606 if (storage) {
607 switch (*storage) {
608 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
609 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
610 StorageBuffer16BitAccess);
611 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
612 StorageUniform16);
613 case StorageClass::Input:
614 case StorageClass::Output: {
615 if (bitwidth == 16) {
616 static constexpr auto cap = Capability::StorageInputOutput16;
617 capabilities.push_back(cap);
618 return;
619 }
620 break;
621 }
622 default:
623 break;
624 }
625 }
626#undef STORAGE_CASE
627
628 // For other non-interface storage classes, require a different set of
629 // capabilities for special bitwidths.
630
631#define WIDTH_CASE(type, width) \
632 case width: { \
633 static constexpr auto cap = Capability::type##width; \
634 capabilities.push_back(cap); \
635 } break
636
637 if (auto intType = dyn_cast<IntegerType>(type)) {
638 switch (bitwidth) {
639 WIDTH_CASE(Int, 8);
640 WIDTH_CASE(Int, 16);
641 WIDTH_CASE(Int, 64);
642 case 1:
643 case 32:
644 break;
645 default:
646 llvm_unreachable("invalid bitwidth to getCapabilities");
647 }
648 } else {
649 assert(isa<FloatType>(type));
650 switch (bitwidth) {
651 case 16: {
652 if (isa<BFloat16Type>(type)) {
653 static constexpr auto cap = Capability::BFloat16TypeKHR;
654 capabilities.push_back(cap);
655 } else {
656 static constexpr auto cap = Capability::Float16;
657 capabilities.push_back(cap);
658 }
659 break;
660 }
661 WIDTH_CASE(Float, 64);
662 case 32:
663 break;
664 default:
665 llvm_unreachable("invalid bitwidth to getCapabilities");
666 }
667 }
668
669#undef WIDTH_CASE
670}
671
672//===----------------------------------------------------------------------===//
673// SPIRVType
674//===----------------------------------------------------------------------===//
675
677 // Allow SPIR-V dialect types
678 if (isa<SPIRVDialect>(type.getDialect()))
679 return true;
680 if (isa<ScalarType>(type))
681 return true;
682 if (auto vectorType = dyn_cast<VectorType>(type))
683 return CompositeType::isValid(vectorType);
684 if (auto tensorArmType = dyn_cast<TensorArmType>(type))
685 return isa<ScalarType>(tensorArmType.getElementType());
686 return false;
687}
688
690 return isIntOrFloat() || isa<VectorType>(*this);
691}
692
694 std::optional<StorageClass> storage) {
695 TypeExtensionVisitor{extensions, storage}.add(*this);
696}
697
700 std::optional<StorageClass> storage) {
701 TypeCapabilityVisitor{capabilities, storage}.add(*this);
702}
703
704std::optional<int64_t> SPIRVType::getSizeInBytes() {
706 .Case([](ScalarType type) -> std::optional<int64_t> {
707 // According to the SPIR-V spec:
708 // "There is no physical size or bit pattern defined for values with
709 // boolean type. If they are stored (in conjunction with OpVariable),
710 // they can only be used with logical addressing operations, not
711 // physical, and only with non-externally visible shader Storage
712 // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
713 // Output."
714 int64_t bitWidth = type.getIntOrFloatBitWidth();
715 if (bitWidth == 1)
716 return std::nullopt;
717 return bitWidth / 8;
718 })
719 .Case([](ArrayType type) -> std::optional<int64_t> {
720 // Since array type may have an explicit stride declaration (in bytes),
721 // we also include it in the calculation.
722 auto elementType = cast<SPIRVType>(type.getElementType());
723 if (std::optional<int64_t> size = elementType.getSizeInBytes())
724 return (*size + type.getArrayStride()) * type.getNumElements();
725 return std::nullopt;
726 })
727 .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
728 if (std::optional<int64_t> elementSize =
729 cast<ScalarType>(type.getElementType()).getSizeInBytes())
730 return *elementSize * type.getNumElements();
731 return std::nullopt;
732 })
733 .Default(std::nullopt);
734}
735
736//===----------------------------------------------------------------------===//
737// SampledImageType
738//===----------------------------------------------------------------------===//
740 using KeyTy = Type;
741
743
744 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
745
747 const KeyTy &key) {
748 return new (allocator.allocate<SampledImageTypeStorage>())
750 }
751
753};
754
756 return Base::get(imageType.getContext(), imageType);
757}
758
764
765Type SampledImageType::getImageType() const { return getImpl()->imageType; }
766
767LogicalResult
769 Type imageType) {
770 auto image = dyn_cast<ImageType>(imageType);
771 if (!image)
772 return emitError() << "expected image type";
773
774 // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
775 // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
776 // ("3.3.6. Type-Declaration Instructions")
777 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
778 return emitError() << "Dim must not be SubpassData or Buffer";
779
780 return success();
781}
782
783//===----------------------------------------------------------------------===//
784// StructType
785//===----------------------------------------------------------------------===//
786
787/// Type storage for SPIR-V structure types:
788///
789/// Structures are uniqued using:
790/// - for identified structs:
791/// - a string identifier;
792/// - for literal structs:
793/// - a list of member types;
794/// - a list of member offset info;
795/// - a list of member decoration info;
796/// - a list of struct decoration info.
797///
798/// Identified structures only have a mutable component consisting of:
799/// - a list of member types;
800/// - a list of member offset info;
801/// - a list of member decoration info;
802/// - a list of struct decoration info.
804 /// Construct a storage object for an identified struct type. A struct type
805 /// associated with such storage must call StructType::trySetBody(...) later
806 /// in order to mutate the storage object providing the actual content.
812
813 /// Construct a storage object for a literal struct type. A struct type
814 /// associated with such storage is immutable.
826
827 /// A storage key is divided into 2 parts:
828 /// - for identified structs:
829 /// - a StringRef representing the struct identifier;
830 /// - for literal structs:
831 /// - an ArrayRef<Type> for member types;
832 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
833 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
834 /// info;
835 /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
836 /// info.
837 ///
838 /// An identified struct type is uniqued only by the first part (field 0)
839 /// of the key.
840 ///
841 /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
842 /// and 4) of the key. The identifier field (field 0) must be empty.
843 using KeyTy =
844 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
847
848 /// For identified structs, return true if the given key contains the same
849 /// identifier.
850 ///
851 /// For literal structs, return true if the given key contains a matching list
852 /// of member types + offset info + decoration info.
853 bool operator==(const KeyTy &key) const {
854 if (isIdentified()) {
855 // Identified types are uniqued by their identifier.
856 return getIdentifier() == std::get<0>(key);
857 }
858
859 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
861 }
862
863 /// If the given key contains a non-empty identifier, this method constructs
864 /// an identified struct and leaves the rest of the struct type data to be set
865 /// through a later call to StructType::trySetBody(...).
866 ///
867 /// If, on the other hand, the key contains an empty identifier, a literal
868 /// struct is constructed using the other fields of the key.
870 const KeyTy &key) {
871 StringRef keyIdentifier = std::get<0>(key);
872
873 if (!keyIdentifier.empty()) {
874 StringRef identifier = allocator.copyInto(keyIdentifier);
875
876 // Identified StructType body/members will be set through trySetBody(...)
877 // later.
878 return new (allocator.allocate<StructTypeStorage>())
880 }
881
882 ArrayRef<Type> keyTypes = std::get<1>(key);
883
884 // Copy the member type and layout information into the bump pointer
885 const Type *typesList = nullptr;
886 if (!keyTypes.empty()) {
887 typesList = allocator.copyInto(keyTypes).data();
888 }
889
890 const StructType::OffsetInfo *offsetInfoList = nullptr;
891 if (!std::get<2>(key).empty()) {
892 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
893 assert(keyOffsetInfo.size() == keyTypes.size() &&
894 "size of offset information must be same as the size of number of "
895 "elements");
896 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
897 }
898
899 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
900 unsigned numMemberDecorations = 0;
901 if (!std::get<3>(key).empty()) {
902 auto keyMemberDecorations = std::get<3>(key);
903 numMemberDecorations = keyMemberDecorations.size();
904 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
905 }
906
907 const StructType::StructDecorationInfo *structDecorationList = nullptr;
908 unsigned numStructDecorations = 0;
909 if (!std::get<4>(key).empty()) {
910 auto keyStructDecorations = std::get<4>(key);
911 numStructDecorations = keyStructDecorations.size();
912 structDecorationList = allocator.copyInto(keyStructDecorations).data();
913 }
914
915 return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
916 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
917 memberDecorationList, numStructDecorations, structDecorationList);
918 }
919
923
930
938
945
946 StringRef getIdentifier() const { return identifier; }
947
948 bool isIdentified() const { return !identifier.empty(); }
949
950 /// Sets the struct type content for identified structs. Calling this method
951 /// is only valid for identified structs.
952 ///
953 /// Fails under the following conditions:
954 /// - If called for a literal struct;
955 /// - If called for an identified struct whose body was set before (through a
956 /// call to this method) but with different contents from the passed
957 /// arguments.
958 LogicalResult
959 mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
960 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
961 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
962 ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
963 if (!isIdentified())
964 return failure();
965
966 if (memberTypesAndIsBodySet.getInt() &&
967 (getMemberTypes() != structMemberTypes ||
968 getOffsetInfo() != structOffsetInfo ||
969 getMemberDecorationsInfo() != structMemberDecorationInfo ||
970 getStructDecorationsInfo() != structDecorationInfo))
971 return failure();
972
973 memberTypesAndIsBodySet.setInt(true);
974 numMembers = structMemberTypes.size();
975
976 // Copy the member type and layout information into the bump pointer.
977 if (!structMemberTypes.empty())
978 memberTypesAndIsBodySet.setPointer(
979 allocator.copyInto(structMemberTypes).data());
980
981 if (!structOffsetInfo.empty()) {
982 assert(structOffsetInfo.size() == structMemberTypes.size() &&
983 "size of offset information must be same as the size of number of "
984 "elements");
985 offsetInfo = allocator.copyInto(structOffsetInfo).data();
986 }
987
988 if (!structMemberDecorationInfo.empty()) {
989 numMemberDecorations = structMemberDecorationInfo.size();
991 allocator.copyInto(structMemberDecorationInfo).data();
992 }
993
994 if (!structDecorationInfo.empty()) {
995 numStructDecorations = structDecorationInfo.size();
996 structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
997 }
998
999 return success();
1000 }
1001
1002 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1004 unsigned numMembers;
1009 StringRef identifier;
1010};
1011
1017 assert(!memberTypes.empty() && "Struct needs at least one member type");
1018 // Sort the decorations.
1020 memberDecorations);
1021 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1022 sortedMemberDecorations.end());
1024 structDecorations);
1025 llvm::array_pod_sort(sortedStructDecorations.begin(),
1026 sortedStructDecorations.end());
1027
1028 return Base::get(memberTypes.vec().front().getContext(),
1029 /*identifier=*/StringRef(), memberTypes, offsetInfo,
1030 sortedMemberDecorations, sortedStructDecorations);
1031}
1032
1034 StringRef identifier) {
1035 assert(!identifier.empty() &&
1036 "StructType identifier must be non-empty string");
1037
1038 return Base::get(context, identifier, ArrayRef<Type>(),
1042}
1043
1044StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1045 StructType newStructType = Base::get(
1046 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1049 // Set an empty body in case this is a identified struct.
1050 if (newStructType.isIdentified() &&
1051 failed(newStructType.trySetBody(
1055 return StructType();
1056
1057 return newStructType;
1058}
1059
1060StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1061
1062bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1063
1064unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1065
1067 assert(getNumElements() > index && "member index out of range");
1068 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1069}
1070
1072 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1073 getNumElements());
1074}
1075
1076bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1077
1078bool StructType::hasDecoration(spirv::Decoration decoration) const {
1080 getImpl()->getStructDecorationsInfo())
1081 if (info.decoration == decoration)
1082 return true;
1083
1084 return false;
1085}
1086
1087uint64_t StructType::getMemberOffset(unsigned index) const {
1088 assert(getNumElements() > index && "member index out of range");
1089 return getImpl()->offsetInfo[index];
1090}
1091
1094 const {
1095 memberDecorations.clear();
1096 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1097 memberDecorations.append(implMemberDecorations.begin(),
1098 implMemberDecorations.end());
1099}
1100
1102 unsigned index,
1104 assert(getNumElements() > index && "member index out of range");
1105 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1106 decorationsInfo.clear();
1107 for (const auto &memberDecoration : memberDecorations) {
1108 if (memberDecoration.memberIndex == index) {
1109 decorationsInfo.push_back(memberDecoration);
1110 }
1111 if (memberDecoration.memberIndex > index) {
1112 // Early exit since the decorations are stored sorted.
1113 return;
1114 }
1115 }
1116}
1117
1120 const {
1121 structDecorations.clear();
1122 auto implDecorations = getImpl()->getStructDecorationsInfo();
1123 structDecorations.append(implDecorations.begin(), implDecorations.end());
1124}
1125
1126LogicalResult
1128 ArrayRef<OffsetInfo> offsetInfo,
1129 ArrayRef<MemberDecorationInfo> memberDecorations,
1130 ArrayRef<StructDecorationInfo> structDecorations) {
1131 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1132 structDecorations);
1133}
1134
1135llvm::hash_code spirv::hash_value(
1136 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1137 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1138 memberDecorationInfo.decoration);
1139}
1140
1141llvm::hash_code spirv::hash_value(
1142 const StructType::StructDecorationInfo &structDecorationInfo) {
1143 return llvm::hash_value(structDecorationInfo.decoration);
1144}
1145
1146//===----------------------------------------------------------------------===//
1147// MatrixType
1148//===----------------------------------------------------------------------===//
1149
1153
1154 using KeyTy = std::tuple<Type, uint32_t>;
1155
1157 const KeyTy &key) {
1158
1159 // Initialize the memory using placement new.
1160 return new (allocator.allocate<MatrixTypeStorage>())
1161 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1162 }
1163
1164 bool operator==(const KeyTy &key) const {
1165 return key == KeyTy(columnType, columnCount);
1166 }
1167
1169 const uint32_t columnCount;
1170};
1171
1172MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1173 return Base::get(columnType.getContext(), columnType, columnCount);
1174}
1175
1177 Type columnType, uint32_t columnCount) {
1178 return Base::getChecked(emitError, columnType.getContext(), columnType,
1179 columnCount);
1180}
1181
1182LogicalResult
1184 Type columnType, uint32_t columnCount) {
1185 if (columnCount < 2 || columnCount > 4)
1186 return emitError() << "matrix can have 2, 3, or 4 columns only";
1187
1188 if (!isValidColumnType(columnType))
1189 return emitError() << "matrix columns must be vectors of floats";
1190
1191 /// The underlying vectors (columns) must be of size 2, 3, or 4
1192 ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
1193 if (columnShape.size() != 1)
1194 return emitError() << "matrix columns must be 1D vectors";
1195
1196 if (columnShape[0] < 2 || columnShape[0] > 4)
1197 return emitError() << "matrix columns must be of size 2, 3, or 4";
1198
1199 return success();
1200}
1201
1202/// Returns true if the matrix elements are vectors of float elements
1204 if (auto vectorType = dyn_cast<VectorType>(columnType)) {
1205 if (isa<FloatType>(vectorType.getElementType()))
1206 return true;
1207 }
1208 return false;
1209}
1210
1211Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1212
1214 return cast<VectorType>(getImpl()->columnType).getElementType();
1215}
1216
1217unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1218
1219unsigned MatrixType::getNumRows() const {
1220 return cast<VectorType>(getImpl()->columnType).getShape()[0];
1221}
1222
1224 return (getImpl()->columnCount) * getNumRows();
1225}
1226
1227void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1228 add(type.getColumnType());
1229 static constexpr auto cap = Capability::Matrix;
1230 capabilities.push_back(cap);
1231}
1232
1233//===----------------------------------------------------------------------===//
1234// TensorArmType
1235//===----------------------------------------------------------------------===//
1236
1238 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1239
1241 const KeyTy &key) {
1242 auto [shape, elementType] = key;
1243 shape = allocator.copyInto(shape);
1244 return new (allocator.allocate<TensorArmTypeStorage>())
1246 }
1247
1248 static llvm::hash_code hashKey(const KeyTy &key) {
1249 auto [shape, elementType] = key;
1250 return llvm::hash_combine(shape, elementType);
1251 }
1252
1253 bool operator==(const KeyTy &key) const {
1254 return key == KeyTy(shape, elementType);
1255 }
1256
1259
1262};
1263
1265 return Base::get(elementType.getContext(), shape, elementType);
1266}
1267
1269 Type elementType) const {
1270 return TensorArmType::get(shape.value_or(getShape()), elementType);
1271}
1272
1273Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1275
1276void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1277 add(type.getElementType());
1278 static constexpr auto ext = Extension::SPV_ARM_tensors;
1279 extensions.push_back(ext);
1280}
1281
1282void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1283 add(type.getElementType());
1284 static constexpr auto cap = Capability::TensorsARM;
1285 capabilities.push_back(cap);
1286}
1287
1288LogicalResult
1290 ArrayRef<int64_t> shape, Type elementType) {
1291 if (llvm::is_contained(shape, 0))
1292 return emitError() << "arm.tensor do not support dimensions = 0";
1293 if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1294 llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1295 return emitError()
1296 << "arm.tensor shape dimensions must be either fully dynamic or "
1297 "completed shaped";
1298 return success();
1299}
1300
1301//===----------------------------------------------------------------------===//
1302// SPIR-V Dialect
1303//===----------------------------------------------------------------------===//
1304
1305void SPIRVDialect::registerTypes() {
1308}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
#define add(a, b)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, uint32_t > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
ArrayRef< Type > getMemberTypes() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const