MLIR  20.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
32  using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
35  const KeyTy &key) {
36  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37  }
38 
39  bool operator==(const KeyTy &key) const {
40  return key == KeyTy(elementType, elementCount, stride);
41  }
42 
43  ArrayTypeStorage(const KeyTy &key)
44  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45  stride(std::get<2>(key)) {}
46 
48  unsigned elementCount;
49  unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53  assert(elementCount && "ArrayType needs at least one element");
54  return Base::get(elementType.getContext(), elementType, elementCount,
55  /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59  unsigned stride) {
60  assert(elementCount && "ArrayType needs at least one element");
61  return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
71  std::optional<StorageClass> storage) {
72  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
77  std::optional<StorageClass> storage) {
78  llvm::cast<SPIRVType>(getElementType())
79  .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83  auto elementType = llvm::cast<SPIRVType>(getElementType());
84  std::optional<int64_t> size = elementType.getSizeInBytes();
85  if (!size)
86  return std::nullopt;
87  return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
95  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96  return isValid(vectorType);
99  spirv::StructType>(type);
100 }
101 
102 bool CompositeType::isValid(VectorType type) {
103  return type.getRank() == 1 &&
104  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105  llvm::isa<ScalarType>(type.getElementType());
106 }
107 
108 Type CompositeType::getElementType(unsigned index) const {
109  return TypeSwitch<Type, Type>(*this)
110  .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
111  [](auto type) { return type.getElementType(); })
112  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
113  .Case<StructType>(
114  [index](StructType type) { return type.getElementType(index); })
115  .Default(
116  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117 }
118 
120  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
121  return arrayType.getNumElements();
122  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
123  return matrixType.getNumColumns();
124  if (auto structType = llvm::dyn_cast<StructType>(*this))
125  return structType.getNumElements();
126  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
127  return vectorType.getNumElements();
128  if (llvm::isa<CooperativeMatrixType>(*this)) {
129  llvm_unreachable(
130  "invalid to query number of elements of spirv Cooperative Matrix type");
131  }
132  if (llvm::isa<RuntimeArrayType>(*this)) {
133  llvm_unreachable(
134  "invalid to query number of elements of spirv::RuntimeArray type");
135  }
136  llvm_unreachable("invalid composite type");
137 }
138 
140  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
145  std::optional<StorageClass> storage) {
146  TypeSwitch<Type>(*this)
148  StructType>(
149  [&](auto type) { type.getExtensions(extensions, storage); })
150  .Case<VectorType>([&](VectorType type) {
151  return llvm::cast<ScalarType>(type.getElementType())
152  .getExtensions(extensions, storage);
153  })
154  .Default([](Type) { llvm_unreachable("invalid composite type"); });
155 }
156 
159  std::optional<StorageClass> storage) {
160  TypeSwitch<Type>(*this)
162  StructType>(
163  [&](auto type) { type.getCapabilities(capabilities, storage); })
164  .Case<VectorType>([&](VectorType type) {
165  auto vecSize = getNumElements();
166  if (vecSize == 8 || vecSize == 16) {
167  static const Capability caps[] = {Capability::Vector16};
168  ArrayRef<Capability> ref(caps, std::size(caps));
169  capabilities.push_back(ref);
170  }
171  return llvm::cast<ScalarType>(type.getElementType())
172  .getCapabilities(capabilities, storage);
173  })
174  .Default([](Type) { llvm_unreachable("invalid composite type"); });
175 }
176 
177 std::optional<int64_t> CompositeType::getSizeInBytes() {
178  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
179  return arrayType.getSizeInBytes();
180  if (auto structType = llvm::dyn_cast<StructType>(*this))
181  return structType.getSizeInBytes();
182  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
183  std::optional<int64_t> elementSize =
184  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
185  if (!elementSize)
186  return std::nullopt;
187  return *elementSize * vectorType.getNumElements();
188  }
189  return std::nullopt;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // CooperativeMatrixType
194 //===----------------------------------------------------------------------===//
195 
197  using KeyTy =
198  std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
199 
201  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
202  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
204  }
205 
206  bool operator==(const KeyTy &key) const {
207  return key == KeyTy(elementType, rows, columns, scope, use);
208  }
209 
211  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
212  columns(std::get<2>(key)), scope(std::get<3>(key)),
213  use(std::get<4>(key)) {}
214 
216  uint32_t rows;
217  uint32_t columns;
218  Scope scope;
219  CooperativeMatrixUseKHR use;
220 };
221 
223  uint32_t rows,
224  uint32_t columns, Scope scope,
225  CooperativeMatrixUseKHR use) {
226  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
227  use);
228 }
229 
231  return getImpl()->elementType;
232 }
233 
234 uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
235 
237  return getImpl()->columns;
238 }
239 
240 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
241 
242 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
243  return getImpl()->use;
244 }
245 
248  std::optional<StorageClass> storage) {
249  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
250  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
251  extensions.push_back(exts);
252 }
253 
256  std::optional<StorageClass> storage) {
257  llvm::cast<SPIRVType>(getElementType())
258  .getCapabilities(capabilities, storage);
259  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
260  capabilities.push_back(caps);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // ImageType
265 //===----------------------------------------------------------------------===//
266 
267 template <typename T>
268 static constexpr unsigned getNumBits() {
269  return 0;
270 }
271 template <>
272 constexpr unsigned getNumBits<Dim>() {
273  static_assert((1 << 3) > getMaxEnumValForDim(),
274  "Not enough bits to encode Dim value");
275  return 3;
276 }
277 template <>
278 constexpr unsigned getNumBits<ImageDepthInfo>() {
279  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
280  "Not enough bits to encode ImageDepthInfo value");
281  return 2;
282 }
283 template <>
284 constexpr unsigned getNumBits<ImageArrayedInfo>() {
285  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
286  "Not enough bits to encode ImageArrayedInfo value");
287  return 1;
288 }
289 template <>
290 constexpr unsigned getNumBits<ImageSamplingInfo>() {
291  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
292  "Not enough bits to encode ImageSamplingInfo value");
293  return 1;
294 }
295 template <>
296 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
297  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
298  "Not enough bits to encode ImageSamplerUseInfo value");
299  return 2;
300 }
301 template <>
302 constexpr unsigned getNumBits<ImageFormat>() {
303  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
304  "Not enough bits to encode ImageFormat value");
305  return 6;
306 }
307 
309 public:
310  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
311  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
312 
314  const KeyTy &key) {
315  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
316  }
317 
318  bool operator==(const KeyTy &key) const {
319  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
320  samplerUseInfo, format);
321  }
322 
324  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
325  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
326  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
327  format(std::get<6>(key)) {}
328 
336 };
337 
338 ImageType
339 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
340  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
341  value) {
342  return Base::get(std::get<0>(value).getContext(), value);
343 }
344 
345 Type ImageType::getElementType() const { return getImpl()->elementType; }
346 
347 Dim ImageType::getDim() const { return getImpl()->dim; }
348 
349 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
350 
351 ImageArrayedInfo ImageType::getArrayedInfo() const {
352  return getImpl()->arrayedInfo;
353 }
354 
355 ImageSamplingInfo ImageType::getSamplingInfo() const {
356  return getImpl()->samplingInfo;
357 }
358 
359 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
360  return getImpl()->samplerUseInfo;
361 }
362 
363 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
364 
366  std::optional<StorageClass>) {
367  // Image types do not require extra extensions thus far.
368 }
369 
372  std::optional<StorageClass>) {
373  if (auto dimCaps = spirv::getCapabilities(getDim()))
374  capabilities.push_back(*dimCaps);
375 
376  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
377  capabilities.push_back(*fmtCaps);
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // PointerType
382 //===----------------------------------------------------------------------===//
383 
385  // (Type, StorageClass) as the key: Type stored in this struct, and
386  // StorageClass stored as TypeStorage's subclass data.
387  using KeyTy = std::pair<Type, StorageClass>;
388 
390  const KeyTy &key) {
391  return new (allocator.allocate<PointerTypeStorage>())
392  PointerTypeStorage(key);
393  }
394 
395  bool operator==(const KeyTy &key) const {
396  return key == KeyTy(pointeeType, storageClass);
397  }
398 
400  : pointeeType(key.first), storageClass(key.second) {}
401 
403  StorageClass storageClass;
404 };
405 
406 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
407  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
408 }
409 
410 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
411 
412 StorageClass PointerType::getStorageClass() const {
413  return getImpl()->storageClass;
414 }
415 
417  std::optional<StorageClass> storage) {
418  // Use this pointer type's storage class because this pointer indicates we are
419  // using the pointee type in that specific storage class.
420  llvm::cast<SPIRVType>(getPointeeType())
421  .getExtensions(extensions, getStorageClass());
422 
423  if (auto scExts = spirv::getExtensions(getStorageClass()))
424  extensions.push_back(*scExts);
425 }
426 
429  std::optional<StorageClass> storage) {
430  // Use this pointer type's storage class because this pointer indicates we are
431  // using the pointee type in that specific storage class.
432  llvm::cast<SPIRVType>(getPointeeType())
433  .getCapabilities(capabilities, getStorageClass());
434 
435  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
436  capabilities.push_back(*scCaps);
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // RuntimeArrayType
441 //===----------------------------------------------------------------------===//
442 
444  using KeyTy = std::pair<Type, unsigned>;
445 
447  const KeyTy &key) {
448  return new (allocator.allocate<RuntimeArrayTypeStorage>())
450  }
451 
452  bool operator==(const KeyTy &key) const {
453  return key == KeyTy(elementType, stride);
454  }
455 
457  : elementType(key.first), stride(key.second) {}
458 
460  unsigned stride;
461 };
462 
464  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
465 }
466 
467 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
468  return Base::get(elementType.getContext(), elementType, stride);
469 }
470 
471 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
472 
473 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
474 
477  std::optional<StorageClass> storage) {
478  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
479 }
480 
483  std::optional<StorageClass> storage) {
484  {
485  static const Capability caps[] = {Capability::Shader};
486  ArrayRef<Capability> ref(caps, std::size(caps));
487  capabilities.push_back(ref);
488  }
489  llvm::cast<SPIRVType>(getElementType())
490  .getCapabilities(capabilities, storage);
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // ScalarType
495 //===----------------------------------------------------------------------===//
496 
498  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
499  return isValid(floatType);
500  }
501  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
502  return isValid(intType);
503  }
504  return false;
505 }
506 
508  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
509 }
510 
511 bool ScalarType::isValid(IntegerType type) {
512  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
513 }
514 
516  std::optional<StorageClass> storage) {
517  // 8- or 16-bit integer/floating-point numbers will require extra extensions
518  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
519  // SPV_KHR_8bit_storage for more details.
520  if (!storage)
521  return;
522 
523  switch (*storage) {
524  case StorageClass::PushConstant:
525  case StorageClass::StorageBuffer:
526  case StorageClass::Uniform:
527  if (getIntOrFloatBitWidth() == 8) {
528  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
529  ArrayRef<Extension> ref(exts, std::size(exts));
530  extensions.push_back(ref);
531  }
532  [[fallthrough]];
533  case StorageClass::Input:
534  case StorageClass::Output:
535  if (getIntOrFloatBitWidth() == 16) {
536  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
537  ArrayRef<Extension> ref(exts, std::size(exts));
538  extensions.push_back(ref);
539  }
540  break;
541  default:
542  break;
543  }
544 }
545 
548  std::optional<StorageClass> storage) {
549  unsigned bitwidth = getIntOrFloatBitWidth();
550 
551  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
552  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
553  // SPV_KHR_8bit_storage for more details.
554 
555 #define STORAGE_CASE(storage, cap8, cap16) \
556  case StorageClass::storage: { \
557  if (bitwidth == 8) { \
558  static const Capability caps[] = {Capability::cap8}; \
559  ArrayRef<Capability> ref(caps, std::size(caps)); \
560  capabilities.push_back(ref); \
561  return; \
562  } \
563  if (bitwidth == 16) { \
564  static const Capability caps[] = {Capability::cap16}; \
565  ArrayRef<Capability> ref(caps, std::size(caps)); \
566  capabilities.push_back(ref); \
567  return; \
568  } \
569  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
570  /* storage classes. Fall through to the next section. */ \
571  } break
572 
573  // This part only handles the cases where special bitwidths appearing in
574  // interface storage classes.
575  if (storage) {
576  switch (*storage) {
577  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
578  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
579  StorageBuffer16BitAccess);
580  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
581  StorageUniform16);
582  case StorageClass::Input:
583  case StorageClass::Output: {
584  if (bitwidth == 16) {
585  static const Capability caps[] = {Capability::StorageInputOutput16};
586  ArrayRef<Capability> ref(caps, std::size(caps));
587  capabilities.push_back(ref);
588  return;
589  }
590  break;
591  }
592  default:
593  break;
594  }
595  }
596 #undef STORAGE_CASE
597 
598  // For other non-interface storage classes, require a different set of
599  // capabilities for special bitwidths.
600 
601 #define WIDTH_CASE(type, width) \
602  case width: { \
603  static const Capability caps[] = {Capability::type##width}; \
604  ArrayRef<Capability> ref(caps, std::size(caps)); \
605  capabilities.push_back(ref); \
606  } break
607 
608  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
609  switch (bitwidth) {
610  WIDTH_CASE(Int, 8);
611  WIDTH_CASE(Int, 16);
612  WIDTH_CASE(Int, 64);
613  case 1:
614  case 32:
615  break;
616  default:
617  llvm_unreachable("invalid bitwidth to getCapabilities");
618  }
619  } else {
620  assert(llvm::isa<FloatType>(*this));
621  switch (bitwidth) {
622  WIDTH_CASE(Float, 16);
623  WIDTH_CASE(Float, 64);
624  case 32:
625  break;
626  default:
627  llvm_unreachable("invalid bitwidth to getCapabilities");
628  }
629  }
630 
631 #undef WIDTH_CASE
632 }
633 
634 std::optional<int64_t> ScalarType::getSizeInBytes() {
635  auto bitWidth = getIntOrFloatBitWidth();
636  // According to the SPIR-V spec:
637  // "There is no physical size or bit pattern defined for values with boolean
638  // type. If they are stored (in conjunction with OpVariable), they can only
639  // be used with logical addressing operations, not physical, and only with
640  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
641  // Private, Function, Input, and Output."
642  if (bitWidth == 1)
643  return std::nullopt;
644  return bitWidth / 8;
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // SPIRVType
649 //===----------------------------------------------------------------------===//
650 
652  // Allow SPIR-V dialect types
653  if (llvm::isa<SPIRVDialect>(type.getDialect()))
654  return true;
655  if (llvm::isa<ScalarType>(type))
656  return true;
657  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
658  return CompositeType::isValid(vectorType);
659  return false;
660 }
661 
663  return isIntOrFloat() || llvm::isa<VectorType>(*this);
664 }
665 
667  std::optional<StorageClass> storage) {
668  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
669  scalarType.getExtensions(extensions, storage);
670  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
671  compositeType.getExtensions(extensions, storage);
672  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
673  imageType.getExtensions(extensions, storage);
674  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
675  sampledImageType.getExtensions(extensions, storage);
676  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
677  matrixType.getExtensions(extensions, storage);
678  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
679  ptrType.getExtensions(extensions, storage);
680  } else {
681  llvm_unreachable("invalid SPIR-V Type to getExtensions");
682  }
683 }
684 
687  std::optional<StorageClass> storage) {
688  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
689  scalarType.getCapabilities(capabilities, storage);
690  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
691  compositeType.getCapabilities(capabilities, storage);
692  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
693  imageType.getCapabilities(capabilities, storage);
694  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
695  sampledImageType.getCapabilities(capabilities, storage);
696  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
697  matrixType.getCapabilities(capabilities, storage);
698  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
699  ptrType.getCapabilities(capabilities, storage);
700  } else {
701  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
702  }
703 }
704 
705 std::optional<int64_t> SPIRVType::getSizeInBytes() {
706  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
707  return scalarType.getSizeInBytes();
708  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
709  return compositeType.getSizeInBytes();
710  return std::nullopt;
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // SampledImageType
715 //===----------------------------------------------------------------------===//
717  using KeyTy = Type;
718 
719  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
720 
721  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
722 
724  const KeyTy &key) {
725  return new (allocator.allocate<SampledImageTypeStorage>())
727  }
728 
730 };
731 
733  return Base::get(imageType.getContext(), imageType);
734 }
735 
738  Type imageType) {
739  return Base::getChecked(emitError, imageType.getContext(), imageType);
740 }
741 
742 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
743 
744 LogicalResult
746  Type imageType) {
747  if (!llvm::isa<ImageType>(imageType))
748  return emitError() << "expected image type";
749 
750  return success();
751 }
752 
755  std::optional<StorageClass> storage) {
756  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
757 }
758 
761  std::optional<StorageClass> storage) {
762  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // StructType
767 //===----------------------------------------------------------------------===//
768 
769 /// Type storage for SPIR-V structure types:
770 ///
771 /// Structures are uniqued using:
772 /// - for identified structs:
773 /// - a string identifier;
774 /// - for literal structs:
775 /// - a list of member types;
776 /// - a list of member offset info;
777 /// - a list of member decoration info.
778 ///
779 /// Identified structures only have a mutable component consisting of:
780 /// - a list of member types;
781 /// - a list of member offset info;
782 /// - a list of member decoration info.
784  /// Construct a storage object for an identified struct type. A struct type
785  /// associated with such storage must call StructType::trySetBody(...) later
786  /// in order to mutate the storage object providing the actual content.
787  StructTypeStorage(StringRef identifier)
788  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
789  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
790  identifier(identifier) {}
791 
792  /// Construct a storage object for a literal struct type. A struct type
793  /// associated with such storage is immutable.
795  unsigned numMembers, Type const *memberTypes,
796  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
797  StructType::MemberDecorationInfo const *memberDecorationsInfo)
798  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
799  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
800  memberDecorationsInfo(memberDecorationsInfo) {}
801 
802  /// A storage key is divided into 2 parts:
803  /// - for identified structs:
804  /// - a StringRef representing the struct identifier;
805  /// - for literal structs:
806  /// - an ArrayRef<Type> for member types;
807  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
808  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
809  /// info.
810  ///
811  /// An identified struct type is uniqued only by the first part (field 0)
812  /// of the key.
813  ///
814  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
815  /// 3) of the key. The identifier field (field 0) must be empty.
816  using KeyTy =
817  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
819 
820  /// For identified structs, return true if the given key contains the same
821  /// identifier.
822  ///
823  /// For literal structs, return true if the given key contains a matching list
824  /// of member types + offset info + decoration info.
825  bool operator==(const KeyTy &key) const {
826  if (isIdentified()) {
827  // Identified types are uniqued by their identifier.
828  return getIdentifier() == std::get<0>(key);
829  }
830 
831  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
832  getMemberDecorationsInfo());
833  }
834 
835  /// If the given key contains a non-empty identifier, this method constructs
836  /// an identified struct and leaves the rest of the struct type data to be set
837  /// through a later call to StructType::trySetBody(...).
838  ///
839  /// If, on the other hand, the key contains an empty identifier, a literal
840  /// struct is constructed using the other fields of the key.
842  const KeyTy &key) {
843  StringRef keyIdentifier = std::get<0>(key);
844 
845  if (!keyIdentifier.empty()) {
846  StringRef identifier = allocator.copyInto(keyIdentifier);
847 
848  // Identified StructType body/members will be set through trySetBody(...)
849  // later.
850  return new (allocator.allocate<StructTypeStorage>())
851  StructTypeStorage(identifier);
852  }
853 
854  ArrayRef<Type> keyTypes = std::get<1>(key);
855 
856  // Copy the member type and layout information into the bump pointer
857  const Type *typesList = nullptr;
858  if (!keyTypes.empty()) {
859  typesList = allocator.copyInto(keyTypes).data();
860  }
861 
862  const StructType::OffsetInfo *offsetInfoList = nullptr;
863  if (!std::get<2>(key).empty()) {
864  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
865  assert(keyOffsetInfo.size() == keyTypes.size() &&
866  "size of offset information must be same as the size of number of "
867  "elements");
868  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
869  }
870 
871  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
872  unsigned numMemberDecorations = 0;
873  if (!std::get<3>(key).empty()) {
874  auto keyMemberDecorations = std::get<3>(key);
875  numMemberDecorations = keyMemberDecorations.size();
876  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
877  }
878 
879  return new (allocator.allocate<StructTypeStorage>())
880  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
881  numMemberDecorations, memberDecorationList);
882  }
883 
885  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
886  }
887 
889  if (offsetInfo) {
890  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
891  }
892  return {};
893  }
894 
896  if (memberDecorationsInfo) {
897  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
898  numMemberDecorations);
899  }
900  return {};
901  }
902 
903  StringRef getIdentifier() const { return identifier; }
904 
905  bool isIdentified() const { return !identifier.empty(); }
906 
907  /// Sets the struct type content for identified structs. Calling this method
908  /// is only valid for identified structs.
909  ///
910  /// Fails under the following conditions:
911  /// - If called for a literal struct;
912  /// - If called for an identified struct whose body was set before (through a
913  /// call to this method) but with different contents from the passed
914  /// arguments.
915  LogicalResult mutate(
916  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
917  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
918  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
919  if (!isIdentified())
920  return failure();
921 
922  if (memberTypesAndIsBodySet.getInt() &&
923  (getMemberTypes() != structMemberTypes ||
924  getOffsetInfo() != structOffsetInfo ||
925  getMemberDecorationsInfo() != structMemberDecorationInfo))
926  return failure();
927 
928  memberTypesAndIsBodySet.setInt(true);
929  numMembers = structMemberTypes.size();
930 
931  // Copy the member type and layout information into the bump pointer.
932  if (!structMemberTypes.empty())
933  memberTypesAndIsBodySet.setPointer(
934  allocator.copyInto(structMemberTypes).data());
935 
936  if (!structOffsetInfo.empty()) {
937  assert(structOffsetInfo.size() == structMemberTypes.size() &&
938  "size of offset information must be same as the size of number of "
939  "elements");
940  offsetInfo = allocator.copyInto(structOffsetInfo).data();
941  }
942 
943  if (!structMemberDecorationInfo.empty()) {
944  numMemberDecorations = structMemberDecorationInfo.size();
945  memberDecorationsInfo =
946  allocator.copyInto(structMemberDecorationInfo).data();
947  }
948 
949  return success();
950  }
951 
952  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
954  unsigned numMembers;
957  StringRef identifier;
958 };
959 
963  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
964  assert(!memberTypes.empty() && "Struct needs at least one member type");
965  // Sort the decorations.
967  memberDecorations);
968  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
969  return Base::get(memberTypes.vec().front().getContext(),
970  /*identifier=*/StringRef(), memberTypes, offsetInfo,
971  sortedDecorations);
972 }
973 
975  StringRef identifier) {
976  assert(!identifier.empty() &&
977  "StructType identifier must be non-empty string");
978 
979  return Base::get(context, identifier, ArrayRef<Type>(),
982 }
983 
984 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
985  StructType newStructType = Base::get(
986  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
988  // Set an empty body in case this is a identified struct.
989  if (newStructType.isIdentified() &&
990  failed(newStructType.trySetBody(
993  return StructType();
994 
995  return newStructType;
996 }
997 
998 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
999 
1000 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1001 
1002 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1003 
1004 Type StructType::getElementType(unsigned index) const {
1005  assert(getNumElements() > index && "member index out of range");
1006  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1007 }
1008 
1010  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1011  getNumElements());
1012 }
1013 
1014 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1015 
1016 uint64_t StructType::getMemberOffset(unsigned index) const {
1017  assert(getNumElements() > index && "member index out of range");
1018  return getImpl()->offsetInfo[index];
1019 }
1020 
1023  const {
1024  memberDecorations.clear();
1025  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1026  memberDecorations.append(implMemberDecorations.begin(),
1027  implMemberDecorations.end());
1028 }
1029 
1031  unsigned index,
1032  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1033  assert(getNumElements() > index && "member index out of range");
1034  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1035  decorationsInfo.clear();
1036  for (const auto &memberDecoration : memberDecorations) {
1037  if (memberDecoration.memberIndex == index) {
1038  decorationsInfo.push_back(memberDecoration);
1039  }
1040  if (memberDecoration.memberIndex > index) {
1041  // Early exit since the decorations are stored sorted.
1042  return;
1043  }
1044  }
1045 }
1046 
1047 LogicalResult
1049  ArrayRef<OffsetInfo> offsetInfo,
1050  ArrayRef<MemberDecorationInfo> memberDecorations) {
1051  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1052 }
1053 
1055  std::optional<StorageClass> storage) {
1056  for (Type elementType : getElementTypes())
1057  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1058 }
1059 
1062  std::optional<StorageClass> storage) {
1063  for (Type elementType : getElementTypes())
1064  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1065 }
1066 
1067 llvm::hash_code spirv::hash_value(
1068  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1069  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1070  memberDecorationInfo.decoration);
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // MatrixType
1075 //===----------------------------------------------------------------------===//
1076 
1078  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1079  : columnType(columnType), columnCount(columnCount) {}
1080 
1081  using KeyTy = std::tuple<Type, uint32_t>;
1082 
1084  const KeyTy &key) {
1085 
1086  // Initialize the memory using placement new.
1087  return new (allocator.allocate<MatrixTypeStorage>())
1088  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1089  }
1090 
1091  bool operator==(const KeyTy &key) const {
1092  return key == KeyTy(columnType, columnCount);
1093  }
1094 
1096  const uint32_t columnCount;
1097 };
1098 
1099 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1100  return Base::get(columnType.getContext(), columnType, columnCount);
1101 }
1102 
1104  Type columnType, uint32_t columnCount) {
1105  return Base::getChecked(emitError, columnType.getContext(), columnType,
1106  columnCount);
1107 }
1108 
1109 LogicalResult
1111  Type columnType, uint32_t columnCount) {
1112  if (columnCount < 2 || columnCount > 4)
1113  return emitError() << "matrix can have 2, 3, or 4 columns only";
1114 
1115  if (!isValidColumnType(columnType))
1116  return emitError() << "matrix columns must be vectors of floats";
1117 
1118  /// The underlying vectors (columns) must be of size 2, 3, or 4
1119  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1120  if (columnShape.size() != 1)
1121  return emitError() << "matrix columns must be 1D vectors";
1122 
1123  if (columnShape[0] < 2 || columnShape[0] > 4)
1124  return emitError() << "matrix columns must be of size 2, 3, or 4";
1125 
1126  return success();
1127 }
1128 
1129 /// Returns true if the matrix elements are vectors of float elements
1131  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1132  if (llvm::isa<FloatType>(vectorType.getElementType()))
1133  return true;
1134  }
1135  return false;
1136 }
1137 
1138 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1139 
1141  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1142 }
1143 
1144 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1145 
1146 unsigned MatrixType::getNumRows() const {
1147  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1148 }
1149 
1150 unsigned MatrixType::getNumElements() const {
1151  return (getImpl()->columnCount) * getNumRows();
1152 }
1153 
1155  std::optional<StorageClass> storage) {
1156  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1157 }
1158 
1161  std::optional<StorageClass> storage) {
1162  {
1163  static const Capability caps[] = {Capability::Matrix};
1164  ArrayRef<Capability> ref(caps, std::size(caps));
1165  capabilities.push_back(ref);
1166  }
1167  // Add any capabilities associated with the underlying vectors (i.e., columns)
1168  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // SPIR-V Dialect
1173 //===----------------------------------------------------------------------===//
1174 
1175 void SPIRVDialect::registerTypes() {
1178 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:296
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:302
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:268
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:284
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:290
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:272
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:278
int64_t rows
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:121
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:127
bool isBF16() const
Definition: Types.cpp:50
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:82
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:70
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:75
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:157
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:177
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:139
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:143
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:119
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:108
static bool classof(Type type)
Definition: SPIRVTypes.cpp:94
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:222
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:246
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:254
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:242
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:349
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:351
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:363
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:359
Type getElementType() const
Definition: SPIRVTypes.cpp:345
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:355
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:370
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:365
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:416
Type getPointeeType() const
Definition: SPIRVTypes.cpp:410
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:427
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:481
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:475
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:705
static bool classof(Type type)
Definition: SPIRVTypes.cpp:651
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:685
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:666
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:745
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:759
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:737
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:753
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:732
static bool classof(Type type)
Definition: SPIRVTypes.cpp:497
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:515
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:507
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:546
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:634
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:974
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:998
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:984
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:32
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:34
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:39
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:201
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:198
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:318
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:313
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:311
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:395
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:387
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:389
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:446
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:452
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:721
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:723
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:783
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:888
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:794
StructType::OffsetInfo const * offsetInfo
Definition: SPIRVTypes.cpp:953
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:825
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:915
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:841
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:884
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:787
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:895
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:818
StructType::MemberDecorationInfo const * memberDecorationsInfo
Definition: SPIRVTypes.cpp:956
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
Definition: SPIRVTypes.cpp:952