MLIR  14.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::spirv;
22 
23 //===----------------------------------------------------------------------===//
24 // ArrayType
25 //===----------------------------------------------------------------------===//
26 
28  using KeyTy = std::tuple<Type, unsigned, unsigned>;
29 
31  const KeyTy &key) {
32  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
33  }
34 
35  bool operator==(const KeyTy &key) const {
36  return key == KeyTy(elementType, elementCount, stride);
37  }
38 
39  ArrayTypeStorage(const KeyTy &key)
40  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
41  stride(std::get<2>(key)) {}
42 
44  unsigned elementCount;
45  unsigned stride;
46 };
47 
48 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
49  assert(elementCount && "ArrayType needs at least one element");
50  return Base::get(elementType.getContext(), elementType, elementCount,
51  /*stride=*/0);
52 }
53 
54 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
55  unsigned stride) {
56  assert(elementCount && "ArrayType needs at least one element");
57  return Base::get(elementType.getContext(), elementType, elementCount, stride);
58 }
59 
60 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
61 
62 Type ArrayType::getElementType() const { return getImpl()->elementType; }
63 
64 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
65 
67  Optional<StorageClass> storage) {
68  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
69 }
70 
73  Optional<StorageClass> storage) {
74  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
75 }
76 
78  auto elementType = getElementType().cast<SPIRVType>();
79  Optional<int64_t> size = elementType.getSizeInBytes();
80  if (!size)
81  return llvm::None;
82  return (*size + getArrayStride()) * getNumElements();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // CompositeType
87 //===----------------------------------------------------------------------===//
88 
90  if (auto vectorType = type.dyn_cast<VectorType>())
91  return isValid(vectorType);
92  return type
95 }
96 
97 bool CompositeType::isValid(VectorType type) {
98  switch (type.getNumElements()) {
99  case 2:
100  case 3:
101  case 4:
102  case 8:
103  case 16:
104  break;
105  default:
106  return false;
107  }
108  return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
109 }
110 
111 Type CompositeType::getElementType(unsigned index) const {
112  return TypeSwitch<Type, Type>(*this)
114  [](auto type) { return type.getElementType(); })
115  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
116  .Case<StructType>(
117  [index](StructType type) { return type.getElementType(index); })
118  .Default(
119  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
120 }
121 
123  if (auto arrayType = dyn_cast<ArrayType>())
124  return arrayType.getNumElements();
125  if (auto matrixType = dyn_cast<MatrixType>())
126  return matrixType.getNumColumns();
127  if (auto structType = dyn_cast<StructType>())
128  return structType.getNumElements();
129  if (auto vectorType = dyn_cast<VectorType>())
130  return vectorType.getNumElements();
131  if (isa<CooperativeMatrixNVType>()) {
132  llvm_unreachable(
133  "invalid to query number of elements of spirv::CooperativeMatrix type");
134  }
135  if (isa<RuntimeArrayType>()) {
136  llvm_unreachable(
137  "invalid to query number of elements of spirv::RuntimeArray type");
138  }
139  llvm_unreachable("invalid composite type");
140 }
141 
143  return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
144 }
145 
148  Optional<StorageClass> storage) {
149  TypeSwitch<Type>(*this)
151  StructType>(
152  [&](auto type) { type.getExtensions(extensions, storage); })
153  .Case<VectorType>([&](VectorType type) {
154  return type.getElementType().cast<ScalarType>().getExtensions(
155  extensions, storage);
156  })
157  .Default([](Type) { llvm_unreachable("invalid composite type"); });
158 }
159 
162  Optional<StorageClass> storage) {
163  TypeSwitch<Type>(*this)
165  StructType>(
166  [&](auto type) { type.getCapabilities(capabilities, storage); })
167  .Case<VectorType>([&](VectorType type) {
168  auto vecSize = getNumElements();
169  if (vecSize == 8 || vecSize == 16) {
170  static const Capability caps[] = {Capability::Vector16};
171  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
172  capabilities.push_back(ref);
173  }
174  return type.getElementType().cast<ScalarType>().getCapabilities(
175  capabilities, storage);
176  })
177  .Default([](Type) { llvm_unreachable("invalid composite type"); });
178 }
179 
181  if (auto arrayType = dyn_cast<ArrayType>())
182  return arrayType.getSizeInBytes();
183  if (auto structType = dyn_cast<StructType>())
184  return structType.getSizeInBytes();
185  if (auto vectorType = dyn_cast<VectorType>()) {
186  Optional<int64_t> elementSize =
187  vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
188  if (!elementSize)
189  return llvm::None;
190  return *elementSize * vectorType.getNumElements();
191  }
192  return llvm::None;
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // CooperativeMatrixType
197 //===----------------------------------------------------------------------===//
198 
200  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
201 
203  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
204  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
206  }
207 
208  bool operator==(const KeyTy &key) const {
209  return key == KeyTy(elementType, scope, rows, columns);
210  }
211 
213  : elementType(std::get<0>(key)), rows(std::get<2>(key)),
214  columns(std::get<3>(key)), scope(std::get<1>(key)) {}
215 
217  unsigned rows;
218  unsigned columns;
219  Scope scope;
220 };
221 
223  Scope scope, unsigned rows,
224  unsigned columns) {
225  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
226 }
227 
229  return getImpl()->elementType;
230 }
231 
232 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
233 
234 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
235 
237  return getImpl()->columns;
238 }
239 
242  Optional<StorageClass> storage) {
243  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
244  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
245  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
246  extensions.push_back(ref);
247 }
248 
251  Optional<StorageClass> storage) {
252  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
253  static const Capability caps[] = {Capability::CooperativeMatrixNV};
254  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
255  capabilities.push_back(ref);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ImageType
260 //===----------------------------------------------------------------------===//
261 
262 template <typename T> static constexpr unsigned getNumBits() { return 0; }
263 template <> constexpr unsigned getNumBits<Dim>() {
264  static_assert((1 << 3) > getMaxEnumValForDim(),
265  "Not enough bits to encode Dim value");
266  return 3;
267 }
268 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
269  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
270  "Not enough bits to encode ImageDepthInfo value");
271  return 2;
272 }
273 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
274  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
275  "Not enough bits to encode ImageArrayedInfo value");
276  return 1;
277 }
278 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
279  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
280  "Not enough bits to encode ImageSamplingInfo value");
281  return 1;
282 }
283 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
284  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
285  "Not enough bits to encode ImageSamplerUseInfo value");
286  return 2;
287 }
288 template <> constexpr unsigned getNumBits<ImageFormat>() {
289  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
290  "Not enough bits to encode ImageFormat value");
291  return 6;
292 }
293 
295 public:
296  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
297  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
298 
300  const KeyTy &key) {
301  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
302  }
303 
304  bool operator==(const KeyTy &key) const {
305  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
306  samplerUseInfo, format);
307  }
308 
310  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
311  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
312  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
313  format(std::get<6>(key)) {}
314 
322 };
323 
324 ImageType
325 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
326  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
327  value) {
328  return Base::get(std::get<0>(value).getContext(), value);
329 }
330 
331 Type ImageType::getElementType() const { return getImpl()->elementType; }
332 
333 Dim ImageType::getDim() const { return getImpl()->dim; }
334 
335 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
336 
337 ImageArrayedInfo ImageType::getArrayedInfo() const {
338  return getImpl()->arrayedInfo;
339 }
340 
341 ImageSamplingInfo ImageType::getSamplingInfo() const {
342  return getImpl()->samplingInfo;
343 }
344 
345 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
346  return getImpl()->samplerUseInfo;
347 }
348 
349 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
350 
353  // Image types do not require extra extensions thus far.
354 }
355 
358  if (auto dimCaps = spirv::getCapabilities(getDim()))
359  capabilities.push_back(*dimCaps);
360 
361  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
362  capabilities.push_back(*fmtCaps);
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // PointerType
367 //===----------------------------------------------------------------------===//
368 
370  // (Type, StorageClass) as the key: Type stored in this struct, and
371  // StorageClass stored as TypeStorage's subclass data.
372  using KeyTy = std::pair<Type, StorageClass>;
373 
375  const KeyTy &key) {
376  return new (allocator.allocate<PointerTypeStorage>())
377  PointerTypeStorage(key);
378  }
379 
380  bool operator==(const KeyTy &key) const {
381  return key == KeyTy(pointeeType, storageClass);
382  }
383 
385  : pointeeType(key.first), storageClass(key.second) {}
386 
388  StorageClass storageClass;
389 };
390 
391 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
392  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
393 }
394 
395 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
396 
397 StorageClass PointerType::getStorageClass() const {
398  return getImpl()->storageClass;
399 }
400 
402  Optional<StorageClass> storage) {
403  // Use this pointer type's storage class because this pointer indicates we are
404  // using the pointee type in that specific storage class.
405  getPointeeType().cast<SPIRVType>().getExtensions(extensions,
406  getStorageClass());
407 
408  if (auto scExts = spirv::getExtensions(getStorageClass()))
409  extensions.push_back(*scExts);
410 }
411 
414  Optional<StorageClass> storage) {
415  // Use this pointer type's storage class because this pointer indicates we are
416  // using the pointee type in that specific storage class.
417  getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
418  getStorageClass());
419 
420  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
421  capabilities.push_back(*scCaps);
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // RuntimeArrayType
426 //===----------------------------------------------------------------------===//
427 
429  using KeyTy = std::pair<Type, unsigned>;
430 
432  const KeyTy &key) {
433  return new (allocator.allocate<RuntimeArrayTypeStorage>())
435  }
436 
437  bool operator==(const KeyTy &key) const {
438  return key == KeyTy(elementType, stride);
439  }
440 
442  : elementType(key.first), stride(key.second) {}
443 
445  unsigned stride;
446 };
447 
449  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
450 }
451 
452 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
453  return Base::get(elementType.getContext(), elementType, stride);
454 }
455 
456 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
457 
458 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
459 
462  Optional<StorageClass> storage) {
463  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
464 }
465 
468  Optional<StorageClass> storage) {
469  {
470  static const Capability caps[] = {Capability::Shader};
471  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
472  capabilities.push_back(ref);
473  }
474  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // ScalarType
479 //===----------------------------------------------------------------------===//
480 
482  if (auto floatType = type.dyn_cast<FloatType>()) {
483  return isValid(floatType);
484  }
485  if (auto intType = type.dyn_cast<IntegerType>()) {
486  return isValid(intType);
487  }
488  return false;
489 }
490 
491 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
492 
493 bool ScalarType::isValid(IntegerType type) {
494  switch (type.getWidth()) {
495  case 1:
496  case 8:
497  case 16:
498  case 32:
499  case 64:
500  return true;
501  default:
502  return false;
503  }
504 }
505 
507  Optional<StorageClass> storage) {
508  // 8- or 16-bit integer/floating-point numbers will require extra extensions
509  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
510  // SPV_KHR_8bit_storage for more details.
511  if (!storage)
512  return;
513 
514  switch (*storage) {
515  case StorageClass::PushConstant:
516  case StorageClass::StorageBuffer:
517  case StorageClass::Uniform:
518  if (getIntOrFloatBitWidth() == 8) {
519  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
520  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
521  extensions.push_back(ref);
522  }
523  LLVM_FALLTHROUGH;
524  case StorageClass::Input:
525  case StorageClass::Output:
526  if (getIntOrFloatBitWidth() == 16) {
527  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
528  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
529  extensions.push_back(ref);
530  }
531  break;
532  default:
533  break;
534  }
535 }
536 
539  Optional<StorageClass> storage) {
540  unsigned bitwidth = getIntOrFloatBitWidth();
541 
542  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
543  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
544  // SPV_KHR_8bit_storage for more details.
545 
546 #define STORAGE_CASE(storage, cap8, cap16) \
547  case StorageClass::storage: { \
548  if (bitwidth == 8) { \
549  static const Capability caps[] = {Capability::cap8}; \
550  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
551  capabilities.push_back(ref); \
552  } else if (bitwidth == 16) { \
553  static const Capability caps[] = {Capability::cap16}; \
554  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
555  capabilities.push_back(ref); \
556  } \
557  /* No requirements for other bitwidths */ \
558  return; \
559  }
560 
561  // This part only handles the cases where special bitwidths appearing in
562  // interface storage classes.
563  if (storage) {
564  switch (*storage) {
565  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
566  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
567  StorageBuffer16BitAccess);
568  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
569  StorageUniform16);
570  case StorageClass::Input:
571  case StorageClass::Output: {
572  if (bitwidth == 16) {
573  static const Capability caps[] = {Capability::StorageInputOutput16};
574  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
575  capabilities.push_back(ref);
576  }
577  return;
578  }
579  default:
580  break;
581  }
582  }
583 #undef STORAGE_CASE
584 
585  // For other non-interface storage classes, require a different set of
586  // capabilities for special bitwidths.
587 
588 #define WIDTH_CASE(type, width) \
589  case width: { \
590  static const Capability caps[] = {Capability::type##width}; \
591  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
592  capabilities.push_back(ref); \
593  } break
594 
595  if (auto intType = dyn_cast<IntegerType>()) {
596  switch (bitwidth) {
597  case 32:
598  case 1:
599  break;
600  WIDTH_CASE(Int, 8);
601  WIDTH_CASE(Int, 16);
602  WIDTH_CASE(Int, 64);
603  default:
604  llvm_unreachable("invalid bitwidth to getCapabilities");
605  }
606  } else {
607  assert(isa<FloatType>());
608  switch (bitwidth) {
609  case 32:
610  break;
611  WIDTH_CASE(Float, 16);
612  WIDTH_CASE(Float, 64);
613  default:
614  llvm_unreachable("invalid bitwidth to getCapabilities");
615  }
616  }
617 
618 #undef WIDTH_CASE
619 }
620 
622  auto bitWidth = getIntOrFloatBitWidth();
623  // According to the SPIR-V spec:
624  // "There is no physical size or bit pattern defined for values with boolean
625  // type. If they are stored (in conjunction with OpVariable), they can only
626  // be used with logical addressing operations, not physical, and only with
627  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
628  // Private, Function, Input, and Output."
629  if (bitWidth == 1)
630  return llvm::None;
631  return bitWidth / 8;
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // SPIRVType
636 //===----------------------------------------------------------------------===//
637 
639  // Allow SPIR-V dialect types
640  if (llvm::isa<SPIRVDialect>(type.getDialect()))
641  return true;
642  if (type.isa<ScalarType>())
643  return true;
644  if (auto vectorType = type.dyn_cast<VectorType>())
646  return false;
647 }
648 
650  return isIntOrFloat() || isa<VectorType>();
651 }
652 
654  Optional<StorageClass> storage) {
655  if (auto scalarType = dyn_cast<ScalarType>()) {
656  scalarType.getExtensions(extensions, storage);
657  } else if (auto compositeType = dyn_cast<CompositeType>()) {
658  compositeType.getExtensions(extensions, storage);
659  } else if (auto imageType = dyn_cast<ImageType>()) {
660  imageType.getExtensions(extensions, storage);
661  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
662  sampledImageType.getExtensions(extensions, storage);
663  } else if (auto matrixType = dyn_cast<MatrixType>()) {
664  matrixType.getExtensions(extensions, storage);
665  } else if (auto ptrType = dyn_cast<PointerType>()) {
666  ptrType.getExtensions(extensions, storage);
667  } else {
668  llvm_unreachable("invalid SPIR-V Type to getExtensions");
669  }
670 }
671 
674  Optional<StorageClass> storage) {
675  if (auto scalarType = dyn_cast<ScalarType>()) {
676  scalarType.getCapabilities(capabilities, storage);
677  } else if (auto compositeType = dyn_cast<CompositeType>()) {
678  compositeType.getCapabilities(capabilities, storage);
679  } else if (auto imageType = dyn_cast<ImageType>()) {
680  imageType.getCapabilities(capabilities, storage);
681  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
682  sampledImageType.getCapabilities(capabilities, storage);
683  } else if (auto matrixType = dyn_cast<MatrixType>()) {
684  matrixType.getCapabilities(capabilities, storage);
685  } else if (auto ptrType = dyn_cast<PointerType>()) {
686  ptrType.getCapabilities(capabilities, storage);
687  } else {
688  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
689  }
690 }
691 
693  if (auto scalarType = dyn_cast<ScalarType>())
694  return scalarType.getSizeInBytes();
695  if (auto compositeType = dyn_cast<CompositeType>())
696  return compositeType.getSizeInBytes();
697  return llvm::None;
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // SampledImageType
702 //===----------------------------------------------------------------------===//
704  using KeyTy = Type;
705 
706  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
707 
708  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
709 
711  const KeyTy &key) {
712  return new (allocator.allocate<SampledImageTypeStorage>())
714  }
715 
717 };
718 
720  return Base::get(imageType.getContext(), imageType);
721 }
722 
725  Type imageType) {
726  return Base::getChecked(emitError, imageType.getContext(), imageType);
727 }
728 
729 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
730 
733  Type imageType) {
734  if (!imageType.isa<ImageType>())
735  return emitError() << "expected image type";
736 
737  return success();
738 }
739 
742  Optional<StorageClass> storage) {
743  getImageType().cast<ImageType>().getExtensions(extensions, storage);
744 }
745 
748  Optional<StorageClass> storage) {
749  getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
750 }
751 
752 //===----------------------------------------------------------------------===//
753 // StructType
754 //===----------------------------------------------------------------------===//
755 
756 /// Type storage for SPIR-V structure types:
757 ///
758 /// Structures are uniqued using:
759 /// - for identified structs:
760 /// - a string identifier;
761 /// - for literal structs:
762 /// - a list of member types;
763 /// - a list of member offset info;
764 /// - a list of member decoration info.
765 ///
766 /// Identified structures only have a mutable component consisting of:
767 /// - a list of member types;
768 /// - a list of member offset info;
769 /// - a list of member decoration info.
771  /// Construct a storage object for an identified struct type. A struct type
772  /// associated with such storage must call StructType::trySetBody(...) later
773  /// in order to mutate the storage object providing the actual content.
774  StructTypeStorage(StringRef identifier)
775  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
776  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
777  identifier(identifier) {}
778 
779  /// Construct a storage object for a literal struct type. A struct type
780  /// associated with such storage is immutable.
782  unsigned numMembers, Type const *memberTypes,
783  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
784  StructType::MemberDecorationInfo const *memberDecorationsInfo)
785  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
786  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
787  memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
788 
789  /// A storage key is divided into 2 parts:
790  /// - for identified structs:
791  /// - a StringRef representing the struct identifier;
792  /// - for literal structs:
793  /// - an ArrayRef<Type> for member types;
794  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
795  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
796  /// info.
797  ///
798  /// An identified struct type is uniqued only by the first part (field 0)
799  /// of the key.
800  ///
801  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
802  /// 3) of the key. The identifier field (field 0) must be empty.
803  using KeyTy =
804  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
806 
807  /// For identified structs, return true if the given key contains the same
808  /// identifier.
809  ///
810  /// For literal structs, return true if the given key contains a matching list
811  /// of member types + offset info + decoration info.
812  bool operator==(const KeyTy &key) const {
813  if (isIdentified()) {
814  // Identified types are uniqued by their identifier.
815  return getIdentifier() == std::get<0>(key);
816  }
817 
818  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
819  getMemberDecorationsInfo());
820  }
821 
822  /// If the given key contains a non-empty identifier, this method constructs
823  /// an identified struct and leaves the rest of the struct type data to be set
824  /// through a later call to StructType::trySetBody(...).
825  ///
826  /// If, on the other hand, the key contains an empty identifier, a literal
827  /// struct is constructed using the other fields of the key.
829  const KeyTy &key) {
830  StringRef keyIdentifier = std::get<0>(key);
831 
832  if (!keyIdentifier.empty()) {
833  StringRef identifier = allocator.copyInto(keyIdentifier);
834 
835  // Identified StructType body/members will be set through trySetBody(...)
836  // later.
837  return new (allocator.allocate<StructTypeStorage>())
838  StructTypeStorage(identifier);
839  }
840 
841  ArrayRef<Type> keyTypes = std::get<1>(key);
842 
843  // Copy the member type and layout information into the bump pointer
844  const Type *typesList = nullptr;
845  if (!keyTypes.empty()) {
846  typesList = allocator.copyInto(keyTypes).data();
847  }
848 
849  const StructType::OffsetInfo *offsetInfoList = nullptr;
850  if (!std::get<2>(key).empty()) {
851  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
852  assert(keyOffsetInfo.size() == keyTypes.size() &&
853  "size of offset information must be same as the size of number of "
854  "elements");
855  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
856  }
857 
858  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
859  unsigned numMemberDecorations = 0;
860  if (!std::get<3>(key).empty()) {
861  auto keyMemberDecorations = std::get<3>(key);
862  numMemberDecorations = keyMemberDecorations.size();
863  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
864  }
865 
866  return new (allocator.allocate<StructTypeStorage>())
867  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
868  numMemberDecorations, memberDecorationList);
869  }
870 
872  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
873  }
874 
875  ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
876  if (offsetInfo) {
877  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
878  }
879  return {};
880  }
881 
883  if (memberDecorationsInfo) {
884  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
885  numMemberDecorations);
886  }
887  return {};
888  }
889 
890  StringRef getIdentifier() const { return identifier; }
891 
892  bool isIdentified() const { return !identifier.empty(); }
893 
894  /// Sets the struct type content for identified structs. Calling this method
895  /// is only valid for identified structs.
896  ///
897  /// Fails under the following conditions:
898  /// - If called for a literal struct;
899  /// - If called for an identified struct whose body was set before (through a
900  /// call to this method) but with different contents from the passed
901  /// arguments.
903  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
904  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
905  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
906  if (!isIdentified())
907  return failure();
908 
909  if (memberTypesAndIsBodySet.getInt() &&
910  (getMemberTypes() != structMemberTypes ||
911  getOffsetInfo() != structOffsetInfo ||
912  getMemberDecorationsInfo() != structMemberDecorationInfo))
913  return failure();
914 
915  memberTypesAndIsBodySet.setInt(true);
916  numMembers = structMemberTypes.size();
917 
918  // Copy the member type and layout information into the bump pointer.
919  if (!structMemberTypes.empty())
920  memberTypesAndIsBodySet.setPointer(
921  allocator.copyInto(structMemberTypes).data());
922 
923  if (!structOffsetInfo.empty()) {
924  assert(structOffsetInfo.size() == structMemberTypes.size() &&
925  "size of offset information must be same as the size of number of "
926  "elements");
927  offsetInfo = allocator.copyInto(structOffsetInfo).data();
928  }
929 
930  if (!structMemberDecorationInfo.empty()) {
931  numMemberDecorations = structMemberDecorationInfo.size();
932  memberDecorationsInfo =
933  allocator.copyInto(structMemberDecorationInfo).data();
934  }
935 
936  return success();
937  }
938 
939  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
941  unsigned numMembers;
944  StringRef identifier;
945 };
946 
950  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
951  assert(!memberTypes.empty() && "Struct needs at least one member type");
952  // Sort the decorations.
954  memberDecorations.begin(), memberDecorations.end());
955  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
956  return Base::get(memberTypes.vec().front().getContext(),
957  /*identifier=*/StringRef(), memberTypes, offsetInfo,
958  sortedDecorations);
959 }
960 
962  StringRef identifier) {
963  assert(!identifier.empty() &&
964  "StructType identifier must be non-empty string");
965 
966  return Base::get(context, identifier, ArrayRef<Type>(),
969 }
970 
971 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
972  StructType newStructType = Base::get(
973  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
975  // Set an empty body in case this is a identified struct.
976  if (newStructType.isIdentified() &&
977  failed(newStructType.trySetBody(
980  return StructType();
981 
982  return newStructType;
983 }
984 
985 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
986 
987 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
988 
989 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
990 
991 Type StructType::getElementType(unsigned index) const {
992  assert(getNumElements() > index && "member index out of range");
993  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
994 }
995 
997  return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
998  getNumElements());
999 }
1000 
1001 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1002 
1003 uint64_t StructType::getMemberOffset(unsigned index) const {
1004  assert(getNumElements() > index && "member index out of range");
1005  return getImpl()->offsetInfo[index];
1006 }
1007 
1010  const {
1011  memberDecorations.clear();
1012  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1013  memberDecorations.append(implMemberDecorations.begin(),
1014  implMemberDecorations.end());
1015 }
1016 
1018  unsigned index,
1019  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1020  assert(getNumElements() > index && "member index out of range");
1021  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1022  decorationsInfo.clear();
1023  for (const auto &memberDecoration : memberDecorations) {
1024  if (memberDecoration.memberIndex == index) {
1025  decorationsInfo.push_back(memberDecoration);
1026  }
1027  if (memberDecoration.memberIndex > index) {
1028  // Early exit since the decorations are stored sorted.
1029  return;
1030  }
1031  }
1032 }
1033 
1036  ArrayRef<OffsetInfo> offsetInfo,
1037  ArrayRef<MemberDecorationInfo> memberDecorations) {
1038  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1039 }
1040 
1042  Optional<StorageClass> storage) {
1043  for (Type elementType : getElementTypes())
1044  elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1045 }
1046 
1049  Optional<StorageClass> storage) {
1050  for (Type elementType : getElementTypes())
1051  elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1052 }
1053 
1054 llvm::hash_code spirv::hash_value(
1055  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1056  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1057  memberDecorationInfo.decoration);
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // MatrixType
1062 //===----------------------------------------------------------------------===//
1063 
1065  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1066  : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1067 
1068  using KeyTy = std::tuple<Type, uint32_t>;
1069 
1071  const KeyTy &key) {
1072 
1073  // Initialize the memory using placement new.
1074  return new (allocator.allocate<MatrixTypeStorage>())
1075  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1076  }
1077 
1078  bool operator==(const KeyTy &key) const {
1079  return key == KeyTy(columnType, columnCount);
1080  }
1081 
1083  const uint32_t columnCount;
1084 };
1085 
1086 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1087  return Base::get(columnType.getContext(), columnType, columnCount);
1088 }
1089 
1091  Type columnType, uint32_t columnCount) {
1092  return Base::getChecked(emitError, columnType.getContext(), columnType,
1093  columnCount);
1094 }
1095 
1097  Type columnType, uint32_t columnCount) {
1098  if (columnCount < 2 || columnCount > 4)
1099  return emitError() << "matrix can have 2, 3, or 4 columns only";
1100 
1101  if (!isValidColumnType(columnType))
1102  return emitError() << "matrix columns must be vectors of floats";
1103 
1104  /// The underlying vectors (columns) must be of size 2, 3, or 4
1105  ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1106  if (columnShape.size() != 1)
1107  return emitError() << "matrix columns must be 1D vectors";
1108 
1109  if (columnShape[0] < 2 || columnShape[0] > 4)
1110  return emitError() << "matrix columns must be of size 2, 3, or 4";
1111 
1112  return success();
1113 }
1114 
1115 /// Returns true if the matrix elements are vectors of float elements
1117  if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1118  if (vectorType.getElementType().isa<FloatType>())
1119  return true;
1120  }
1121  return false;
1122 }
1123 
1124 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1125 
1127  return getImpl()->columnType.cast<VectorType>().getElementType();
1128 }
1129 
1130 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1131 
1132 unsigned MatrixType::getNumRows() const {
1133  return getImpl()->columnType.cast<VectorType>().getShape()[0];
1134 }
1135 
1136 unsigned MatrixType::getNumElements() const {
1137  return (getImpl()->columnCount) * getNumRows();
1138 }
1139 
1141  Optional<StorageClass> storage) {
1142  getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1143 }
1144 
1147  Optional<StorageClass> storage) {
1148  {
1149  static const Capability caps[] = {Capability::Matrix};
1150  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1151  capabilities.push_back(ref);
1152  }
1153  // Add any capabilities associated with the underlying vectors (i.e., columns)
1154  getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1155 }
1156 
1157 //===----------------------------------------------------------------------===//
1158 // SPIR-V Dialect
1159 //===----------------------------------------------------------------------===//
1160 
1161 void SPIRVDialect::registerTypes() {
1164 }
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:268
Include the generated interface declarations.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:66
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:349
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:122
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:160
Base storage class appearing in a Type.
Definition: TypeSupport.h:121
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:506
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This is a utility allocator used to allocate memory for instances of derived types.
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static bool classof(Type type)
Definition: SPIRVTypes.cpp:638
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:372
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:902
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:240
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:304
bool isIdentified() const
Returns true if the StructType is identified.
Definition: SPIRVTypes.cpp:987
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
Type getElementType() const
Definition: SPIRVTypes.cpp:331
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:781
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:828
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:351
std::tuple< Type, uint32_t > KeyTy
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:412
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:491
static constexpr const bool value
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:111
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:708
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:875
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:273
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:948
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:732
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:437
T * allocate()
Allocate an instance of the provided type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:180
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:719
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Range class for element types.
Definition: SPIRVTypes.h:346
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:77
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:746
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:380
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:222
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:374
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:35
U dyn_cast() const
Definition: Types.h:244
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:805
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
#define WIDTH_CASE(type, width)
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
Type getElementType() const
Returns the elements&#39; type (i.e, single element type).
StructType::MemberDecorationInfo const * memberDecorationsInfo
Definition: SPIRVTypes.cpp:943
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:460
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:740
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:431
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:774
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:401
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:335
unsigned getNumElements() const
Returns total number of elements (rows*columns).
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:985
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:871
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent...
Definition: SPIRVTypes.cpp:142
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:28
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:297
uint64_t getMemberOffset(unsigned) const
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:971
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:30
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:621
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
Definition: SPIRVTypes.cpp:939
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:278
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:710
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
#define STORAGE_CASE(storage, cap8, cap16)
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:283
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:249
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:989
ElementTypeRange getElementTypes() const
Definition: SPIRVTypes.cpp:996
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:200
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:991
static bool classof(Type type)
Definition: SPIRVTypes.cpp:481
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:299
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:678
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:537
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:770
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
StructType::OffsetInfo const * offsetInfo
Definition: SPIRVTypes.cpp:940
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:466
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:203
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:882
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
MatrixTypeStorage(Type columnType, uint32_t columnCount)
SPIR-V struct type.
Definition: SPIRVTypes.h:278
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:724
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:262
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:397
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:812
void getCapabilities(CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Appends to capabilities the capabilities needed for this type to appear in the given storage class...
Definition: SPIRVTypes.cpp:672
void getExtensions(ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Appends to extensions the extensions needed for this type to appear in the given storage class...
Definition: SPIRVTypes.cpp:653
bool operator==(const KeyTy &key) const
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
static bool classof(Type type)
Definition: SPIRVTypes.cpp:89
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
static MatrixType get(Type columnType, uint32_t columnCount)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:961
unsigned getNumRows() const
Returns the number of rows.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
bool isa() const
Definition: Types.h:234
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Type getColumnType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:146
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:263
Optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:692
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:71
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:356
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
bool isBF16() const
Definition: Types.cpp:21
U cast() const
Definition: Types.h:250
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:288
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:97