MLIR  15.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::spirv;
22 
23 //===----------------------------------------------------------------------===//
24 // ArrayType
25 //===----------------------------------------------------------------------===//
26 
28  using KeyTy = std::tuple<Type, unsigned, unsigned>;
29 
31  const KeyTy &key) {
32  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
33  }
34 
35  bool operator==(const KeyTy &key) const {
36  return key == KeyTy(elementType, elementCount, stride);
37  }
38 
39  ArrayTypeStorage(const KeyTy &key)
40  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
41  stride(std::get<2>(key)) {}
42 
44  unsigned elementCount;
45  unsigned stride;
46 };
47 
48 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
49  assert(elementCount && "ArrayType needs at least one element");
50  return Base::get(elementType.getContext(), elementType, elementCount,
51  /*stride=*/0);
52 }
53 
54 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
55  unsigned stride) {
56  assert(elementCount && "ArrayType needs at least one element");
57  return Base::get(elementType.getContext(), elementType, elementCount, stride);
58 }
59 
60 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
61 
62 Type ArrayType::getElementType() const { return getImpl()->elementType; }
63 
64 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
65 
67  Optional<StorageClass> storage) {
68  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
69 }
70 
73  Optional<StorageClass> storage) {
74  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
75 }
76 
78  auto elementType = getElementType().cast<SPIRVType>();
79  Optional<int64_t> size = elementType.getSizeInBytes();
80  if (!size)
81  return llvm::None;
82  return (*size + getArrayStride()) * getNumElements();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // CompositeType
87 //===----------------------------------------------------------------------===//
88 
90  if (auto vectorType = type.dyn_cast<VectorType>())
91  return isValid(vectorType);
92  return type
95 }
96 
97 bool CompositeType::isValid(VectorType type) {
98  switch (type.getNumElements()) {
99  case 2:
100  case 3:
101  case 4:
102  case 8:
103  case 16:
104  break;
105  default:
106  return false;
107  }
108  return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
109 }
110 
111 Type CompositeType::getElementType(unsigned index) const {
112  return TypeSwitch<Type, Type>(*this)
114  [](auto type) { return type.getElementType(); })
115  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
116  .Case<StructType>(
117  [index](StructType type) { return type.getElementType(index); })
118  .Default(
119  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
120 }
121 
123  if (auto arrayType = dyn_cast<ArrayType>())
124  return arrayType.getNumElements();
125  if (auto matrixType = dyn_cast<MatrixType>())
126  return matrixType.getNumColumns();
127  if (auto structType = dyn_cast<StructType>())
128  return structType.getNumElements();
129  if (auto vectorType = dyn_cast<VectorType>())
130  return vectorType.getNumElements();
131  if (isa<CooperativeMatrixNVType>()) {
132  llvm_unreachable(
133  "invalid to query number of elements of spirv::CooperativeMatrix type");
134  }
135  if (isa<RuntimeArrayType>()) {
136  llvm_unreachable(
137  "invalid to query number of elements of spirv::RuntimeArray type");
138  }
139  llvm_unreachable("invalid composite type");
140 }
141 
143  return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
144 }
145 
148  Optional<StorageClass> storage) {
149  TypeSwitch<Type>(*this)
151  StructType>(
152  [&](auto type) { type.getExtensions(extensions, storage); })
153  .Case<VectorType>([&](VectorType type) {
154  return type.getElementType().cast<ScalarType>().getExtensions(
155  extensions, storage);
156  })
157  .Default([](Type) { llvm_unreachable("invalid composite type"); });
158 }
159 
162  Optional<StorageClass> storage) {
163  TypeSwitch<Type>(*this)
165  StructType>(
166  [&](auto type) { type.getCapabilities(capabilities, storage); })
167  .Case<VectorType>([&](VectorType type) {
168  auto vecSize = getNumElements();
169  if (vecSize == 8 || vecSize == 16) {
170  static const Capability caps[] = {Capability::Vector16};
171  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
172  capabilities.push_back(ref);
173  }
174  return type.getElementType().cast<ScalarType>().getCapabilities(
175  capabilities, storage);
176  })
177  .Default([](Type) { llvm_unreachable("invalid composite type"); });
178 }
179 
181  if (auto arrayType = dyn_cast<ArrayType>())
182  return arrayType.getSizeInBytes();
183  if (auto structType = dyn_cast<StructType>())
184  return structType.getSizeInBytes();
185  if (auto vectorType = dyn_cast<VectorType>()) {
186  Optional<int64_t> elementSize =
187  vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
188  if (!elementSize)
189  return llvm::None;
190  return *elementSize * vectorType.getNumElements();
191  }
192  return llvm::None;
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // CooperativeMatrixType
197 //===----------------------------------------------------------------------===//
198 
200  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
201 
203  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
204  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
206  }
207 
208  bool operator==(const KeyTy &key) const {
209  return key == KeyTy(elementType, scope, rows, columns);
210  }
211 
213  : elementType(std::get<0>(key)), rows(std::get<2>(key)),
214  columns(std::get<3>(key)), scope(std::get<1>(key)) {}
215 
217  unsigned rows;
218  unsigned columns;
219  Scope scope;
220 };
221 
223  Scope scope, unsigned rows,
224  unsigned columns) {
225  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
226 }
227 
229  return getImpl()->elementType;
230 }
231 
232 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
233 
234 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
235 
237  return getImpl()->columns;
238 }
239 
242  Optional<StorageClass> storage) {
243  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
244  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
245  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
246  extensions.push_back(ref);
247 }
248 
251  Optional<StorageClass> storage) {
252  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
253  static const Capability caps[] = {Capability::CooperativeMatrixNV};
254  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
255  capabilities.push_back(ref);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ImageType
260 //===----------------------------------------------------------------------===//
261 
262 template <typename T> static constexpr unsigned getNumBits() { return 0; }
263 template <> constexpr unsigned getNumBits<Dim>() {
264  static_assert((1 << 3) > getMaxEnumValForDim(),
265  "Not enough bits to encode Dim value");
266  return 3;
267 }
268 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
269  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
270  "Not enough bits to encode ImageDepthInfo value");
271  return 2;
272 }
273 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
274  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
275  "Not enough bits to encode ImageArrayedInfo value");
276  return 1;
277 }
278 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
279  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
280  "Not enough bits to encode ImageSamplingInfo value");
281  return 1;
282 }
283 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
284  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
285  "Not enough bits to encode ImageSamplerUseInfo value");
286  return 2;
287 }
288 template <> constexpr unsigned getNumBits<ImageFormat>() {
289  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
290  "Not enough bits to encode ImageFormat value");
291  return 6;
292 }
293 
295 public:
296  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
297  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
298 
300  const KeyTy &key) {
301  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
302  }
303 
304  bool operator==(const KeyTy &key) const {
305  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
306  samplerUseInfo, format);
307  }
308 
310  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
311  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
312  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
313  format(std::get<6>(key)) {}
314 
322 };
323 
324 ImageType
325 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
326  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
327  value) {
328  return Base::get(std::get<0>(value).getContext(), value);
329 }
330 
331 Type ImageType::getElementType() const { return getImpl()->elementType; }
332 
333 Dim ImageType::getDim() const { return getImpl()->dim; }
334 
335 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
336 
337 ImageArrayedInfo ImageType::getArrayedInfo() const {
338  return getImpl()->arrayedInfo;
339 }
340 
341 ImageSamplingInfo ImageType::getSamplingInfo() const {
342  return getImpl()->samplingInfo;
343 }
344 
345 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
346  return getImpl()->samplerUseInfo;
347 }
348 
349 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
350 
353  // Image types do not require extra extensions thus far.
354 }
355 
358  if (auto dimCaps = spirv::getCapabilities(getDim()))
359  capabilities.push_back(*dimCaps);
360 
361  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
362  capabilities.push_back(*fmtCaps);
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // PointerType
367 //===----------------------------------------------------------------------===//
368 
370  // (Type, StorageClass) as the key: Type stored in this struct, and
371  // StorageClass stored as TypeStorage's subclass data.
372  using KeyTy = std::pair<Type, StorageClass>;
373 
375  const KeyTy &key) {
376  return new (allocator.allocate<PointerTypeStorage>())
377  PointerTypeStorage(key);
378  }
379 
380  bool operator==(const KeyTy &key) const {
381  return key == KeyTy(pointeeType, storageClass);
382  }
383 
385  : pointeeType(key.first), storageClass(key.second) {}
386 
388  StorageClass storageClass;
389 };
390 
391 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
392  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
393 }
394 
395 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
396 
397 StorageClass PointerType::getStorageClass() const {
398  return getImpl()->storageClass;
399 }
400 
402  Optional<StorageClass> storage) {
403  // Use this pointer type's storage class because this pointer indicates we are
404  // using the pointee type in that specific storage class.
405  getPointeeType().cast<SPIRVType>().getExtensions(extensions,
406  getStorageClass());
407 
408  if (auto scExts = spirv::getExtensions(getStorageClass()))
409  extensions.push_back(*scExts);
410 }
411 
414  Optional<StorageClass> storage) {
415  // Use this pointer type's storage class because this pointer indicates we are
416  // using the pointee type in that specific storage class.
417  getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
418  getStorageClass());
419 
420  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
421  capabilities.push_back(*scCaps);
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // RuntimeArrayType
426 //===----------------------------------------------------------------------===//
427 
429  using KeyTy = std::pair<Type, unsigned>;
430 
432  const KeyTy &key) {
433  return new (allocator.allocate<RuntimeArrayTypeStorage>())
435  }
436 
437  bool operator==(const KeyTy &key) const {
438  return key == KeyTy(elementType, stride);
439  }
440 
442  : elementType(key.first), stride(key.second) {}
443 
445  unsigned stride;
446 };
447 
449  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
450 }
451 
452 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
453  return Base::get(elementType.getContext(), elementType, stride);
454 }
455 
456 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
457 
458 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
459 
462  Optional<StorageClass> storage) {
463  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
464 }
465 
468  Optional<StorageClass> storage) {
469  {
470  static const Capability caps[] = {Capability::Shader};
471  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
472  capabilities.push_back(ref);
473  }
474  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // ScalarType
479 //===----------------------------------------------------------------------===//
480 
482  if (auto floatType = type.dyn_cast<FloatType>()) {
483  return isValid(floatType);
484  }
485  if (auto intType = type.dyn_cast<IntegerType>()) {
486  return isValid(intType);
487  }
488  return false;
489 }
490 
491 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
492 
493 bool ScalarType::isValid(IntegerType type) {
494  switch (type.getWidth()) {
495  case 1:
496  case 8:
497  case 16:
498  case 32:
499  case 64:
500  return true;
501  default:
502  return false;
503  }
504 }
505 
507  Optional<StorageClass> storage) {
508  // 8- or 16-bit integer/floating-point numbers will require extra extensions
509  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
510  // SPV_KHR_8bit_storage for more details.
511  if (!storage)
512  return;
513 
514  switch (*storage) {
515  case StorageClass::PushConstant:
516  case StorageClass::StorageBuffer:
517  case StorageClass::Uniform:
518  if (getIntOrFloatBitWidth() == 8) {
519  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
520  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
521  extensions.push_back(ref);
522  }
523  LLVM_FALLTHROUGH;
524  case StorageClass::Input:
525  case StorageClass::Output:
526  if (getIntOrFloatBitWidth() == 16) {
527  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
528  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
529  extensions.push_back(ref);
530  }
531  break;
532  default:
533  break;
534  }
535 }
536 
539  Optional<StorageClass> storage) {
540  unsigned bitwidth = getIntOrFloatBitWidth();
541 
542  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
543  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
544  // SPV_KHR_8bit_storage for more details.
545 
546 #define STORAGE_CASE(storage, cap8, cap16) \
547  case StorageClass::storage: { \
548  if (bitwidth == 8) { \
549  static const Capability caps[] = {Capability::cap8}; \
550  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
551  capabilities.push_back(ref); \
552  return; \
553  } \
554  if (bitwidth == 16) { \
555  static const Capability caps[] = {Capability::cap16}; \
556  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
557  capabilities.push_back(ref); \
558  return; \
559  } \
560  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
561  /* storage classes. Fall through to the next section. */ \
562  } break
563 
564  // This part only handles the cases where special bitwidths appearing in
565  // interface storage classes.
566  if (storage) {
567  switch (*storage) {
568  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
569  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
570  StorageBuffer16BitAccess);
571  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
572  StorageUniform16);
573  case StorageClass::Input:
574  case StorageClass::Output: {
575  if (bitwidth == 16) {
576  static const Capability caps[] = {Capability::StorageInputOutput16};
577  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
578  capabilities.push_back(ref);
579  return;
580  }
581  break;
582  }
583  default:
584  break;
585  }
586  }
587 #undef STORAGE_CASE
588 
589  // For other non-interface storage classes, require a different set of
590  // capabilities for special bitwidths.
591 
592 #define WIDTH_CASE(type, width) \
593  case width: { \
594  static const Capability caps[] = {Capability::type##width}; \
595  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
596  capabilities.push_back(ref); \
597  } break
598 
599  if (auto intType = dyn_cast<IntegerType>()) {
600  switch (bitwidth) {
601  WIDTH_CASE(Int, 8);
602  WIDTH_CASE(Int, 16);
603  WIDTH_CASE(Int, 64);
604  case 1:
605  case 32:
606  break;
607  default:
608  llvm_unreachable("invalid bitwidth to getCapabilities");
609  }
610  } else {
611  assert(isa<FloatType>());
612  switch (bitwidth) {
613  WIDTH_CASE(Float, 16);
614  WIDTH_CASE(Float, 64);
615  case 32:
616  break;
617  default:
618  llvm_unreachable("invalid bitwidth to getCapabilities");
619  }
620  }
621 
622 #undef WIDTH_CASE
623 }
624 
626  auto bitWidth = getIntOrFloatBitWidth();
627  // According to the SPIR-V spec:
628  // "There is no physical size or bit pattern defined for values with boolean
629  // type. If they are stored (in conjunction with OpVariable), they can only
630  // be used with logical addressing operations, not physical, and only with
631  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
632  // Private, Function, Input, and Output."
633  if (bitWidth == 1)
634  return llvm::None;
635  return bitWidth / 8;
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // SPIRVType
640 //===----------------------------------------------------------------------===//
641 
643  // Allow SPIR-V dialect types
644  if (llvm::isa<SPIRVDialect>(type.getDialect()))
645  return true;
646  if (type.isa<ScalarType>())
647  return true;
648  if (auto vectorType = type.dyn_cast<VectorType>())
650  return false;
651 }
652 
654  return isIntOrFloat() || isa<VectorType>();
655 }
656 
658  Optional<StorageClass> storage) {
659  if (auto scalarType = dyn_cast<ScalarType>()) {
660  scalarType.getExtensions(extensions, storage);
661  } else if (auto compositeType = dyn_cast<CompositeType>()) {
662  compositeType.getExtensions(extensions, storage);
663  } else if (auto imageType = dyn_cast<ImageType>()) {
664  imageType.getExtensions(extensions, storage);
665  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
666  sampledImageType.getExtensions(extensions, storage);
667  } else if (auto matrixType = dyn_cast<MatrixType>()) {
668  matrixType.getExtensions(extensions, storage);
669  } else if (auto ptrType = dyn_cast<PointerType>()) {
670  ptrType.getExtensions(extensions, storage);
671  } else {
672  llvm_unreachable("invalid SPIR-V Type to getExtensions");
673  }
674 }
675 
678  Optional<StorageClass> storage) {
679  if (auto scalarType = dyn_cast<ScalarType>()) {
680  scalarType.getCapabilities(capabilities, storage);
681  } else if (auto compositeType = dyn_cast<CompositeType>()) {
682  compositeType.getCapabilities(capabilities, storage);
683  } else if (auto imageType = dyn_cast<ImageType>()) {
684  imageType.getCapabilities(capabilities, storage);
685  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
686  sampledImageType.getCapabilities(capabilities, storage);
687  } else if (auto matrixType = dyn_cast<MatrixType>()) {
688  matrixType.getCapabilities(capabilities, storage);
689  } else if (auto ptrType = dyn_cast<PointerType>()) {
690  ptrType.getCapabilities(capabilities, storage);
691  } else {
692  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
693  }
694 }
695 
697  if (auto scalarType = dyn_cast<ScalarType>())
698  return scalarType.getSizeInBytes();
699  if (auto compositeType = dyn_cast<CompositeType>())
700  return compositeType.getSizeInBytes();
701  return llvm::None;
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // SampledImageType
706 //===----------------------------------------------------------------------===//
708  using KeyTy = Type;
709 
710  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
711 
712  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
713 
715  const KeyTy &key) {
716  return new (allocator.allocate<SampledImageTypeStorage>())
718  }
719 
721 };
722 
724  return Base::get(imageType.getContext(), imageType);
725 }
726 
729  Type imageType) {
730  return Base::getChecked(emitError, imageType.getContext(), imageType);
731 }
732 
733 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
734 
737  Type imageType) {
738  if (!imageType.isa<ImageType>())
739  return emitError() << "expected image type";
740 
741  return success();
742 }
743 
746  Optional<StorageClass> storage) {
747  getImageType().cast<ImageType>().getExtensions(extensions, storage);
748 }
749 
752  Optional<StorageClass> storage) {
753  getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // StructType
758 //===----------------------------------------------------------------------===//
759 
760 /// Type storage for SPIR-V structure types:
761 ///
762 /// Structures are uniqued using:
763 /// - for identified structs:
764 /// - a string identifier;
765 /// - for literal structs:
766 /// - a list of member types;
767 /// - a list of member offset info;
768 /// - a list of member decoration info.
769 ///
770 /// Identified structures only have a mutable component consisting of:
771 /// - a list of member types;
772 /// - a list of member offset info;
773 /// - a list of member decoration info.
775  /// Construct a storage object for an identified struct type. A struct type
776  /// associated with such storage must call StructType::trySetBody(...) later
777  /// in order to mutate the storage object providing the actual content.
778  StructTypeStorage(StringRef identifier)
779  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
780  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
781  identifier(identifier) {}
782 
783  /// Construct a storage object for a literal struct type. A struct type
784  /// associated with such storage is immutable.
786  unsigned numMembers, Type const *memberTypes,
787  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
788  StructType::MemberDecorationInfo const *memberDecorationsInfo)
789  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
790  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
791  memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
792 
793  /// A storage key is divided into 2 parts:
794  /// - for identified structs:
795  /// - a StringRef representing the struct identifier;
796  /// - for literal structs:
797  /// - an ArrayRef<Type> for member types;
798  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
799  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
800  /// info.
801  ///
802  /// An identified struct type is uniqued only by the first part (field 0)
803  /// of the key.
804  ///
805  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
806  /// 3) of the key. The identifier field (field 0) must be empty.
807  using KeyTy =
808  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
810 
811  /// For identified structs, return true if the given key contains the same
812  /// identifier.
813  ///
814  /// For literal structs, return true if the given key contains a matching list
815  /// of member types + offset info + decoration info.
816  bool operator==(const KeyTy &key) const {
817  if (isIdentified()) {
818  // Identified types are uniqued by their identifier.
819  return getIdentifier() == std::get<0>(key);
820  }
821 
822  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
823  getMemberDecorationsInfo());
824  }
825 
826  /// If the given key contains a non-empty identifier, this method constructs
827  /// an identified struct and leaves the rest of the struct type data to be set
828  /// through a later call to StructType::trySetBody(...).
829  ///
830  /// If, on the other hand, the key contains an empty identifier, a literal
831  /// struct is constructed using the other fields of the key.
833  const KeyTy &key) {
834  StringRef keyIdentifier = std::get<0>(key);
835 
836  if (!keyIdentifier.empty()) {
837  StringRef identifier = allocator.copyInto(keyIdentifier);
838 
839  // Identified StructType body/members will be set through trySetBody(...)
840  // later.
841  return new (allocator.allocate<StructTypeStorage>())
842  StructTypeStorage(identifier);
843  }
844 
845  ArrayRef<Type> keyTypes = std::get<1>(key);
846 
847  // Copy the member type and layout information into the bump pointer
848  const Type *typesList = nullptr;
849  if (!keyTypes.empty()) {
850  typesList = allocator.copyInto(keyTypes).data();
851  }
852 
853  const StructType::OffsetInfo *offsetInfoList = nullptr;
854  if (!std::get<2>(key).empty()) {
855  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
856  assert(keyOffsetInfo.size() == keyTypes.size() &&
857  "size of offset information must be same as the size of number of "
858  "elements");
859  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
860  }
861 
862  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
863  unsigned numMemberDecorations = 0;
864  if (!std::get<3>(key).empty()) {
865  auto keyMemberDecorations = std::get<3>(key);
866  numMemberDecorations = keyMemberDecorations.size();
867  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
868  }
869 
870  return new (allocator.allocate<StructTypeStorage>())
871  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
872  numMemberDecorations, memberDecorationList);
873  }
874 
876  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
877  }
878 
879  ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
880  if (offsetInfo) {
881  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
882  }
883  return {};
884  }
885 
887  if (memberDecorationsInfo) {
888  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
889  numMemberDecorations);
890  }
891  return {};
892  }
893 
894  StringRef getIdentifier() const { return identifier; }
895 
896  bool isIdentified() const { return !identifier.empty(); }
897 
898  /// Sets the struct type content for identified structs. Calling this method
899  /// is only valid for identified structs.
900  ///
901  /// Fails under the following conditions:
902  /// - If called for a literal struct;
903  /// - If called for an identified struct whose body was set before (through a
904  /// call to this method) but with different contents from the passed
905  /// arguments.
907  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
908  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
909  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
910  if (!isIdentified())
911  return failure();
912 
913  if (memberTypesAndIsBodySet.getInt() &&
914  (getMemberTypes() != structMemberTypes ||
915  getOffsetInfo() != structOffsetInfo ||
916  getMemberDecorationsInfo() != structMemberDecorationInfo))
917  return failure();
918 
919  memberTypesAndIsBodySet.setInt(true);
920  numMembers = structMemberTypes.size();
921 
922  // Copy the member type and layout information into the bump pointer.
923  if (!structMemberTypes.empty())
924  memberTypesAndIsBodySet.setPointer(
925  allocator.copyInto(structMemberTypes).data());
926 
927  if (!structOffsetInfo.empty()) {
928  assert(structOffsetInfo.size() == structMemberTypes.size() &&
929  "size of offset information must be same as the size of number of "
930  "elements");
931  offsetInfo = allocator.copyInto(structOffsetInfo).data();
932  }
933 
934  if (!structMemberDecorationInfo.empty()) {
935  numMemberDecorations = structMemberDecorationInfo.size();
936  memberDecorationsInfo =
937  allocator.copyInto(structMemberDecorationInfo).data();
938  }
939 
940  return success();
941  }
942 
943  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
945  unsigned numMembers;
948  StringRef identifier;
949 };
950 
954  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
955  assert(!memberTypes.empty() && "Struct needs at least one member type");
956  // Sort the decorations.
958  memberDecorations.begin(), memberDecorations.end());
959  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
960  return Base::get(memberTypes.vec().front().getContext(),
961  /*identifier=*/StringRef(), memberTypes, offsetInfo,
962  sortedDecorations);
963 }
964 
966  StringRef identifier) {
967  assert(!identifier.empty() &&
968  "StructType identifier must be non-empty string");
969 
970  return Base::get(context, identifier, ArrayRef<Type>(),
973 }
974 
975 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
976  StructType newStructType = Base::get(
977  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
979  // Set an empty body in case this is a identified struct.
980  if (newStructType.isIdentified() &&
981  failed(newStructType.trySetBody(
984  return StructType();
985 
986  return newStructType;
987 }
988 
989 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
990 
991 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
992 
993 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
994 
995 Type StructType::getElementType(unsigned index) const {
996  assert(getNumElements() > index && "member index out of range");
997  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
998 }
999 
1001  return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1002  getNumElements());
1003 }
1004 
1005 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1006 
1007 uint64_t StructType::getMemberOffset(unsigned index) const {
1008  assert(getNumElements() > index && "member index out of range");
1009  return getImpl()->offsetInfo[index];
1010 }
1011 
1014  const {
1015  memberDecorations.clear();
1016  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1017  memberDecorations.append(implMemberDecorations.begin(),
1018  implMemberDecorations.end());
1019 }
1020 
1022  unsigned index,
1023  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1024  assert(getNumElements() > index && "member index out of range");
1025  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1026  decorationsInfo.clear();
1027  for (const auto &memberDecoration : memberDecorations) {
1028  if (memberDecoration.memberIndex == index) {
1029  decorationsInfo.push_back(memberDecoration);
1030  }
1031  if (memberDecoration.memberIndex > index) {
1032  // Early exit since the decorations are stored sorted.
1033  return;
1034  }
1035  }
1036 }
1037 
1040  ArrayRef<OffsetInfo> offsetInfo,
1041  ArrayRef<MemberDecorationInfo> memberDecorations) {
1042  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1043 }
1044 
1046  Optional<StorageClass> storage) {
1047  for (Type elementType : getElementTypes())
1048  elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1049 }
1050 
1053  Optional<StorageClass> storage) {
1054  for (Type elementType : getElementTypes())
1055  elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1056 }
1057 
1058 llvm::hash_code spirv::hash_value(
1059  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1060  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1061  memberDecorationInfo.decoration);
1062 }
1063 
1064 //===----------------------------------------------------------------------===//
1065 // MatrixType
1066 //===----------------------------------------------------------------------===//
1067 
1069  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1070  : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1071 
1072  using KeyTy = std::tuple<Type, uint32_t>;
1073 
1075  const KeyTy &key) {
1076 
1077  // Initialize the memory using placement new.
1078  return new (allocator.allocate<MatrixTypeStorage>())
1079  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1080  }
1081 
1082  bool operator==(const KeyTy &key) const {
1083  return key == KeyTy(columnType, columnCount);
1084  }
1085 
1087  const uint32_t columnCount;
1088 };
1089 
1090 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1091  return Base::get(columnType.getContext(), columnType, columnCount);
1092 }
1093 
1095  Type columnType, uint32_t columnCount) {
1096  return Base::getChecked(emitError, columnType.getContext(), columnType,
1097  columnCount);
1098 }
1099 
1101  Type columnType, uint32_t columnCount) {
1102  if (columnCount < 2 || columnCount > 4)
1103  return emitError() << "matrix can have 2, 3, or 4 columns only";
1104 
1105  if (!isValidColumnType(columnType))
1106  return emitError() << "matrix columns must be vectors of floats";
1107 
1108  /// The underlying vectors (columns) must be of size 2, 3, or 4
1109  ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1110  if (columnShape.size() != 1)
1111  return emitError() << "matrix columns must be 1D vectors";
1112 
1113  if (columnShape[0] < 2 || columnShape[0] > 4)
1114  return emitError() << "matrix columns must be of size 2, 3, or 4";
1115 
1116  return success();
1117 }
1118 
1119 /// Returns true if the matrix elements are vectors of float elements
1121  if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1122  if (vectorType.getElementType().isa<FloatType>())
1123  return true;
1124  }
1125  return false;
1126 }
1127 
1128 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1129 
1131  return getImpl()->columnType.cast<VectorType>().getElementType();
1132 }
1133 
1134 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1135 
1136 unsigned MatrixType::getNumRows() const {
1137  return getImpl()->columnType.cast<VectorType>().getShape()[0];
1138 }
1139 
1140 unsigned MatrixType::getNumElements() const {
1141  return (getImpl()->columnCount) * getNumRows();
1142 }
1143 
1145  Optional<StorageClass> storage) {
1146  getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1147 }
1148 
1151  Optional<StorageClass> storage) {
1152  {
1153  static const Capability caps[] = {Capability::Matrix};
1154  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1155  capabilities.push_back(ref);
1156  }
1157  // Add any capabilities associated with the underlying vectors (i.e., columns)
1158  getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // SPIR-V Dialect
1163 //===----------------------------------------------------------------------===//
1164 
1165 void SPIRVDialect::registerTypes() {
1168 }
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:268
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:66
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:349
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:122
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:311
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:160
Base storage class appearing in a Type.
Definition: TypeSupport.h:121
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:506
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This is a utility allocator used to allocate memory for instances of derived types.
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static bool classof(Type type)
Definition: SPIRVTypes.cpp:642
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:372
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:906
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:240
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:304
bool isIdentified() const
Returns true if the StructType is identified.
Definition: SPIRVTypes.cpp:991
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
Type getElementType() const
Definition: SPIRVTypes.cpp:331
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:785
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:832
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:351
std::tuple< Type, uint32_t > KeyTy
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:412
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:491
static constexpr const bool value
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:111
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:712
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:879
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:273
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:952
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:736
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:437
T * allocate()
Allocate an instance of the provided type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:180
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:723
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Range class for element types.
Definition: SPIRVTypes.h:347
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:77
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:750
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:380
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:222
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:374
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:35
U dyn_cast() const
Definition: Types.h:256
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:809
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
#define WIDTH_CASE(type, width)
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
Type getElementType() const
Returns the elements&#39; type (i.e, single element type).
StructType::MemberDecorationInfo const * memberDecorationsInfo
Definition: SPIRVTypes.cpp:947
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:460
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:744
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:431
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:778
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:401
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:335
unsigned getNumElements() const
Returns total number of elements (rows*columns).
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:989
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:875
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent...
Definition: SPIRVTypes.cpp:142
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:28
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:297
uint64_t getMemberOffset(unsigned) const
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:975
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:30
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:625
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
Definition: SPIRVTypes.cpp:943
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:278
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:714
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
#define STORAGE_CASE(storage, cap8, cap16)
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:283
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:249
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:993
ElementTypeRange getElementTypes() const
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:200
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:995
static bool classof(Type type)
Definition: SPIRVTypes.cpp:481
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:299
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:537
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:774
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
StructType::OffsetInfo const * offsetInfo
Definition: SPIRVTypes.cpp:944
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:466
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:203
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:886
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
MatrixTypeStorage(Type columnType, uint32_t columnCount)
SPIR-V struct type.
Definition: SPIRVTypes.h:278
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:728
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:262
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:397
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:816
void getCapabilities(CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Appends to capabilities the capabilities needed for this type to appear in the given storage class...
Definition: SPIRVTypes.cpp:676
void getExtensions(ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Appends to extensions the extensions needed for this type to appear in the given storage class...
Definition: SPIRVTypes.cpp:657
bool operator==(const KeyTy &key) const
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
static bool classof(Type type)
Definition: SPIRVTypes.cpp:89
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
static MatrixType get(Type columnType, uint32_t columnCount)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:965
unsigned getNumRows() const
Returns the number of rows.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
bool isa() const
Definition: Types.h:246
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Type getColumnType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:146
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:263
Optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:696
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:71
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:356
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
bool isBF16() const
Definition: Types.cpp:21
U cast() const
Definition: Types.h:262
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:288
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:97