MLIR 23.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/ErrorHandling.h"
21
22#include <cstdint>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::spirv;
27
28namespace {
29// Helper function to collect extensions implied by a type by visiting all its
30// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31//
32// Serves as the source-of-truth for type extension information. All extension
33// logic should be added to this class, while the
34// `SPIRVType::getExtensions` function should not handle extension-related logic
35// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36class TypeExtensionVisitor {
37public:
38 TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39 std::optional<StorageClass> storage)
40 : extensions(extensions), storage(storage) {}
41
42 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43 // as seen and dispatches to the right concrete `.add` function.
44 void add(SPIRVType type) {
45 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46 return;
47
49 .Case<CooperativeMatrixType, ImageType, PointerType, ScalarType,
50 TensorArmType>(
51 [this](auto concreteType) { addConcrete(concreteType); })
52 .Case<ArrayType, MatrixType, RuntimeArrayType, VectorType>(
53 [this](auto concreteType) { add(concreteType.getElementType()); })
54 .Case([this](SampledImageType concreteType) {
55 add(concreteType.getImageType());
56 })
57 .Case([this](StructType concreteType) {
58 for (Type elementType : concreteType.getElementTypes())
59 add(elementType);
60 })
61 .Case<SamplerType, NamedBarrierType>([](auto) { /* no extensions */ })
62 .DefaultUnreachable("Unhandled type");
63 }
64
65 void add(Type type) { add(cast<SPIRVType>(type)); }
66
67private:
68 // Types that add unique extensions.
69 void addConcrete(CooperativeMatrixType type);
70 void addConcrete(ImageType type);
71 void addConcrete(PointerType type);
72 void addConcrete(ScalarType type);
73 void addConcrete(TensorArmType type);
74
75 template <Extension... Es>
76 void pushExts() {
77 static constexpr Extension exts[] = {Es...};
78 extensions.push_back(exts);
79 }
80
82 std::optional<StorageClass> storage;
83 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
84};
85
86// Helper function to collect capabilities implied by a type by visiting all its
87// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
88//
89// Serves as the source-of-truth for type capability information. All capability
90// logic should be added to this class, while the
91// `SPIRVType::getCapabilities` function should not handle capability-related
92// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
93class TypeCapabilityVisitor {
94public:
95 TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
96 std::optional<StorageClass> storage)
97 : capabilities(capabilities), storage(storage) {}
98
99 // Main visitor entry point. Adds all extensions to the vector. Saves `type`
100 // as seen and dispatches to the right concrete `.add` function.
101 void add(SPIRVType type) {
102 if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
103 return;
104
106 .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
107 RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
108 [this](auto concreteType) { addConcrete(concreteType); })
109 .Case([this](ArrayType concreteType) {
110 add(concreteType.getElementType());
111 })
112 .Case([this](SampledImageType concreteType) {
113 add(concreteType.getImageType());
114 })
115 .Case([this](StructType concreteType) {
116 for (Type elementType : concreteType.getElementTypes())
117 add(elementType);
118 })
119 .Case([](SamplerType) { /* no capabilities */ })
120 .Case(
121 [this](NamedBarrierType) { pushCaps<Capability::NamedBarrier>(); })
122 .DefaultUnreachable("Unhandled type");
123 }
124
125 void add(Type type) { add(cast<SPIRVType>(type)); }
126
127private:
128 // Types that add unique extensions.
129 void addConcrete(CooperativeMatrixType type);
130 void addConcrete(ImageType type);
131 void addConcrete(MatrixType type);
132 void addConcrete(PointerType type);
133 void addConcrete(RuntimeArrayType type);
134 void addConcrete(ScalarType type);
135 void addConcrete(TensorArmType type);
136 void addConcrete(VectorType type);
137
138 template <Capability... Cs>
139 void pushCaps() {
140 static constexpr Capability caps[] = {Cs...};
141 capabilities.push_back(caps);
142 }
143
145 std::optional<StorageClass> storage;
146 llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
147};
148
149} // namespace
150
151//===----------------------------------------------------------------------===//
152// ArrayType
153//===----------------------------------------------------------------------===//
154
156 using KeyTy = std::tuple<Type, unsigned, unsigned>;
157
159 const KeyTy &key) {
160 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
161 }
162
163 bool operator==(const KeyTy &key) const {
164 return key == KeyTy(elementType, elementCount, stride);
165 }
166
168 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
169 stride(std::get<2>(key)) {}
170
172 unsigned elementCount;
173 unsigned stride;
174};
175
176ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
177 assert(elementCount && "ArrayType needs at least one element");
178 return Base::get(elementType.getContext(), elementType, elementCount,
179 /*stride=*/0);
180}
181
182ArrayType ArrayType::get(Type elementType, unsigned elementCount,
183 unsigned stride) {
184 assert(elementCount && "ArrayType needs at least one element");
185 return Base::get(elementType.getContext(), elementType, elementCount, stride);
186}
187
188unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
189
190Type ArrayType::getElementType() const { return getImpl()->elementType; }
191
192unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
193
194//===----------------------------------------------------------------------===//
195// CompositeType
196//===----------------------------------------------------------------------===//
197
199 if (auto vectorType = dyn_cast<VectorType>(type))
200 return isValid(vectorType);
203 type);
204}
205
206bool CompositeType::isValid(VectorType type) {
207 return type.getRank() == 1 &&
208 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
209 (isa<ScalarType>(type.getElementType()) ||
210 isa<PointerType>(type.getElementType()));
211}
212
214 return TypeSwitch<Type, Type>(*this)
216 TensorArmType>([](auto type) { return type.getElementType(); })
217 .Case([](MatrixType type) { return type.getColumnType(); })
218 .Case([index](StructType type) { return type.getElementType(index); })
219 .DefaultUnreachable("Invalid composite type");
220}
221
224 .Case<ArrayType, StructType, TensorArmType, VectorType>(
225 [](auto type) { return type.getNumElements(); })
226 .Case([](MatrixType type) { return type.getNumColumns(); })
227 .DefaultUnreachable("Invalid type for number of elements query");
228}
229
231 return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
232}
233
234void TypeCapabilityVisitor::addConcrete(VectorType type) {
235 add(type.getElementType());
236
237 int64_t vecSize = type.getNumElements();
238 if (vecSize == 8 || vecSize == 16)
239 pushCaps<Capability::Vector16>();
240}
241
242//===----------------------------------------------------------------------===//
243// CooperativeMatrixType
244//===----------------------------------------------------------------------===//
245
247 // In the specification dimensions of the Cooperative Matrix are 32-bit
248 // integers --- the initial implementation kept those values as such. However,
249 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
250 // as 32-bits and expose it as int64_t through `getShape`, however, this
251 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
252 // 32-bits integers would require an extra logic and storage. So, we diverge
253 // from the spec and internally represent the dimensions as 64-bit integers,
254 // so we can easily return an `ArrayRef` from `getShape` without any extra
255 // logic. Alternatively, we could store both rows and columns (both 32-bits)
256 // and shape (64-bits), assigning rows and columns to shape whenever
257 // `getShape` is called. This would be at the cost of extra logic and storage.
258 // Note: Because `ArrayRef` is returned we cannot construct an object in
259 // `getShape` on the fly.
260 using KeyTy =
261 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
262
264 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
265 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
267 }
268
269 bool operator==(const KeyTy &key) const {
270 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
271 }
272
274 : elementType(std::get<0>(key)),
275 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
276 use(std::get<4>(key)) {}
277
279 // [#rows, #columns]
280 std::array<int64_t, 2> shape;
281 Scope scope;
282 CooperativeMatrixUseKHR use;
283};
284
286 uint32_t rows,
287 uint32_t columns, Scope scope,
288 CooperativeMatrixUseKHR use) {
289 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
290 use);
291}
292
294 return getImpl()->elementType;
295}
296
298 assert(getImpl()->shape[0] != ShapedType::kDynamic);
299 return static_cast<uint32_t>(getImpl()->shape[0]);
300}
301
303 assert(getImpl()->shape[1] != ShapedType::kDynamic);
304 return static_cast<uint32_t>(getImpl()->shape[1]);
305}
306
310
311Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
312
313CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
314 return getImpl()->use;
315}
316
317void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
318 add(type.getElementType());
319 pushExts<Extension::SPV_KHR_cooperative_matrix>();
320}
321
322void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
323 Type elementType = type.getElementType();
324 add(elementType);
325 pushCaps<Capability::CooperativeMatrixKHR>();
326 if (elementType.isBF16())
327 pushCaps<Capability::BFloat16CooperativeMatrixKHR>();
328 if (elementType.isF8E4M3FN() || elementType.isF8E5M2())
329 pushCaps<Capability::Float8CooperativeMatrixEXT>();
330}
331
332//===----------------------------------------------------------------------===//
333// ImageType
334//===----------------------------------------------------------------------===//
335
336template <typename T>
337static constexpr unsigned getNumBits() {
338 return 0;
339}
340template <>
341constexpr unsigned getNumBits<Dim>() {
342 static_assert((1 << 3) > getMaxEnumValForDim(),
343 "Not enough bits to encode Dim value");
344 return 3;
345}
346template <>
347constexpr unsigned getNumBits<ImageDepthInfo>() {
348 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
349 "Not enough bits to encode ImageDepthInfo value");
350 return 2;
351}
352template <>
353constexpr unsigned getNumBits<ImageArrayedInfo>() {
354 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
355 "Not enough bits to encode ImageArrayedInfo value");
356 return 1;
357}
358template <>
359constexpr unsigned getNumBits<ImageSamplingInfo>() {
360 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
361 "Not enough bits to encode ImageSamplingInfo value");
362 return 1;
363}
364template <>
366 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
367 "Not enough bits to encode ImageSamplerUseInfo value");
368 return 2;
369}
370template <>
371constexpr unsigned getNumBits<ImageFormat>() {
372 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
373 "Not enough bits to encode ImageFormat value");
374 return 6;
375}
376
378public:
379 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
380 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
381
383 const KeyTy &key) {
384 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
385 }
386
387 bool operator==(const KeyTy &key) const {
390 }
391
393 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
394 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
395 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
396 format(std::get<6>(key)) {}
397
405};
406
408ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
409 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
410 value) {
411 return Base::get(std::get<0>(value).getContext(), value);
412}
413
414Type ImageType::getElementType() const { return getImpl()->elementType; }
415
416Dim ImageType::getDim() const { return getImpl()->dim; }
417
418ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
419
420ImageArrayedInfo ImageType::getArrayedInfo() const {
421 return getImpl()->arrayedInfo;
422}
423
424ImageSamplingInfo ImageType::getSamplingInfo() const {
425 return getImpl()->samplingInfo;
426}
427
428ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
429 return getImpl()->samplerUseInfo;
430}
431
432ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
433
434void TypeExtensionVisitor::addConcrete(ImageType type) {
435 // OpTypeImage with a 64-bit integer Sampled Type requires the
436 // SPV_EXT_shader_image_int64 extension (companion to Int64ImageEXT).
437 if (auto intTy = dyn_cast<IntegerType>(type.getElementType());
438 intTy && intTy.getWidth() == 64)
439 pushExts<Extension::SPV_EXT_shader_image_int64>();
440 add(type.getElementType());
441}
442
443void TypeCapabilityVisitor::addConcrete(ImageType type) {
444 // Capability requirements for OpTypeImage are determined jointly by Dim,
445 // Sampled, MS, and Arrayed - see the SPIR-V spec's "Capabilities" column on
446 // OpTypeImage.
447 Dim dim = type.getDim();
448 bool isMultisampled =
449 type.getSamplingInfo() == ImageSamplingInfo::MultiSampled;
450 bool isArrayed = type.getArrayedInfo() == ImageArrayedInfo::Arrayed;
451 ImageSamplerUseInfo sampler = type.getSamplerUseInfo();
452 bool noSampler = sampler == ImageSamplerUseInfo::NoSampler;
453 bool needSampler = sampler == ImageSamplerUseInfo::NeedSampler;
454
455 switch (dim) {
456 case Dim::Dim1D:
457 if (needSampler)
458 pushCaps<Capability::Sampled1D>();
459 else if (noSampler)
460 pushCaps<Capability::Image1D>();
461 else
462 pushCaps<Capability::Image1D, Capability::Sampled1D>();
463 break;
464 case Dim::Dim2D:
465 if (isMultisampled && noSampler)
466 pushCaps<Capability::StorageImageMultisample>();
467 if (isMultisampled && isArrayed)
468 pushCaps<Capability::ImageMSArray>();
469 break;
470 case Dim::Dim3D:
471 break;
472 case Dim::Cube:
473 pushCaps<Capability::Shader>();
474 if (isArrayed)
475 pushCaps<Capability::ImageCubeArray>();
476 break;
477 case Dim::Rect:
478 pushCaps<Capability::ImageRect, Capability::SampledRect>();
479 break;
480 case Dim::Buffer:
481 if (needSampler)
482 pushCaps<Capability::SampledBuffer>();
483 else if (noSampler)
484 pushCaps<Capability::ImageBuffer>();
485 else
486 pushCaps<Capability::ImageBuffer, Capability::SampledBuffer>();
487 break;
488 case Dim::SubpassData:
489 pushCaps<Capability::InputAttachment>();
490 break;
491 }
492
493 if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
494 capabilities.push_back(*fmtCaps);
495
496 // OpTypeImage with a 64-bit integer Sampled Type requires Int64ImageEXT.
497 if (auto intTy = dyn_cast<IntegerType>(type.getElementType());
498 intTy && intTy.getWidth() == 64)
499 pushCaps<Capability::Int64ImageEXT>();
500
501 add(type.getElementType());
502}
503
504//===----------------------------------------------------------------------===//
505// PointerType
506//===----------------------------------------------------------------------===//
507
509 // (Type, StorageClass) as the key: Type stored in this struct, and
510 // StorageClass stored as TypeStorage's subclass data.
511 using KeyTy = std::pair<Type, StorageClass>;
512
514 const KeyTy &key) {
515 return new (allocator.allocate<PointerTypeStorage>())
517 }
518
519 bool operator==(const KeyTy &key) const {
520 return key == KeyTy(pointeeType, storageClass);
521 }
522
524 : pointeeType(key.first), storageClass(key.second) {}
525
527 StorageClass storageClass;
528};
529
530PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
531 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
532}
533
534Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
535
536StorageClass PointerType::getStorageClass() const {
537 return getImpl()->storageClass;
538}
539
540void TypeExtensionVisitor::addConcrete(PointerType type) {
541 // Use this pointer type's storage class because this pointer indicates we are
542 // using the pointee type in that specific storage class.
543 std::optional<StorageClass> oldStorageClass = storage;
544 storage = type.getStorageClass();
545 add(type.getPointeeType());
546 storage = oldStorageClass;
547
548 if (auto scExts = spirv::getExtensions(type.getStorageClass()))
549 extensions.push_back(*scExts);
550}
551
552void TypeCapabilityVisitor::addConcrete(PointerType type) {
553 // Use this pointer type's storage class because this pointer indicates we are
554 // using the pointee type in that specific storage class.
555 std::optional<StorageClass> oldStorageClass = storage;
556 storage = type.getStorageClass();
557 add(type.getPointeeType());
558 storage = oldStorageClass;
559
560 if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
561 capabilities.push_back(*scCaps);
562}
563
564//===----------------------------------------------------------------------===//
565// RuntimeArrayType
566//===----------------------------------------------------------------------===//
567
569 using KeyTy = std::pair<Type, unsigned>;
570
572 const KeyTy &key) {
573 return new (allocator.allocate<RuntimeArrayTypeStorage>())
575 }
576
577 bool operator==(const KeyTy &key) const {
578 return key == KeyTy(elementType, stride);
579 }
580
582 : elementType(key.first), stride(key.second) {}
583
585 unsigned stride;
586};
587
589 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
590}
591
592RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
593 return Base::get(elementType.getContext(), elementType, stride);
594}
595
596Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
597
598unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
599
600void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
601 add(type.getElementType());
602 pushCaps<Capability::Shader>();
603}
604
605//===----------------------------------------------------------------------===//
606// ScalarType
607//===----------------------------------------------------------------------===//
608
610 if (auto floatType = dyn_cast<FloatType>(type)) {
611 return isValid(floatType);
612 }
613 if (auto intType = dyn_cast<IntegerType>(type)) {
614 return isValid(intType);
615 }
616 return false;
617}
618
619bool ScalarType::isValid(FloatType type) {
620 if (type.isF8E4M3FN() || type.isF8E5M2())
621 return true;
622 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
623}
624
625bool ScalarType::isValid(IntegerType type) {
626 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
627}
628
629void TypeExtensionVisitor::addConcrete(ScalarType type) {
630 if (type.isBF16())
631 pushExts<Extension::SPV_KHR_bfloat16>();
632
633 if (type.isF8E4M3FN() || type.isF8E5M2())
634 pushExts<Extension::SPV_EXT_float8>();
635
636 // 8- or 16-bit integer/floating-point numbers will require extra extensions
637 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
638 // SPV_KHR_8bit_storage for more details.
639 if (!storage)
640 return;
641
642 switch (*storage) {
643 case StorageClass::PushConstant:
644 case StorageClass::StorageBuffer:
645 case StorageClass::Uniform:
646 if (type.getIntOrFloatBitWidth() == 8)
647 pushExts<Extension::SPV_KHR_8bit_storage>();
648 [[fallthrough]];
649 case StorageClass::Input:
650 case StorageClass::Output:
651 if (type.getIntOrFloatBitWidth() == 16)
652 pushExts<Extension::SPV_KHR_16bit_storage>();
653 break;
654 default:
655 break;
656 }
657}
658
659void TypeCapabilityVisitor::addConcrete(ScalarType type) {
660 unsigned bitwidth = type.getIntOrFloatBitWidth();
661
662 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
663 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
664 // SPV_KHR_8bit_storage for more details.
665
666#define STORAGE_CASE(storage, cap8, cap16) \
667 case StorageClass::storage: { \
668 if (bitwidth == 8) { \
669 pushCaps<Capability::cap8>(); \
670 return; \
671 } \
672 if (bitwidth == 16) { \
673 pushCaps<Capability::cap16>(); \
674 return; \
675 } \
676 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
677 /* storage classes. Fall through to the next section. */ \
678 } break
679
680 // This part only handles the cases where special bitwidths appearing in
681 // interface storage classes.
682 if (storage) {
683 switch (*storage) {
684 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
685 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
686 StorageBuffer16BitAccess);
687 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
688 StorageUniform16);
689 case StorageClass::Input:
690 case StorageClass::Output: {
691 if (bitwidth == 16) {
692 pushCaps<Capability::StorageInputOutput16>();
693 return;
694 }
695 break;
696 }
697 default:
698 break;
699 }
700 }
701#undef STORAGE_CASE
702
703 // For other non-interface storage classes, require a different set of
704 // capabilities for special bitwidths.
705
706#define WIDTH_CASE(type, width) \
707 case width: \
708 pushCaps<Capability::type##width>(); \
709 break
710
711 if (auto intType = dyn_cast<IntegerType>(type)) {
712 switch (bitwidth) {
713 WIDTH_CASE(Int, 8);
714 WIDTH_CASE(Int, 16);
715 WIDTH_CASE(Int, 64);
716 case 1:
717 case 32:
718 break;
719 default:
720 llvm_unreachable("invalid bitwidth to getCapabilities");
721 }
722 } else {
723 assert(isa<FloatType>(type));
724 switch (bitwidth) {
725 case 8: {
726 if (type.isF8E4M3FN() || type.isF8E5M2())
727 pushCaps<Capability::Float8EXT>();
728 else
729 llvm_unreachable("invalid 8-bit float type to getCapabilities");
730 break;
731 }
732 case 16: {
733 if (type.isBF16())
734 pushCaps<Capability::BFloat16TypeKHR>();
735 else
736 pushCaps<Capability::Float16>();
737 break;
738 }
739 WIDTH_CASE(Float, 64);
740 case 32:
741 break;
742 default:
743 llvm_unreachable("invalid bitwidth to getCapabilities");
744 }
745 }
746
747#undef WIDTH_CASE
748}
749
750//===----------------------------------------------------------------------===//
751// SPIRVType
752//===----------------------------------------------------------------------===//
753
755 // Allow SPIR-V dialect types
756 if (isa<SPIRVDialect>(type.getDialect()))
757 return true;
758 if (isa<ScalarType>(type))
759 return true;
760 if (auto vectorType = dyn_cast<VectorType>(type))
761 return CompositeType::isValid(vectorType);
762 if (auto tensorArmType = dyn_cast<TensorArmType>(type))
763 return isa<ScalarType>(tensorArmType.getElementType());
764 return false;
765}
766
768 return isIntOrFloat() || isa<VectorType>(*this);
769}
770
772 std::optional<StorageClass> storage) {
773 TypeExtensionVisitor{extensions, storage}.add(*this);
774}
775
778 std::optional<StorageClass> storage) {
779 TypeCapabilityVisitor{capabilities, storage}.add(*this);
780}
781
782std::optional<int64_t> SPIRVType::getSizeInBytes() {
784 .Case([](ScalarType type) -> std::optional<int64_t> {
785 // According to the SPIR-V spec:
786 // "There is no physical size or bit pattern defined for values with
787 // boolean type. If they are stored (in conjunction with OpVariable),
788 // they can only be used with logical addressing operations, not
789 // physical, and only with non-externally visible shader Storage
790 // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
791 // Output."
792 int64_t bitWidth = type.getIntOrFloatBitWidth();
793 if (bitWidth == 1)
794 return std::nullopt;
795 return bitWidth / 8;
796 })
797 .Case([](ArrayType type) -> std::optional<int64_t> {
798 // Since array type may have an explicit stride declaration (in bytes),
799 // we also include it in the calculation.
800 auto elementType = cast<SPIRVType>(type.getElementType());
801 if (std::optional<int64_t> size = elementType.getSizeInBytes())
802 return (*size + type.getArrayStride()) * type.getNumElements();
803 return std::nullopt;
804 })
805 .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
806 if (std::optional<int64_t> elementSize =
807 cast<ScalarType>(type.getElementType()).getSizeInBytes())
808 return *elementSize * type.getNumElements();
809 return std::nullopt;
810 })
811 .Default(std::nullopt);
812}
813
814//===----------------------------------------------------------------------===//
815// SampledImageType
816//===----------------------------------------------------------------------===//
818 using KeyTy = Type;
819
821
822 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
823
825 const KeyTy &key) {
826 return new (allocator.allocate<SampledImageTypeStorage>())
828 }
829
831};
832
834 return Base::get(imageType.getContext(), imageType);
835}
836
842
843Type SampledImageType::getImageType() const { return getImpl()->imageType; }
844
845LogicalResult
847 Type imageType) {
848 auto image = dyn_cast<ImageType>(imageType);
849 if (!image)
850 return emitError() << "expected image type";
851
852 // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
853 // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
854 // ("3.3.6. Type-Declaration Instructions")
855 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
856 return emitError() << "Dim must not be SubpassData or Buffer";
857
858 return success();
859}
860
861//===----------------------------------------------------------------------===//
862// SamplerType
863//===----------------------------------------------------------------------===//
864
866 return Base::get(context);
867}
868
869//===----------------------------------------------------------------------===//
870// NamedBarrierType
871//===----------------------------------------------------------------------===//
872
874 return Base::get(context);
875}
876
877//===----------------------------------------------------------------------===//
878// StructType
879//===----------------------------------------------------------------------===//
880
881/// Type storage for SPIR-V structure types:
882///
883/// Structures are uniqued using:
884/// - for identified structs:
885/// - a string identifier;
886/// - for literal structs:
887/// - a list of member types;
888/// - a list of member offset info;
889/// - a list of member decoration info;
890/// - a list of struct decoration info.
891///
892/// Identified structures only have a mutable component consisting of:
893/// - a list of member types;
894/// - a list of member offset info;
895/// - a list of member decoration info;
896/// - a list of struct decoration info.
898 /// Construct a storage object for an identified struct type. A struct type
899 /// associated with such storage must call StructType::trySetBody(...) later
900 /// in order to mutate the storage object providing the actual content.
906
907 /// Construct a storage object for a literal struct type. A struct type
908 /// associated with such storage is immutable.
920
921 /// A storage key is divided into 2 parts:
922 /// - for identified structs:
923 /// - a StringRef representing the struct identifier;
924 /// - for literal structs:
925 /// - an ArrayRef<Type> for member types;
926 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
927 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
928 /// info;
929 /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
930 /// info.
931 ///
932 /// An identified struct type is uniqued only by the first part (field 0)
933 /// of the key.
934 ///
935 /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
936 /// and 4) of the key. The identifier field (field 0) must be empty.
937 using KeyTy =
938 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
941
942 /// For identified structs, return true if the given key contains the same
943 /// identifier.
944 ///
945 /// For literal structs, return true if the given key contains a matching list
946 /// of member types + offset info + decoration info.
947 bool operator==(const KeyTy &key) const {
948 if (isIdentified()) {
949 // Identified types are uniqued by their identifier.
950 return getIdentifier() == std::get<0>(key);
951 }
952
953 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
955 }
956
957 /// If the given key contains a non-empty identifier, this method constructs
958 /// an identified struct and leaves the rest of the struct type data to be set
959 /// through a later call to StructType::trySetBody(...).
960 ///
961 /// If, on the other hand, the key contains an empty identifier, a literal
962 /// struct is constructed using the other fields of the key.
964 const KeyTy &key) {
965 StringRef keyIdentifier = std::get<0>(key);
966
967 if (!keyIdentifier.empty()) {
968 StringRef identifier = allocator.copyInto(keyIdentifier);
969
970 // Identified StructType body/members will be set through trySetBody(...)
971 // later.
972 return new (allocator.allocate<StructTypeStorage>())
974 }
975
976 ArrayRef<Type> keyTypes = std::get<1>(key);
977
978 // Copy the member type and layout information into the bump pointer
979 const Type *typesList = nullptr;
980 if (!keyTypes.empty()) {
981 typesList = allocator.copyInto(keyTypes).data();
982 }
983
984 const StructType::OffsetInfo *offsetInfoList = nullptr;
985 if (!std::get<2>(key).empty()) {
986 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
987 assert(keyOffsetInfo.size() == keyTypes.size() &&
988 "size of offset information must be same as the size of number of "
989 "elements");
990 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
991 }
992
993 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
994 unsigned numMemberDecorations = 0;
995 if (!std::get<3>(key).empty()) {
996 auto keyMemberDecorations = std::get<3>(key);
997 numMemberDecorations = keyMemberDecorations.size();
998 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
999 }
1000
1001 const StructType::StructDecorationInfo *structDecorationList = nullptr;
1002 unsigned numStructDecorations = 0;
1003 if (!std::get<4>(key).empty()) {
1004 auto keyStructDecorations = std::get<4>(key);
1005 numStructDecorations = keyStructDecorations.size();
1006 structDecorationList = allocator.copyInto(keyStructDecorations).data();
1007 }
1008
1009 return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
1010 keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
1011 memberDecorationList, numStructDecorations, structDecorationList);
1012 }
1013
1017
1024
1032
1039
1040 StringRef getIdentifier() const { return identifier; }
1041
1042 bool isIdentified() const { return !identifier.empty(); }
1043
1044 /// Sets the struct type content for identified structs. Calling this method
1045 /// is only valid for identified structs.
1046 ///
1047 /// Fails under the following conditions:
1048 /// - If called for a literal struct;
1049 /// - If called for an identified struct whose body was set before (through a
1050 /// call to this method) but with different contents from the passed
1051 /// arguments.
1052 LogicalResult
1053 mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
1054 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
1055 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
1056 ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
1057 if (!isIdentified())
1058 return failure();
1059
1060 if (memberTypesAndIsBodySet.getInt() &&
1061 (getMemberTypes() != structMemberTypes ||
1062 getOffsetInfo() != structOffsetInfo ||
1063 getMemberDecorationsInfo() != structMemberDecorationInfo ||
1064 getStructDecorationsInfo() != structDecorationInfo))
1065 return failure();
1066
1067 memberTypesAndIsBodySet.setInt(true);
1068 numMembers = structMemberTypes.size();
1069
1070 // Copy the member type and layout information into the bump pointer.
1071 if (!structMemberTypes.empty())
1072 memberTypesAndIsBodySet.setPointer(
1073 allocator.copyInto(structMemberTypes).data());
1074
1075 if (!structOffsetInfo.empty()) {
1076 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1077 "size of offset information must be same as the size of number of "
1078 "elements");
1079 offsetInfo = allocator.copyInto(structOffsetInfo).data();
1080 }
1081
1082 if (!structMemberDecorationInfo.empty()) {
1083 numMemberDecorations = structMemberDecorationInfo.size();
1085 allocator.copyInto(structMemberDecorationInfo).data();
1086 }
1087
1088 if (!structDecorationInfo.empty()) {
1089 numStructDecorations = structDecorationInfo.size();
1090 structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
1091 }
1092
1093 return success();
1094 }
1095
1096 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1098 unsigned numMembers;
1103 StringRef identifier;
1104};
1105
1111 assert(!memberTypes.empty() && "Struct needs at least one member type");
1112 // Sort the decorations.
1114 memberDecorations);
1115 llvm::array_pod_sort(sortedMemberDecorations.begin(),
1116 sortedMemberDecorations.end());
1118 structDecorations);
1119 llvm::array_pod_sort(sortedStructDecorations.begin(),
1120 sortedStructDecorations.end());
1121
1122 return Base::get(memberTypes.vec().front().getContext(),
1123 /*identifier=*/StringRef(), memberTypes, offsetInfo,
1124 sortedMemberDecorations, sortedStructDecorations);
1125}
1126
1128 StringRef identifier) {
1129 assert(!identifier.empty() &&
1130 "StructType identifier must be non-empty string");
1131
1132 return Base::get(context, identifier, ArrayRef<Type>(),
1136}
1137
1138StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1139 StructType newStructType = Base::get(
1140 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1143 // Set an empty body in case this is a identified struct.
1144 if (newStructType.isIdentified() &&
1145 failed(newStructType.trySetBody(
1149 return StructType();
1150
1151 return newStructType;
1152}
1153
1154StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1155
1156bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1157
1158unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1159
1161 assert(getNumElements() > index && "member index out of range");
1162 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1163}
1164
1166 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1167 getNumElements());
1168}
1169
1170bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1171
1172bool StructType::hasDecoration(spirv::Decoration decoration) const {
1174 getImpl()->getStructDecorationsInfo())
1175 if (info.decoration == decoration)
1176 return true;
1177
1178 return false;
1179}
1180
1181uint64_t StructType::getMemberOffset(unsigned index) const {
1182 assert(getNumElements() > index && "member index out of range");
1183 return getImpl()->offsetInfo[index];
1184}
1185
1188 const {
1189 memberDecorations.clear();
1190 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1191 memberDecorations.append(implMemberDecorations.begin(),
1192 implMemberDecorations.end());
1193}
1194
1196 unsigned index,
1198 assert(getNumElements() > index && "member index out of range");
1199 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1200 decorationsInfo.clear();
1201 for (const auto &memberDecoration : memberDecorations) {
1202 if (memberDecoration.memberIndex == index) {
1203 decorationsInfo.push_back(memberDecoration);
1204 }
1205 if (memberDecoration.memberIndex > index) {
1206 // Early exit since the decorations are stored sorted.
1207 return;
1208 }
1209 }
1210}
1211
1214 const {
1215 structDecorations.clear();
1216 auto implDecorations = getImpl()->getStructDecorationsInfo();
1217 structDecorations.append(implDecorations.begin(), implDecorations.end());
1218}
1219
1220LogicalResult
1222 ArrayRef<OffsetInfo> offsetInfo,
1223 ArrayRef<MemberDecorationInfo> memberDecorations,
1224 ArrayRef<StructDecorationInfo> structDecorations) {
1225 return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1226 structDecorations);
1227}
1228
1229llvm::hash_code spirv::hash_value(
1230 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1231 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1232 memberDecorationInfo.decoration);
1233}
1234
1235llvm::hash_code spirv::hash_value(
1236 const StructType::StructDecorationInfo &structDecorationInfo) {
1237 return llvm::hash_value(structDecorationInfo.decoration);
1238}
1239
1240//===----------------------------------------------------------------------===//
1241// MatrixType
1242//===----------------------------------------------------------------------===//
1243
1245 // Use a 64-bit integer as a column count internally to better support a
1246 // `ShapedType` interface. See comment in `CooperativeMatrixType` for more
1247 // context.
1248 using KeyTy = std::tuple<Type, int64_t>;
1249
1251 : columnType(std::get<0>(key)),
1252 shape({cast<VectorType>(std::get<0>(key)).getShape()[0],
1253 std::get<1>(key)}) {}
1254
1256 const KeyTy &key) {
1257
1258 // Initialize the memory using placement new.
1259 return new (allocator.allocate<MatrixTypeStorage>()) MatrixTypeStorage(key);
1260 }
1261
1262 bool operator==(const KeyTy &key) const {
1263 return key == KeyTy(columnType, shape[1]);
1264 }
1265
1267 // [#rows, #columns]
1268 std::array<int64_t, 2> shape;
1269};
1270
1271MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1272 return Base::get(columnType.getContext(), columnType, columnCount);
1273}
1274
1276 Type columnType, uint32_t columnCount) {
1277 return Base::getChecked(emitError, columnType.getContext(), columnType,
1278 columnCount);
1279}
1280
1281LogicalResult
1283 Type columnType, uint32_t columnCount) {
1284 if (columnCount < 2 || columnCount > 4)
1285 return emitError() << "matrix can have 2, 3, or 4 columns only";
1286
1287 if (!isValidColumnType(columnType))
1288 return emitError() << "matrix columns must be vectors of floats";
1289
1290 /// The underlying vectors (columns) must be of size 2, 3, or 4
1291 ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
1292 if (columnShape.size() != 1)
1293 return emitError() << "matrix columns must be 1D vectors";
1294
1295 if (columnShape[0] < 2 || columnShape[0] > 4)
1296 return emitError() << "matrix columns must be of size 2, 3, or 4";
1297
1298 return success();
1299}
1300
1301/// Returns true if the matrix elements are vectors of float elements
1303 if (auto vectorType = dyn_cast<VectorType>(columnType)) {
1304 if (isa<FloatType>(vectorType.getElementType()))
1305 return true;
1306 }
1307 return false;
1308}
1309
1310Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1311
1313 return cast<VectorType>(getImpl()->columnType).getElementType();
1314}
1315
1317 assert(getImpl()->shape[1] >= 0); // Also includes ShapedType::kDynamic.
1318 assert(getImpl()->shape[1] <= std::numeric_limits<unsigned>::max());
1319 return static_cast<uint32_t>(getImpl()->shape[1]);
1320}
1321
1322unsigned MatrixType::getNumRows() const {
1323 assert(getImpl()->shape[0] >= 0); // Also includes ShapedType::kDynamic.
1324 assert(getImpl()->shape[0] <= std::numeric_limits<unsigned>::max());
1325 return static_cast<uint32_t>(getImpl()->shape[0]);
1326}
1327
1329 return getNumColumns() * getNumRows();
1330}
1331
1333
1334void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1335 add(type.getColumnType());
1336 pushCaps<Capability::Matrix>();
1337}
1338
1339//===----------------------------------------------------------------------===//
1340// TensorArmType
1341//===----------------------------------------------------------------------===//
1342
1344 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1345
1347 const KeyTy &key) {
1348 auto [shape, elementType] = key;
1349 shape = allocator.copyInto(shape);
1350 return new (allocator.allocate<TensorArmTypeStorage>())
1352 }
1353
1354 static llvm::hash_code hashKey(const KeyTy &key) {
1355 auto [shape, elementType] = key;
1356 return llvm::hash_combine(shape, elementType);
1357 }
1358
1359 bool operator==(const KeyTy &key) const {
1360 return key == KeyTy(shape, elementType);
1361 }
1362
1365
1368};
1369
1371 return Base::get(elementType.getContext(), shape, elementType);
1372}
1373
1375 Type elementType) const {
1376 return TensorArmType::get(shape.value_or(getShape()), elementType);
1377}
1378
1379Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1381
1382void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1383 add(type.getElementType());
1384 pushExts<Extension::SPV_ARM_tensors>();
1385}
1386
1387void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1388 add(type.getElementType());
1389 pushCaps<Capability::TensorsARM>();
1390}
1391
1392LogicalResult
1394 ArrayRef<int64_t> shape, Type elementType) {
1395 if (llvm::is_contained(shape, 0))
1396 return emitError() << "arm.tensor do not support dimensions = 0";
1397 if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1398 llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1399 return emitError()
1400 << "arm.tensor shape dimensions must be either fully dynamic or "
1401 "completed shaped";
1402 return success();
1403}
1404
1405//===----------------------------------------------------------------------===//
1406// SPIR-V Dialect
1407//===----------------------------------------------------------------------===//
1408
1409void SPIRVDialect::registerTypes() {
1413}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
#define add(a, b)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
TypeStorage()
This constructor is used by derived classes as part of the TypeUniquer.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:148
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
ArrayRef< int64_t > getShape() const
unsigned getNumRows() const
Returns the number of rows.
static NamedBarrierType get(MLIRContext *context)
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:66
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:55
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:274
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:509
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
StorageUniquer::StorageAllocator TypeStorageAllocator
This is a utility allocator used to allocate memory for instances of derived Types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
bool operator==(const KeyTy &key) const
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
bool operator==(const KeyTy &key) const
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Type, int64_t > KeyTy
bool operator==(const KeyTy &key) const
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
StructType::OffsetInfo const * offsetInfo
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::StructDecorationInfo const * structDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
ArrayRef< Type > getMemberTypes() const
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const