MLIR  17.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <iterator>
22 
23 using namespace mlir;
24 using namespace mlir::spirv;
25 
26 //===----------------------------------------------------------------------===//
27 // ArrayType
28 //===----------------------------------------------------------------------===//
29 
31  using KeyTy = std::tuple<Type, unsigned, unsigned>;
32 
34  const KeyTy &key) {
35  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
36  }
37 
38  bool operator==(const KeyTy &key) const {
39  return key == KeyTy(elementType, elementCount, stride);
40  }
41 
42  ArrayTypeStorage(const KeyTy &key)
43  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
44  stride(std::get<2>(key)) {}
45 
47  unsigned elementCount;
48  unsigned stride;
49 };
50 
51 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
52  assert(elementCount && "ArrayType needs at least one element");
53  return Base::get(elementType.getContext(), elementType, elementCount,
54  /*stride=*/0);
55 }
56 
57 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
58  unsigned stride) {
59  assert(elementCount && "ArrayType needs at least one element");
60  return Base::get(elementType.getContext(), elementType, elementCount, stride);
61 }
62 
63 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
64 
65 Type ArrayType::getElementType() const { return getImpl()->elementType; }
66 
67 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
68 
70  std::optional<StorageClass> storage) {
71  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
72 }
73 
76  std::optional<StorageClass> storage) {
77  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
78 }
79 
80 std::optional<int64_t> ArrayType::getSizeInBytes() {
81  auto elementType = getElementType().cast<SPIRVType>();
82  std::optional<int64_t> size = elementType.getSizeInBytes();
83  if (!size)
84  return std::nullopt;
85  return (*size + getArrayStride()) * getNumElements();
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // CompositeType
90 //===----------------------------------------------------------------------===//
91 
93  if (auto vectorType = type.dyn_cast<VectorType>())
94  return isValid(vectorType);
98 }
99 
100 bool CompositeType::isValid(VectorType type) {
101  switch (type.getNumElements()) {
102  case 2:
103  case 3:
104  case 4:
105  case 8:
106  case 16:
107  break;
108  default:
109  return false;
110  }
111  return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
112 }
113 
114 Type CompositeType::getElementType(unsigned index) const {
115  return TypeSwitch<Type, Type>(*this)
117  RuntimeArrayType, VectorType>(
118  [](auto type) { return type.getElementType(); })
119  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
120  .Case<StructType>(
121  [index](StructType type) { return type.getElementType(index); })
122  .Default(
123  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
124 }
125 
127  if (auto arrayType = dyn_cast<ArrayType>())
128  return arrayType.getNumElements();
129  if (auto matrixType = dyn_cast<MatrixType>())
130  return matrixType.getNumColumns();
131  if (auto structType = dyn_cast<StructType>())
132  return structType.getNumElements();
133  if (auto vectorType = dyn_cast<VectorType>())
134  return vectorType.getNumElements();
135  if (isa<CooperativeMatrixNVType>()) {
136  llvm_unreachable(
137  "invalid to query number of elements of spirv::CooperativeMatrix type");
138  }
139  if (isa<JointMatrixINTELType>()) {
140  llvm_unreachable(
141  "invalid to query number of elements of spirv::JointMatrix type");
142  }
143  if (isa<RuntimeArrayType>()) {
144  llvm_unreachable(
145  "invalid to query number of elements of spirv::RuntimeArray type");
146  }
147  llvm_unreachable("invalid composite type");
148 }
149 
152  RuntimeArrayType>();
153 }
154 
157  std::optional<StorageClass> storage) {
158  TypeSwitch<Type>(*this)
161  [&](auto type) { type.getExtensions(extensions, storage); })
162  .Case<VectorType>([&](VectorType type) {
163  return type.getElementType().cast<ScalarType>().getExtensions(
164  extensions, storage);
165  })
166  .Default([](Type) { llvm_unreachable("invalid composite type"); });
167 }
168 
171  std::optional<StorageClass> storage) {
172  TypeSwitch<Type>(*this)
175  [&](auto type) { type.getCapabilities(capabilities, storage); })
176  .Case<VectorType>([&](VectorType type) {
177  auto vecSize = getNumElements();
178  if (vecSize == 8 || vecSize == 16) {
179  static const Capability caps[] = {Capability::Vector16};
180  ArrayRef<Capability> ref(caps, std::size(caps));
181  capabilities.push_back(ref);
182  }
183  return type.getElementType().cast<ScalarType>().getCapabilities(
184  capabilities, storage);
185  })
186  .Default([](Type) { llvm_unreachable("invalid composite type"); });
187 }
188 
189 std::optional<int64_t> CompositeType::getSizeInBytes() {
190  if (auto arrayType = dyn_cast<ArrayType>())
191  return arrayType.getSizeInBytes();
192  if (auto structType = dyn_cast<StructType>())
193  return structType.getSizeInBytes();
194  if (auto vectorType = dyn_cast<VectorType>()) {
195  std::optional<int64_t> elementSize =
196  vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
197  if (!elementSize)
198  return std::nullopt;
199  return *elementSize * vectorType.getNumElements();
200  }
201  return std::nullopt;
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // CooperativeMatrixType
206 //===----------------------------------------------------------------------===//
207 
209  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
210 
212  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
213  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
215  }
216 
217  bool operator==(const KeyTy &key) const {
218  return key == KeyTy(elementType, scope, rows, columns);
219  }
220 
222  : elementType(std::get<0>(key)), rows(std::get<2>(key)),
223  columns(std::get<3>(key)), scope(std::get<1>(key)) {}
224 
226  unsigned rows;
227  unsigned columns;
228  Scope scope;
229 };
230 
232  Scope scope, unsigned rows,
233  unsigned columns) {
234  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
235 }
236 
238  return getImpl()->elementType;
239 }
240 
241 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
242 
243 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
244 
246  return getImpl()->columns;
247 }
248 
251  std::optional<StorageClass> storage) {
252  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
253  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
254  ArrayRef<Extension> ref(exts, std::size(exts));
255  extensions.push_back(ref);
256 }
257 
260  std::optional<StorageClass> storage) {
261  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
262  static const Capability caps[] = {Capability::CooperativeMatrixNV};
263  ArrayRef<Capability> ref(caps, std::size(caps));
264  capabilities.push_back(ref);
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // JointMatrixType
269 //===----------------------------------------------------------------------===//
270 
272  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
273 
275  const KeyTy &key) {
276  return new (allocator.allocate<JointMatrixTypeStorage>())
278  }
279 
280  bool operator==(const KeyTy &key) const {
281  return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
282  }
283 
285  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
286  columns(std::get<2>(key)), scope(std::get<4>(key)),
287  matrixLayout(std::get<3>(key)) {}
288 
290  unsigned rows;
291  unsigned columns;
292  Scope scope;
293  MatrixLayout matrixLayout;
294 };
295 
297  unsigned rows, unsigned columns,
298  MatrixLayout matrixLayout) {
299  return Base::get(elementType.getContext(), elementType, rows, columns,
300  matrixLayout, scope);
301 }
302 
304  return getImpl()->elementType;
305 }
306 
307 Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
308 
309 unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
310 
311 unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
312 
314  return getImpl()->matrixLayout;
315 }
316 
319  std::optional<StorageClass> storage) {
320  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
321  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
322  ArrayRef<Extension> ref(exts, std::size(exts));
323  extensions.push_back(ref);
324 }
325 
328  std::optional<StorageClass> storage) {
329  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
330  static const Capability caps[] = {Capability::JointMatrixINTEL};
331  ArrayRef<Capability> ref(caps, std::size(caps));
332  capabilities.push_back(ref);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // ImageType
337 //===----------------------------------------------------------------------===//
338 
339 template <typename T>
340 static constexpr unsigned getNumBits() {
341  return 0;
342 }
343 template <>
344 constexpr unsigned getNumBits<Dim>() {
345  static_assert((1 << 3) > getMaxEnumValForDim(),
346  "Not enough bits to encode Dim value");
347  return 3;
348 }
349 template <>
350 constexpr unsigned getNumBits<ImageDepthInfo>() {
351  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
352  "Not enough bits to encode ImageDepthInfo value");
353  return 2;
354 }
355 template <>
356 constexpr unsigned getNumBits<ImageArrayedInfo>() {
357  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
358  "Not enough bits to encode ImageArrayedInfo value");
359  return 1;
360 }
361 template <>
362 constexpr unsigned getNumBits<ImageSamplingInfo>() {
363  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
364  "Not enough bits to encode ImageSamplingInfo value");
365  return 1;
366 }
367 template <>
368 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
369  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
370  "Not enough bits to encode ImageSamplerUseInfo value");
371  return 2;
372 }
373 template <>
374 constexpr unsigned getNumBits<ImageFormat>() {
375  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
376  "Not enough bits to encode ImageFormat value");
377  return 6;
378 }
379 
381 public:
382  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
383  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
384 
386  const KeyTy &key) {
387  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
388  }
389 
390  bool operator==(const KeyTy &key) const {
391  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
392  samplerUseInfo, format);
393  }
394 
396  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
397  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
398  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
399  format(std::get<6>(key)) {}
400 
408 };
409 
410 ImageType
411 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
412  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
413  value) {
414  return Base::get(std::get<0>(value).getContext(), value);
415 }
416 
417 Type ImageType::getElementType() const { return getImpl()->elementType; }
418 
419 Dim ImageType::getDim() const { return getImpl()->dim; }
420 
421 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
422 
423 ImageArrayedInfo ImageType::getArrayedInfo() const {
424  return getImpl()->arrayedInfo;
425 }
426 
427 ImageSamplingInfo ImageType::getSamplingInfo() const {
428  return getImpl()->samplingInfo;
429 }
430 
431 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
432  return getImpl()->samplerUseInfo;
433 }
434 
435 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
436 
438  std::optional<StorageClass>) {
439  // Image types do not require extra extensions thus far.
440 }
441 
444  std::optional<StorageClass>) {
445  if (auto dimCaps = spirv::getCapabilities(getDim()))
446  capabilities.push_back(*dimCaps);
447 
448  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
449  capabilities.push_back(*fmtCaps);
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // PointerType
454 //===----------------------------------------------------------------------===//
455 
457  // (Type, StorageClass) as the key: Type stored in this struct, and
458  // StorageClass stored as TypeStorage's subclass data.
459  using KeyTy = std::pair<Type, StorageClass>;
460 
462  const KeyTy &key) {
463  return new (allocator.allocate<PointerTypeStorage>())
464  PointerTypeStorage(key);
465  }
466 
467  bool operator==(const KeyTy &key) const {
468  return key == KeyTy(pointeeType, storageClass);
469  }
470 
472  : pointeeType(key.first), storageClass(key.second) {}
473 
475  StorageClass storageClass;
476 };
477 
478 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
479  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
480 }
481 
482 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
483 
484 StorageClass PointerType::getStorageClass() const {
485  return getImpl()->storageClass;
486 }
487 
489  std::optional<StorageClass> storage) {
490  // Use this pointer type's storage class because this pointer indicates we are
491  // using the pointee type in that specific storage class.
492  getPointeeType().cast<SPIRVType>().getExtensions(extensions,
493  getStorageClass());
494 
495  if (auto scExts = spirv::getExtensions(getStorageClass()))
496  extensions.push_back(*scExts);
497 }
498 
501  std::optional<StorageClass> storage) {
502  // Use this pointer type's storage class because this pointer indicates we are
503  // using the pointee type in that specific storage class.
504  getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
505  getStorageClass());
506 
507  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
508  capabilities.push_back(*scCaps);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // RuntimeArrayType
513 //===----------------------------------------------------------------------===//
514 
516  using KeyTy = std::pair<Type, unsigned>;
517 
519  const KeyTy &key) {
520  return new (allocator.allocate<RuntimeArrayTypeStorage>())
522  }
523 
524  bool operator==(const KeyTy &key) const {
525  return key == KeyTy(elementType, stride);
526  }
527 
529  : elementType(key.first), stride(key.second) {}
530 
532  unsigned stride;
533 };
534 
536  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
537 }
538 
539 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
540  return Base::get(elementType.getContext(), elementType, stride);
541 }
542 
543 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
544 
545 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
546 
549  std::optional<StorageClass> storage) {
550  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
551 }
552 
555  std::optional<StorageClass> storage) {
556  {
557  static const Capability caps[] = {Capability::Shader};
558  ArrayRef<Capability> ref(caps, std::size(caps));
559  capabilities.push_back(ref);
560  }
561  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // ScalarType
566 //===----------------------------------------------------------------------===//
567 
569  if (auto floatType = type.dyn_cast<FloatType>()) {
570  return isValid(floatType);
571  }
572  if (auto intType = type.dyn_cast<IntegerType>()) {
573  return isValid(intType);
574  }
575  return false;
576 }
577 
579  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
580 }
581 
582 bool ScalarType::isValid(IntegerType type) {
583  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
584 }
585 
587  std::optional<StorageClass> storage) {
588  // 8- or 16-bit integer/floating-point numbers will require extra extensions
589  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
590  // SPV_KHR_8bit_storage for more details.
591  if (!storage)
592  return;
593 
594  switch (*storage) {
595  case StorageClass::PushConstant:
596  case StorageClass::StorageBuffer:
597  case StorageClass::Uniform:
598  if (getIntOrFloatBitWidth() == 8) {
599  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
600  ArrayRef<Extension> ref(exts, std::size(exts));
601  extensions.push_back(ref);
602  }
603  [[fallthrough]];
604  case StorageClass::Input:
605  case StorageClass::Output:
606  if (getIntOrFloatBitWidth() == 16) {
607  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
608  ArrayRef<Extension> ref(exts, std::size(exts));
609  extensions.push_back(ref);
610  }
611  break;
612  default:
613  break;
614  }
615 }
616 
619  std::optional<StorageClass> storage) {
620  unsigned bitwidth = getIntOrFloatBitWidth();
621 
622  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
623  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
624  // SPV_KHR_8bit_storage for more details.
625 
626 #define STORAGE_CASE(storage, cap8, cap16) \
627  case StorageClass::storage: { \
628  if (bitwidth == 8) { \
629  static const Capability caps[] = {Capability::cap8}; \
630  ArrayRef<Capability> ref(caps, std::size(caps)); \
631  capabilities.push_back(ref); \
632  return; \
633  } \
634  if (bitwidth == 16) { \
635  static const Capability caps[] = {Capability::cap16}; \
636  ArrayRef<Capability> ref(caps, std::size(caps)); \
637  capabilities.push_back(ref); \
638  return; \
639  } \
640  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
641  /* storage classes. Fall through to the next section. */ \
642  } break
643 
644  // This part only handles the cases where special bitwidths appearing in
645  // interface storage classes.
646  if (storage) {
647  switch (*storage) {
648  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
649  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
650  StorageBuffer16BitAccess);
651  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
652  StorageUniform16);
653  case StorageClass::Input:
654  case StorageClass::Output: {
655  if (bitwidth == 16) {
656  static const Capability caps[] = {Capability::StorageInputOutput16};
657  ArrayRef<Capability> ref(caps, std::size(caps));
658  capabilities.push_back(ref);
659  return;
660  }
661  break;
662  }
663  default:
664  break;
665  }
666  }
667 #undef STORAGE_CASE
668 
669  // For other non-interface storage classes, require a different set of
670  // capabilities for special bitwidths.
671 
672 #define WIDTH_CASE(type, width) \
673  case width: { \
674  static const Capability caps[] = {Capability::type##width}; \
675  ArrayRef<Capability> ref(caps, std::size(caps)); \
676  capabilities.push_back(ref); \
677  } break
678 
679  if (auto intType = dyn_cast<IntegerType>()) {
680  switch (bitwidth) {
681  WIDTH_CASE(Int, 8);
682  WIDTH_CASE(Int, 16);
683  WIDTH_CASE(Int, 64);
684  case 1:
685  case 32:
686  break;
687  default:
688  llvm_unreachable("invalid bitwidth to getCapabilities");
689  }
690  } else {
691  assert(isa<FloatType>());
692  switch (bitwidth) {
693  WIDTH_CASE(Float, 16);
694  WIDTH_CASE(Float, 64);
695  case 32:
696  break;
697  default:
698  llvm_unreachable("invalid bitwidth to getCapabilities");
699  }
700  }
701 
702 #undef WIDTH_CASE
703 }
704 
705 std::optional<int64_t> ScalarType::getSizeInBytes() {
706  auto bitWidth = getIntOrFloatBitWidth();
707  // According to the SPIR-V spec:
708  // "There is no physical size or bit pattern defined for values with boolean
709  // type. If they are stored (in conjunction with OpVariable), they can only
710  // be used with logical addressing operations, not physical, and only with
711  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
712  // Private, Function, Input, and Output."
713  if (bitWidth == 1)
714  return std::nullopt;
715  return bitWidth / 8;
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // SPIRVType
720 //===----------------------------------------------------------------------===//
721 
723  // Allow SPIR-V dialect types
724  if (llvm::isa<SPIRVDialect>(type.getDialect()))
725  return true;
726  if (type.isa<ScalarType>())
727  return true;
728  if (auto vectorType = type.dyn_cast<VectorType>())
729  return CompositeType::isValid(vectorType);
730  return false;
731 }
732 
734  return isIntOrFloat() || isa<VectorType>();
735 }
736 
738  std::optional<StorageClass> storage) {
739  if (auto scalarType = dyn_cast<ScalarType>()) {
740  scalarType.getExtensions(extensions, storage);
741  } else if (auto compositeType = dyn_cast<CompositeType>()) {
742  compositeType.getExtensions(extensions, storage);
743  } else if (auto imageType = dyn_cast<ImageType>()) {
744  imageType.getExtensions(extensions, storage);
745  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
746  sampledImageType.getExtensions(extensions, storage);
747  } else if (auto matrixType = dyn_cast<MatrixType>()) {
748  matrixType.getExtensions(extensions, storage);
749  } else if (auto ptrType = dyn_cast<PointerType>()) {
750  ptrType.getExtensions(extensions, storage);
751  } else {
752  llvm_unreachable("invalid SPIR-V Type to getExtensions");
753  }
754 }
755 
758  std::optional<StorageClass> storage) {
759  if (auto scalarType = dyn_cast<ScalarType>()) {
760  scalarType.getCapabilities(capabilities, storage);
761  } else if (auto compositeType = dyn_cast<CompositeType>()) {
762  compositeType.getCapabilities(capabilities, storage);
763  } else if (auto imageType = dyn_cast<ImageType>()) {
764  imageType.getCapabilities(capabilities, storage);
765  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
766  sampledImageType.getCapabilities(capabilities, storage);
767  } else if (auto matrixType = dyn_cast<MatrixType>()) {
768  matrixType.getCapabilities(capabilities, storage);
769  } else if (auto ptrType = dyn_cast<PointerType>()) {
770  ptrType.getCapabilities(capabilities, storage);
771  } else {
772  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
773  }
774 }
775 
776 std::optional<int64_t> SPIRVType::getSizeInBytes() {
777  if (auto scalarType = dyn_cast<ScalarType>())
778  return scalarType.getSizeInBytes();
779  if (auto compositeType = dyn_cast<CompositeType>())
780  return compositeType.getSizeInBytes();
781  return std::nullopt;
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // SampledImageType
786 //===----------------------------------------------------------------------===//
788  using KeyTy = Type;
789 
790  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
791 
792  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
793 
795  const KeyTy &key) {
796  return new (allocator.allocate<SampledImageTypeStorage>())
798  }
799 
801 };
802 
804  return Base::get(imageType.getContext(), imageType);
805 }
806 
809  Type imageType) {
810  return Base::getChecked(emitError, imageType.getContext(), imageType);
811 }
812 
813 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
814 
817  Type imageType) {
818  if (!imageType.isa<ImageType>())
819  return emitError() << "expected image type";
820 
821  return success();
822 }
823 
826  std::optional<StorageClass> storage) {
827  getImageType().cast<ImageType>().getExtensions(extensions, storage);
828 }
829 
832  std::optional<StorageClass> storage) {
833  getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // StructType
838 //===----------------------------------------------------------------------===//
839 
840 /// Type storage for SPIR-V structure types:
841 ///
842 /// Structures are uniqued using:
843 /// - for identified structs:
844 /// - a string identifier;
845 /// - for literal structs:
846 /// - a list of member types;
847 /// - a list of member offset info;
848 /// - a list of member decoration info.
849 ///
850 /// Identified structures only have a mutable component consisting of:
851 /// - a list of member types;
852 /// - a list of member offset info;
853 /// - a list of member decoration info.
855  /// Construct a storage object for an identified struct type. A struct type
856  /// associated with such storage must call StructType::trySetBody(...) later
857  /// in order to mutate the storage object providing the actual content.
858  StructTypeStorage(StringRef identifier)
859  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
860  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
861  identifier(identifier) {}
862 
863  /// Construct a storage object for a literal struct type. A struct type
864  /// associated with such storage is immutable.
866  unsigned numMembers, Type const *memberTypes,
867  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
868  StructType::MemberDecorationInfo const *memberDecorationsInfo)
869  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
870  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
871  memberDecorationsInfo(memberDecorationsInfo) {}
872 
873  /// A storage key is divided into 2 parts:
874  /// - for identified structs:
875  /// - a StringRef representing the struct identifier;
876  /// - for literal structs:
877  /// - an ArrayRef<Type> for member types;
878  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
879  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
880  /// info.
881  ///
882  /// An identified struct type is uniqued only by the first part (field 0)
883  /// of the key.
884  ///
885  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
886  /// 3) of the key. The identifier field (field 0) must be empty.
887  using KeyTy =
888  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
890 
891  /// For identified structs, return true if the given key contains the same
892  /// identifier.
893  ///
894  /// For literal structs, return true if the given key contains a matching list
895  /// of member types + offset info + decoration info.
896  bool operator==(const KeyTy &key) const {
897  if (isIdentified()) {
898  // Identified types are uniqued by their identifier.
899  return getIdentifier() == std::get<0>(key);
900  }
901 
902  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
903  getMemberDecorationsInfo());
904  }
905 
906  /// If the given key contains a non-empty identifier, this method constructs
907  /// an identified struct and leaves the rest of the struct type data to be set
908  /// through a later call to StructType::trySetBody(...).
909  ///
910  /// If, on the other hand, the key contains an empty identifier, a literal
911  /// struct is constructed using the other fields of the key.
913  const KeyTy &key) {
914  StringRef keyIdentifier = std::get<0>(key);
915 
916  if (!keyIdentifier.empty()) {
917  StringRef identifier = allocator.copyInto(keyIdentifier);
918 
919  // Identified StructType body/members will be set through trySetBody(...)
920  // later.
921  return new (allocator.allocate<StructTypeStorage>())
922  StructTypeStorage(identifier);
923  }
924 
925  ArrayRef<Type> keyTypes = std::get<1>(key);
926 
927  // Copy the member type and layout information into the bump pointer
928  const Type *typesList = nullptr;
929  if (!keyTypes.empty()) {
930  typesList = allocator.copyInto(keyTypes).data();
931  }
932 
933  const StructType::OffsetInfo *offsetInfoList = nullptr;
934  if (!std::get<2>(key).empty()) {
935  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
936  assert(keyOffsetInfo.size() == keyTypes.size() &&
937  "size of offset information must be same as the size of number of "
938  "elements");
939  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
940  }
941 
942  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
943  unsigned numMemberDecorations = 0;
944  if (!std::get<3>(key).empty()) {
945  auto keyMemberDecorations = std::get<3>(key);
946  numMemberDecorations = keyMemberDecorations.size();
947  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
948  }
949 
950  return new (allocator.allocate<StructTypeStorage>())
951  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
952  numMemberDecorations, memberDecorationList);
953  }
954 
956  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
957  }
958 
960  if (offsetInfo) {
961  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
962  }
963  return {};
964  }
965 
967  if (memberDecorationsInfo) {
968  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
969  numMemberDecorations);
970  }
971  return {};
972  }
973 
974  StringRef getIdentifier() const { return identifier; }
975 
976  bool isIdentified() const { return !identifier.empty(); }
977 
978  /// Sets the struct type content for identified structs. Calling this method
979  /// is only valid for identified structs.
980  ///
981  /// Fails under the following conditions:
982  /// - If called for a literal struct;
983  /// - If called for an identified struct whose body was set before (through a
984  /// call to this method) but with different contents from the passed
985  /// arguments.
987  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
988  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
989  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
990  if (!isIdentified())
991  return failure();
992 
993  if (memberTypesAndIsBodySet.getInt() &&
994  (getMemberTypes() != structMemberTypes ||
995  getOffsetInfo() != structOffsetInfo ||
996  getMemberDecorationsInfo() != structMemberDecorationInfo))
997  return failure();
998 
999  memberTypesAndIsBodySet.setInt(true);
1000  numMembers = structMemberTypes.size();
1001 
1002  // Copy the member type and layout information into the bump pointer.
1003  if (!structMemberTypes.empty())
1004  memberTypesAndIsBodySet.setPointer(
1005  allocator.copyInto(structMemberTypes).data());
1006 
1007  if (!structOffsetInfo.empty()) {
1008  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1009  "size of offset information must be same as the size of number of "
1010  "elements");
1011  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1012  }
1013 
1014  if (!structMemberDecorationInfo.empty()) {
1015  numMemberDecorations = structMemberDecorationInfo.size();
1016  memberDecorationsInfo =
1017  allocator.copyInto(structMemberDecorationInfo).data();
1018  }
1019 
1020  return success();
1021  }
1022 
1023  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1025  unsigned numMembers;
1028  StringRef identifier;
1029 };
1030 
1031 StructType
1034  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1035  assert(!memberTypes.empty() && "Struct needs at least one member type");
1036  // Sort the decorations.
1038  memberDecorations.begin(), memberDecorations.end());
1039  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1040  return Base::get(memberTypes.vec().front().getContext(),
1041  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1042  sortedDecorations);
1043 }
1044 
1046  StringRef identifier) {
1047  assert(!identifier.empty() &&
1048  "StructType identifier must be non-empty string");
1049 
1050  return Base::get(context, identifier, ArrayRef<Type>(),
1053 }
1054 
1055 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1056  StructType newStructType = Base::get(
1057  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1059  // Set an empty body in case this is a identified struct.
1060  if (newStructType.isIdentified() &&
1061  failed(newStructType.trySetBody(
1064  return StructType();
1065 
1066  return newStructType;
1067 }
1068 
1069 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1070 
1071 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1072 
1073 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1074 
1075 Type StructType::getElementType(unsigned index) const {
1076  assert(getNumElements() > index && "member index out of range");
1077  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1078 }
1079 
1081  return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1082  getNumElements());
1083 }
1084 
1085 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1086 
1087 uint64_t StructType::getMemberOffset(unsigned index) const {
1088  assert(getNumElements() > index && "member index out of range");
1089  return getImpl()->offsetInfo[index];
1090 }
1091 
1094  const {
1095  memberDecorations.clear();
1096  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1097  memberDecorations.append(implMemberDecorations.begin(),
1098  implMemberDecorations.end());
1099 }
1100 
1102  unsigned index,
1103  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1104  assert(getNumElements() > index && "member index out of range");
1105  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1106  decorationsInfo.clear();
1107  for (const auto &memberDecoration : memberDecorations) {
1108  if (memberDecoration.memberIndex == index) {
1109  decorationsInfo.push_back(memberDecoration);
1110  }
1111  if (memberDecoration.memberIndex > index) {
1112  // Early exit since the decorations are stored sorted.
1113  return;
1114  }
1115  }
1116 }
1117 
1120  ArrayRef<OffsetInfo> offsetInfo,
1121  ArrayRef<MemberDecorationInfo> memberDecorations) {
1122  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1123 }
1124 
1126  std::optional<StorageClass> storage) {
1127  for (Type elementType : getElementTypes())
1128  elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1129 }
1130 
1133  std::optional<StorageClass> storage) {
1134  for (Type elementType : getElementTypes())
1135  elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1136 }
1137 
1138 llvm::hash_code spirv::hash_value(
1139  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1140  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1141  memberDecorationInfo.decoration);
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // MatrixType
1146 //===----------------------------------------------------------------------===//
1147 
1149  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1150  : columnType(columnType), columnCount(columnCount) {}
1151 
1152  using KeyTy = std::tuple<Type, uint32_t>;
1153 
1155  const KeyTy &key) {
1156 
1157  // Initialize the memory using placement new.
1158  return new (allocator.allocate<MatrixTypeStorage>())
1159  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1160  }
1161 
1162  bool operator==(const KeyTy &key) const {
1163  return key == KeyTy(columnType, columnCount);
1164  }
1165 
1167  const uint32_t columnCount;
1168 };
1169 
1170 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1171  return Base::get(columnType.getContext(), columnType, columnCount);
1172 }
1173 
1175  Type columnType, uint32_t columnCount) {
1176  return Base::getChecked(emitError, columnType.getContext(), columnType,
1177  columnCount);
1178 }
1179 
1181  Type columnType, uint32_t columnCount) {
1182  if (columnCount < 2 || columnCount > 4)
1183  return emitError() << "matrix can have 2, 3, or 4 columns only";
1184 
1185  if (!isValidColumnType(columnType))
1186  return emitError() << "matrix columns must be vectors of floats";
1187 
1188  /// The underlying vectors (columns) must be of size 2, 3, or 4
1189  ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1190  if (columnShape.size() != 1)
1191  return emitError() << "matrix columns must be 1D vectors";
1192 
1193  if (columnShape[0] < 2 || columnShape[0] > 4)
1194  return emitError() << "matrix columns must be of size 2, 3, or 4";
1195 
1196  return success();
1197 }
1198 
1199 /// Returns true if the matrix elements are vectors of float elements
1201  if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1202  if (vectorType.getElementType().isa<FloatType>())
1203  return true;
1204  }
1205  return false;
1206 }
1207 
1208 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1209 
1211  return getImpl()->columnType.cast<VectorType>().getElementType();
1212 }
1213 
1214 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1215 
1216 unsigned MatrixType::getNumRows() const {
1217  return getImpl()->columnType.cast<VectorType>().getShape()[0];
1218 }
1219 
1220 unsigned MatrixType::getNumElements() const {
1221  return (getImpl()->columnCount) * getNumRows();
1222 }
1223 
1225  std::optional<StorageClass> storage) {
1226  getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1227 }
1228 
1231  std::optional<StorageClass> storage) {
1232  {
1233  static const Capability caps[] = {Capability::Matrix};
1234  ArrayRef<Capability> ref(caps, std::size(caps));
1235  capabilities.push_back(ref);
1236  }
1237  // Add any capabilities associated with the underlying vectors (i.e., columns)
1238  getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // SPIR-V Dialect
1243 //===----------------------------------------------------------------------===//
1244 
1245 void SPIRVDialect::registerTypes() {
1248  StructType>();
1249 }
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:368
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:374
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:340
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:356
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:362
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:344
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:350
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
Base storage class appearing in a Type.
Definition: TypeSupport.h:153
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
U dyn_cast() const
Definition: Types.h:311
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
bool isBF16() const
Definition: Types.cpp:42
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:65
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:67
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:63
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:51
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:80
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:69
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:74
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:169
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:189
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:150
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:155
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:126
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:100
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:114
static bool classof(Type type)
Definition: SPIRVTypes.cpp:92
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:243
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:245
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:231
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:258
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:241
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:249
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:164
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:421
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:423
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:435
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:431
Type getElementType() const
Definition: SPIRVTypes.cpp:417
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:427
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:442
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:437
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:307
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:326
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:317
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:311
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:296
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:309
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:313
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:488
Type getPointeeType() const
Definition: SPIRVTypes.cpp:482
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:484
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:499
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:478
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:553
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:545
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:547
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:535
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:776
static bool classof(Type type)
Definition: SPIRVTypes.cpp:722
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:756
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:737
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:816
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:830
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:808
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:824
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:803
static bool classof(Type type)
Definition: SPIRVTypes.cpp:568
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:586
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:578
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:617
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:705
Range class for element types.
Definition: SPIRVTypes.h:351
SPIR-V struct type.
Definition: SPIRVTypes.h:282
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
ElementTypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:31
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:33
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:38
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:209
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:212
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:390
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:385
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:383
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
Definition: SPIRVTypes.cpp:272
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:274
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:280
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:467
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:459
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:461
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:518
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:524
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:792
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:794
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:854
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:959
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:865
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:896
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:986
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:912
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:955
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:858
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:966
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:889
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet