MLIR  18.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
32  using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
35  const KeyTy &key) {
36  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37  }
38 
39  bool operator==(const KeyTy &key) const {
40  return key == KeyTy(elementType, elementCount, stride);
41  }
42 
43  ArrayTypeStorage(const KeyTy &key)
44  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45  stride(std::get<2>(key)) {}
46 
48  unsigned elementCount;
49  unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53  assert(elementCount && "ArrayType needs at least one element");
54  return Base::get(elementType.getContext(), elementType, elementCount,
55  /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59  unsigned stride) {
60  assert(elementCount && "ArrayType needs at least one element");
61  return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
71  std::optional<StorageClass> storage) {
72  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
77  std::optional<StorageClass> storage) {
78  llvm::cast<SPIRVType>(getElementType())
79  .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83  auto elementType = llvm::cast<SPIRVType>(getElementType());
84  std::optional<int64_t> size = elementType.getSizeInBytes();
85  if (!size)
86  return std::nullopt;
87  return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
95  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96  return isValid(vectorType);
100  spirv::StructType>(type);
101 }
102 
103 bool CompositeType::isValid(VectorType type) {
104  return type.getRank() == 1 &&
105  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
106  llvm::isa<ScalarType>(type.getElementType());
107 }
108 
109 Type CompositeType::getElementType(unsigned index) const {
110  return TypeSwitch<Type, Type>(*this)
113  [](auto type) { return type.getElementType(); })
114  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
115  .Case<StructType>(
116  [index](StructType type) { return type.getElementType(index); })
117  .Default(
118  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
119 }
120 
122  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
123  return arrayType.getNumElements();
124  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
125  return matrixType.getNumColumns();
126  if (auto structType = llvm::dyn_cast<StructType>(*this))
127  return structType.getNumElements();
128  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
129  return vectorType.getNumElements();
130  if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
131  llvm_unreachable(
132  "invalid to query number of elements of spirv Cooperative Matrix type");
133  }
134  if (llvm::isa<JointMatrixINTELType>(*this)) {
135  llvm_unreachable(
136  "invalid to query number of elements of spirv::JointMatrix type");
137  }
138  if (llvm::isa<RuntimeArrayType>(*this)) {
139  llvm_unreachable(
140  "invalid to query number of elements of spirv::RuntimeArray type");
141  }
142  llvm_unreachable("invalid composite type");
143 }
144 
148 }
149 
152  std::optional<StorageClass> storage) {
153  TypeSwitch<Type>(*this)
156  [&](auto type) { type.getExtensions(extensions, storage); })
157  .Case<VectorType>([&](VectorType type) {
158  return llvm::cast<ScalarType>(type.getElementType())
159  .getExtensions(extensions, storage);
160  })
161  .Default([](Type) { llvm_unreachable("invalid composite type"); });
162 }
163 
166  std::optional<StorageClass> storage) {
167  TypeSwitch<Type>(*this)
170  [&](auto type) { type.getCapabilities(capabilities, storage); })
171  .Case<VectorType>([&](VectorType type) {
172  auto vecSize = getNumElements();
173  if (vecSize == 8 || vecSize == 16) {
174  static const Capability caps[] = {Capability::Vector16};
175  ArrayRef<Capability> ref(caps, std::size(caps));
176  capabilities.push_back(ref);
177  }
178  return llvm::cast<ScalarType>(type.getElementType())
179  .getCapabilities(capabilities, storage);
180  })
181  .Default([](Type) { llvm_unreachable("invalid composite type"); });
182 }
183 
184 std::optional<int64_t> CompositeType::getSizeInBytes() {
185  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
186  return arrayType.getSizeInBytes();
187  if (auto structType = llvm::dyn_cast<StructType>(*this))
188  return structType.getSizeInBytes();
189  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
190  std::optional<int64_t> elementSize =
191  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
192  if (!elementSize)
193  return std::nullopt;
194  return *elementSize * vectorType.getNumElements();
195  }
196  return std::nullopt;
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // CooperativeMatrixType
201 //===----------------------------------------------------------------------===//
202 
204  using KeyTy =
205  std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
206 
208  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
209  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
211  }
212 
213  bool operator==(const KeyTy &key) const {
214  return key == KeyTy(elementType, rows, columns, scope, use);
215  }
216 
218  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
219  columns(std::get<2>(key)), scope(std::get<3>(key)),
220  use(std::get<4>(key)) {}
221 
223  uint32_t rows;
224  uint32_t columns;
225  Scope scope;
226  CooperativeMatrixUseKHR use;
227 };
228 
230  uint32_t rows,
231  uint32_t columns, Scope scope,
232  CooperativeMatrixUseKHR use) {
233  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
234  use);
235 }
236 
238  return getImpl()->elementType;
239 }
240 
241 uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
242 
244  return getImpl()->columns;
245 }
246 
247 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
248 
249 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
250  return getImpl()->use;
251 }
252 
255  std::optional<StorageClass> storage) {
256  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
257  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
258  extensions.push_back(exts);
259 }
260 
263  std::optional<StorageClass> storage) {
264  llvm::cast<SPIRVType>(getElementType())
265  .getCapabilities(capabilities, storage);
266  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
267  capabilities.push_back(caps);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // CooperativeMatrixNVType
272 //===----------------------------------------------------------------------===//
273 
275  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
276 
278  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
279  return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
281  }
282 
283  bool operator==(const KeyTy &key) const {
284  return key == KeyTy(elementType, scope, rows, columns);
285  }
286 
288  : elementType(std::get<0>(key)), rows(std::get<2>(key)),
289  columns(std::get<3>(key)), scope(std::get<1>(key)) {}
290 
292  unsigned rows;
293  unsigned columns;
294  Scope scope;
295 };
296 
298  Scope scope, unsigned rows,
299  unsigned columns) {
300  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
301 }
302 
304  return getImpl()->elementType;
305 }
306 
307 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
308 
309 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
310 
312  return getImpl()->columns;
313 }
314 
317  std::optional<StorageClass> storage) {
318  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
319  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
320  ArrayRef<Extension> ref(exts, std::size(exts));
321  extensions.push_back(ref);
322 }
323 
326  std::optional<StorageClass> storage) {
327  llvm::cast<SPIRVType>(getElementType())
328  .getCapabilities(capabilities, storage);
329  static const Capability caps[] = {Capability::CooperativeMatrixNV};
330  ArrayRef<Capability> ref(caps, std::size(caps));
331  capabilities.push_back(ref);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // JointMatrixType
336 //===----------------------------------------------------------------------===//
337 
339  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
340 
342  const KeyTy &key) {
343  return new (allocator.allocate<JointMatrixTypeStorage>())
345  }
346 
347  bool operator==(const KeyTy &key) const {
348  return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
349  }
350 
352  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
353  columns(std::get<2>(key)), scope(std::get<4>(key)),
354  matrixLayout(std::get<3>(key)) {}
355 
357  unsigned rows;
358  unsigned columns;
359  Scope scope;
360  MatrixLayout matrixLayout;
361 };
362 
364  unsigned rows, unsigned columns,
365  MatrixLayout matrixLayout) {
366  return Base::get(elementType.getContext(), elementType, rows, columns,
367  matrixLayout, scope);
368 }
369 
371  return getImpl()->elementType;
372 }
373 
374 Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
375 
376 unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
377 
378 unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
379 
381  return getImpl()->matrixLayout;
382 }
383 
386  std::optional<StorageClass> storage) {
387  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
388  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
389  ArrayRef<Extension> ref(exts, std::size(exts));
390  extensions.push_back(ref);
391 }
392 
395  std::optional<StorageClass> storage) {
396  llvm::cast<SPIRVType>(getElementType())
397  .getCapabilities(capabilities, storage);
398  static const Capability caps[] = {Capability::JointMatrixINTEL};
399  ArrayRef<Capability> ref(caps, std::size(caps));
400  capabilities.push_back(ref);
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // ImageType
405 //===----------------------------------------------------------------------===//
406 
407 template <typename T>
408 static constexpr unsigned getNumBits() {
409  return 0;
410 }
411 template <>
412 constexpr unsigned getNumBits<Dim>() {
413  static_assert((1 << 3) > getMaxEnumValForDim(),
414  "Not enough bits to encode Dim value");
415  return 3;
416 }
417 template <>
418 constexpr unsigned getNumBits<ImageDepthInfo>() {
419  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
420  "Not enough bits to encode ImageDepthInfo value");
421  return 2;
422 }
423 template <>
424 constexpr unsigned getNumBits<ImageArrayedInfo>() {
425  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
426  "Not enough bits to encode ImageArrayedInfo value");
427  return 1;
428 }
429 template <>
430 constexpr unsigned getNumBits<ImageSamplingInfo>() {
431  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
432  "Not enough bits to encode ImageSamplingInfo value");
433  return 1;
434 }
435 template <>
436 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
437  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
438  "Not enough bits to encode ImageSamplerUseInfo value");
439  return 2;
440 }
441 template <>
442 constexpr unsigned getNumBits<ImageFormat>() {
443  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
444  "Not enough bits to encode ImageFormat value");
445  return 6;
446 }
447 
449 public:
450  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
451  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
452 
454  const KeyTy &key) {
455  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
456  }
457 
458  bool operator==(const KeyTy &key) const {
459  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
460  samplerUseInfo, format);
461  }
462 
464  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
465  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
466  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
467  format(std::get<6>(key)) {}
468 
476 };
477 
478 ImageType
479 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
480  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
481  value) {
482  return Base::get(std::get<0>(value).getContext(), value);
483 }
484 
485 Type ImageType::getElementType() const { return getImpl()->elementType; }
486 
487 Dim ImageType::getDim() const { return getImpl()->dim; }
488 
489 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
490 
491 ImageArrayedInfo ImageType::getArrayedInfo() const {
492  return getImpl()->arrayedInfo;
493 }
494 
495 ImageSamplingInfo ImageType::getSamplingInfo() const {
496  return getImpl()->samplingInfo;
497 }
498 
499 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
500  return getImpl()->samplerUseInfo;
501 }
502 
503 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
504 
506  std::optional<StorageClass>) {
507  // Image types do not require extra extensions thus far.
508 }
509 
512  std::optional<StorageClass>) {
513  if (auto dimCaps = spirv::getCapabilities(getDim()))
514  capabilities.push_back(*dimCaps);
515 
516  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
517  capabilities.push_back(*fmtCaps);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // PointerType
522 //===----------------------------------------------------------------------===//
523 
525  // (Type, StorageClass) as the key: Type stored in this struct, and
526  // StorageClass stored as TypeStorage's subclass data.
527  using KeyTy = std::pair<Type, StorageClass>;
528 
530  const KeyTy &key) {
531  return new (allocator.allocate<PointerTypeStorage>())
532  PointerTypeStorage(key);
533  }
534 
535  bool operator==(const KeyTy &key) const {
536  return key == KeyTy(pointeeType, storageClass);
537  }
538 
540  : pointeeType(key.first), storageClass(key.second) {}
541 
543  StorageClass storageClass;
544 };
545 
546 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
547  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
548 }
549 
550 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
551 
552 StorageClass PointerType::getStorageClass() const {
553  return getImpl()->storageClass;
554 }
555 
557  std::optional<StorageClass> storage) {
558  // Use this pointer type's storage class because this pointer indicates we are
559  // using the pointee type in that specific storage class.
560  llvm::cast<SPIRVType>(getPointeeType())
561  .getExtensions(extensions, getStorageClass());
562 
563  if (auto scExts = spirv::getExtensions(getStorageClass()))
564  extensions.push_back(*scExts);
565 }
566 
569  std::optional<StorageClass> storage) {
570  // Use this pointer type's storage class because this pointer indicates we are
571  // using the pointee type in that specific storage class.
572  llvm::cast<SPIRVType>(getPointeeType())
573  .getCapabilities(capabilities, getStorageClass());
574 
575  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
576  capabilities.push_back(*scCaps);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // RuntimeArrayType
581 //===----------------------------------------------------------------------===//
582 
584  using KeyTy = std::pair<Type, unsigned>;
585 
587  const KeyTy &key) {
588  return new (allocator.allocate<RuntimeArrayTypeStorage>())
590  }
591 
592  bool operator==(const KeyTy &key) const {
593  return key == KeyTy(elementType, stride);
594  }
595 
597  : elementType(key.first), stride(key.second) {}
598 
600  unsigned stride;
601 };
602 
604  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
605 }
606 
607 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
608  return Base::get(elementType.getContext(), elementType, stride);
609 }
610 
611 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
612 
613 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
614 
617  std::optional<StorageClass> storage) {
618  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
619 }
620 
623  std::optional<StorageClass> storage) {
624  {
625  static const Capability caps[] = {Capability::Shader};
626  ArrayRef<Capability> ref(caps, std::size(caps));
627  capabilities.push_back(ref);
628  }
629  llvm::cast<SPIRVType>(getElementType())
630  .getCapabilities(capabilities, storage);
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // ScalarType
635 //===----------------------------------------------------------------------===//
636 
638  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
639  return isValid(floatType);
640  }
641  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
642  return isValid(intType);
643  }
644  return false;
645 }
646 
648  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
649 }
650 
651 bool ScalarType::isValid(IntegerType type) {
652  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
653 }
654 
656  std::optional<StorageClass> storage) {
657  // 8- or 16-bit integer/floating-point numbers will require extra extensions
658  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
659  // SPV_KHR_8bit_storage for more details.
660  if (!storage)
661  return;
662 
663  switch (*storage) {
664  case StorageClass::PushConstant:
665  case StorageClass::StorageBuffer:
666  case StorageClass::Uniform:
667  if (getIntOrFloatBitWidth() == 8) {
668  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
669  ArrayRef<Extension> ref(exts, std::size(exts));
670  extensions.push_back(ref);
671  }
672  [[fallthrough]];
673  case StorageClass::Input:
674  case StorageClass::Output:
675  if (getIntOrFloatBitWidth() == 16) {
676  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
677  ArrayRef<Extension> ref(exts, std::size(exts));
678  extensions.push_back(ref);
679  }
680  break;
681  default:
682  break;
683  }
684 }
685 
688  std::optional<StorageClass> storage) {
689  unsigned bitwidth = getIntOrFloatBitWidth();
690 
691  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
692  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
693  // SPV_KHR_8bit_storage for more details.
694 
695 #define STORAGE_CASE(storage, cap8, cap16) \
696  case StorageClass::storage: { \
697  if (bitwidth == 8) { \
698  static const Capability caps[] = {Capability::cap8}; \
699  ArrayRef<Capability> ref(caps, std::size(caps)); \
700  capabilities.push_back(ref); \
701  return; \
702  } \
703  if (bitwidth == 16) { \
704  static const Capability caps[] = {Capability::cap16}; \
705  ArrayRef<Capability> ref(caps, std::size(caps)); \
706  capabilities.push_back(ref); \
707  return; \
708  } \
709  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
710  /* storage classes. Fall through to the next section. */ \
711  } break
712 
713  // This part only handles the cases where special bitwidths appearing in
714  // interface storage classes.
715  if (storage) {
716  switch (*storage) {
717  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
718  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
719  StorageBuffer16BitAccess);
720  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
721  StorageUniform16);
722  case StorageClass::Input:
723  case StorageClass::Output: {
724  if (bitwidth == 16) {
725  static const Capability caps[] = {Capability::StorageInputOutput16};
726  ArrayRef<Capability> ref(caps, std::size(caps));
727  capabilities.push_back(ref);
728  return;
729  }
730  break;
731  }
732  default:
733  break;
734  }
735  }
736 #undef STORAGE_CASE
737 
738  // For other non-interface storage classes, require a different set of
739  // capabilities for special bitwidths.
740 
741 #define WIDTH_CASE(type, width) \
742  case width: { \
743  static const Capability caps[] = {Capability::type##width}; \
744  ArrayRef<Capability> ref(caps, std::size(caps)); \
745  capabilities.push_back(ref); \
746  } break
747 
748  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
749  switch (bitwidth) {
750  WIDTH_CASE(Int, 8);
751  WIDTH_CASE(Int, 16);
752  WIDTH_CASE(Int, 64);
753  case 1:
754  case 32:
755  break;
756  default:
757  llvm_unreachable("invalid bitwidth to getCapabilities");
758  }
759  } else {
760  assert(llvm::isa<FloatType>(*this));
761  switch (bitwidth) {
762  WIDTH_CASE(Float, 16);
763  WIDTH_CASE(Float, 64);
764  case 32:
765  break;
766  default:
767  llvm_unreachable("invalid bitwidth to getCapabilities");
768  }
769  }
770 
771 #undef WIDTH_CASE
772 }
773 
774 std::optional<int64_t> ScalarType::getSizeInBytes() {
775  auto bitWidth = getIntOrFloatBitWidth();
776  // According to the SPIR-V spec:
777  // "There is no physical size or bit pattern defined for values with boolean
778  // type. If they are stored (in conjunction with OpVariable), they can only
779  // be used with logical addressing operations, not physical, and only with
780  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
781  // Private, Function, Input, and Output."
782  if (bitWidth == 1)
783  return std::nullopt;
784  return bitWidth / 8;
785 }
786 
787 //===----------------------------------------------------------------------===//
788 // SPIRVType
789 //===----------------------------------------------------------------------===//
790 
792  // Allow SPIR-V dialect types
793  if (llvm::isa<SPIRVDialect>(type.getDialect()))
794  return true;
795  if (llvm::isa<ScalarType>(type))
796  return true;
797  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
798  return CompositeType::isValid(vectorType);
799  return false;
800 }
801 
803  return isIntOrFloat() || llvm::isa<VectorType>(*this);
804 }
805 
807  std::optional<StorageClass> storage) {
808  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
809  scalarType.getExtensions(extensions, storage);
810  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
811  compositeType.getExtensions(extensions, storage);
812  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
813  imageType.getExtensions(extensions, storage);
814  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
815  sampledImageType.getExtensions(extensions, storage);
816  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
817  matrixType.getExtensions(extensions, storage);
818  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
819  ptrType.getExtensions(extensions, storage);
820  } else {
821  llvm_unreachable("invalid SPIR-V Type to getExtensions");
822  }
823 }
824 
827  std::optional<StorageClass> storage) {
828  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
829  scalarType.getCapabilities(capabilities, storage);
830  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
831  compositeType.getCapabilities(capabilities, storage);
832  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
833  imageType.getCapabilities(capabilities, storage);
834  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
835  sampledImageType.getCapabilities(capabilities, storage);
836  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
837  matrixType.getCapabilities(capabilities, storage);
838  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
839  ptrType.getCapabilities(capabilities, storage);
840  } else {
841  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
842  }
843 }
844 
845 std::optional<int64_t> SPIRVType::getSizeInBytes() {
846  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
847  return scalarType.getSizeInBytes();
848  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
849  return compositeType.getSizeInBytes();
850  return std::nullopt;
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // SampledImageType
855 //===----------------------------------------------------------------------===//
857  using KeyTy = Type;
858 
859  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
860 
861  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
862 
864  const KeyTy &key) {
865  return new (allocator.allocate<SampledImageTypeStorage>())
867  }
868 
870 };
871 
873  return Base::get(imageType.getContext(), imageType);
874 }
875 
878  Type imageType) {
879  return Base::getChecked(emitError, imageType.getContext(), imageType);
880 }
881 
882 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
883 
886  Type imageType) {
887  if (!llvm::isa<ImageType>(imageType))
888  return emitError() << "expected image type";
889 
890  return success();
891 }
892 
895  std::optional<StorageClass> storage) {
896  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
897 }
898 
901  std::optional<StorageClass> storage) {
902  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
903 }
904 
905 //===----------------------------------------------------------------------===//
906 // StructType
907 //===----------------------------------------------------------------------===//
908 
909 /// Type storage for SPIR-V structure types:
910 ///
911 /// Structures are uniqued using:
912 /// - for identified structs:
913 /// - a string identifier;
914 /// - for literal structs:
915 /// - a list of member types;
916 /// - a list of member offset info;
917 /// - a list of member decoration info.
918 ///
919 /// Identified structures only have a mutable component consisting of:
920 /// - a list of member types;
921 /// - a list of member offset info;
922 /// - a list of member decoration info.
924  /// Construct a storage object for an identified struct type. A struct type
925  /// associated with such storage must call StructType::trySetBody(...) later
926  /// in order to mutate the storage object providing the actual content.
927  StructTypeStorage(StringRef identifier)
928  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
929  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
930  identifier(identifier) {}
931 
932  /// Construct a storage object for a literal struct type. A struct type
933  /// associated with such storage is immutable.
935  unsigned numMembers, Type const *memberTypes,
936  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
937  StructType::MemberDecorationInfo const *memberDecorationsInfo)
938  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
939  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
940  memberDecorationsInfo(memberDecorationsInfo) {}
941 
942  /// A storage key is divided into 2 parts:
943  /// - for identified structs:
944  /// - a StringRef representing the struct identifier;
945  /// - for literal structs:
946  /// - an ArrayRef<Type> for member types;
947  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
948  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
949  /// info.
950  ///
951  /// An identified struct type is uniqued only by the first part (field 0)
952  /// of the key.
953  ///
954  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
955  /// 3) of the key. The identifier field (field 0) must be empty.
956  using KeyTy =
957  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
959 
960  /// For identified structs, return true if the given key contains the same
961  /// identifier.
962  ///
963  /// For literal structs, return true if the given key contains a matching list
964  /// of member types + offset info + decoration info.
965  bool operator==(const KeyTy &key) const {
966  if (isIdentified()) {
967  // Identified types are uniqued by their identifier.
968  return getIdentifier() == std::get<0>(key);
969  }
970 
971  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
972  getMemberDecorationsInfo());
973  }
974 
975  /// If the given key contains a non-empty identifier, this method constructs
976  /// an identified struct and leaves the rest of the struct type data to be set
977  /// through a later call to StructType::trySetBody(...).
978  ///
979  /// If, on the other hand, the key contains an empty identifier, a literal
980  /// struct is constructed using the other fields of the key.
982  const KeyTy &key) {
983  StringRef keyIdentifier = std::get<0>(key);
984 
985  if (!keyIdentifier.empty()) {
986  StringRef identifier = allocator.copyInto(keyIdentifier);
987 
988  // Identified StructType body/members will be set through trySetBody(...)
989  // later.
990  return new (allocator.allocate<StructTypeStorage>())
991  StructTypeStorage(identifier);
992  }
993 
994  ArrayRef<Type> keyTypes = std::get<1>(key);
995 
996  // Copy the member type and layout information into the bump pointer
997  const Type *typesList = nullptr;
998  if (!keyTypes.empty()) {
999  typesList = allocator.copyInto(keyTypes).data();
1000  }
1001 
1002  const StructType::OffsetInfo *offsetInfoList = nullptr;
1003  if (!std::get<2>(key).empty()) {
1004  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
1005  assert(keyOffsetInfo.size() == keyTypes.size() &&
1006  "size of offset information must be same as the size of number of "
1007  "elements");
1008  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
1009  }
1010 
1011  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
1012  unsigned numMemberDecorations = 0;
1013  if (!std::get<3>(key).empty()) {
1014  auto keyMemberDecorations = std::get<3>(key);
1015  numMemberDecorations = keyMemberDecorations.size();
1016  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
1017  }
1018 
1019  return new (allocator.allocate<StructTypeStorage>())
1020  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
1021  numMemberDecorations, memberDecorationList);
1022  }
1023 
1025  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
1026  }
1027 
1029  if (offsetInfo) {
1030  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
1031  }
1032  return {};
1033  }
1034 
1036  if (memberDecorationsInfo) {
1037  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
1038  numMemberDecorations);
1039  }
1040  return {};
1041  }
1042 
1043  StringRef getIdentifier() const { return identifier; }
1044 
1045  bool isIdentified() const { return !identifier.empty(); }
1046 
1047  /// Sets the struct type content for identified structs. Calling this method
1048  /// is only valid for identified structs.
1049  ///
1050  /// Fails under the following conditions:
1051  /// - If called for a literal struct;
1052  /// - If called for an identified struct whose body was set before (through a
1053  /// call to this method) but with different contents from the passed
1054  /// arguments.
1056  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
1057  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
1058  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
1059  if (!isIdentified())
1060  return failure();
1061 
1062  if (memberTypesAndIsBodySet.getInt() &&
1063  (getMemberTypes() != structMemberTypes ||
1064  getOffsetInfo() != structOffsetInfo ||
1065  getMemberDecorationsInfo() != structMemberDecorationInfo))
1066  return failure();
1067 
1068  memberTypesAndIsBodySet.setInt(true);
1069  numMembers = structMemberTypes.size();
1070 
1071  // Copy the member type and layout information into the bump pointer.
1072  if (!structMemberTypes.empty())
1073  memberTypesAndIsBodySet.setPointer(
1074  allocator.copyInto(structMemberTypes).data());
1075 
1076  if (!structOffsetInfo.empty()) {
1077  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1078  "size of offset information must be same as the size of number of "
1079  "elements");
1080  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1081  }
1082 
1083  if (!structMemberDecorationInfo.empty()) {
1084  numMemberDecorations = structMemberDecorationInfo.size();
1085  memberDecorationsInfo =
1086  allocator.copyInto(structMemberDecorationInfo).data();
1087  }
1088 
1089  return success();
1090  }
1091 
1092  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1094  unsigned numMembers;
1097  StringRef identifier;
1098 };
1099 
1100 StructType
1103  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1104  assert(!memberTypes.empty() && "Struct needs at least one member type");
1105  // Sort the decorations.
1107  memberDecorations.begin(), memberDecorations.end());
1108  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1109  return Base::get(memberTypes.vec().front().getContext(),
1110  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1111  sortedDecorations);
1112 }
1113 
1115  StringRef identifier) {
1116  assert(!identifier.empty() &&
1117  "StructType identifier must be non-empty string");
1118 
1119  return Base::get(context, identifier, ArrayRef<Type>(),
1122 }
1123 
1124 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1125  StructType newStructType = Base::get(
1126  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1128  // Set an empty body in case this is a identified struct.
1129  if (newStructType.isIdentified() &&
1130  failed(newStructType.trySetBody(
1133  return StructType();
1134 
1135  return newStructType;
1136 }
1137 
1138 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1139 
1140 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1141 
1142 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1143 
1144 Type StructType::getElementType(unsigned index) const {
1145  assert(getNumElements() > index && "member index out of range");
1146  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1147 }
1148 
1150  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1151  getNumElements());
1152 }
1153 
1154 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1155 
1156 uint64_t StructType::getMemberOffset(unsigned index) const {
1157  assert(getNumElements() > index && "member index out of range");
1158  return getImpl()->offsetInfo[index];
1159 }
1160 
1163  const {
1164  memberDecorations.clear();
1165  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1166  memberDecorations.append(implMemberDecorations.begin(),
1167  implMemberDecorations.end());
1168 }
1169 
1171  unsigned index,
1172  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1173  assert(getNumElements() > index && "member index out of range");
1174  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1175  decorationsInfo.clear();
1176  for (const auto &memberDecoration : memberDecorations) {
1177  if (memberDecoration.memberIndex == index) {
1178  decorationsInfo.push_back(memberDecoration);
1179  }
1180  if (memberDecoration.memberIndex > index) {
1181  // Early exit since the decorations are stored sorted.
1182  return;
1183  }
1184  }
1185 }
1186 
1189  ArrayRef<OffsetInfo> offsetInfo,
1190  ArrayRef<MemberDecorationInfo> memberDecorations) {
1191  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1192 }
1193 
1195  std::optional<StorageClass> storage) {
1196  for (Type elementType : getElementTypes())
1197  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1198 }
1199 
1202  std::optional<StorageClass> storage) {
1203  for (Type elementType : getElementTypes())
1204  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1205 }
1206 
1207 llvm::hash_code spirv::hash_value(
1208  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1209  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1210  memberDecorationInfo.decoration);
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // MatrixType
1215 //===----------------------------------------------------------------------===//
1216 
1218  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1219  : columnType(columnType), columnCount(columnCount) {}
1220 
1221  using KeyTy = std::tuple<Type, uint32_t>;
1222 
1224  const KeyTy &key) {
1225 
1226  // Initialize the memory using placement new.
1227  return new (allocator.allocate<MatrixTypeStorage>())
1228  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1229  }
1230 
1231  bool operator==(const KeyTy &key) const {
1232  return key == KeyTy(columnType, columnCount);
1233  }
1234 
1236  const uint32_t columnCount;
1237 };
1238 
1239 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1240  return Base::get(columnType.getContext(), columnType, columnCount);
1241 }
1242 
1244  Type columnType, uint32_t columnCount) {
1245  return Base::getChecked(emitError, columnType.getContext(), columnType,
1246  columnCount);
1247 }
1248 
1250  Type columnType, uint32_t columnCount) {
1251  if (columnCount < 2 || columnCount > 4)
1252  return emitError() << "matrix can have 2, 3, or 4 columns only";
1253 
1254  if (!isValidColumnType(columnType))
1255  return emitError() << "matrix columns must be vectors of floats";
1256 
1257  /// The underlying vectors (columns) must be of size 2, 3, or 4
1258  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1259  if (columnShape.size() != 1)
1260  return emitError() << "matrix columns must be 1D vectors";
1261 
1262  if (columnShape[0] < 2 || columnShape[0] > 4)
1263  return emitError() << "matrix columns must be of size 2, 3, or 4";
1264 
1265  return success();
1266 }
1267 
1268 /// Returns true if the matrix elements are vectors of float elements
1270  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1271  if (llvm::isa<FloatType>(vectorType.getElementType()))
1272  return true;
1273  }
1274  return false;
1275 }
1276 
1277 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1278 
1280  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1281 }
1282 
1283 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1284 
1285 unsigned MatrixType::getNumRows() const {
1286  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1287 }
1288 
1289 unsigned MatrixType::getNumElements() const {
1290  return (getImpl()->columnCount) * getNumRows();
1291 }
1292 
1294  std::optional<StorageClass> storage) {
1295  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1296 }
1297 
1300  std::optional<StorageClass> storage) {
1301  {
1302  static const Capability caps[] = {Capability::Matrix};
1303  ArrayRef<Capability> ref(caps, std::size(caps));
1304  capabilities.push_back(ref);
1305  }
1306  // Add any capabilities associated with the underlying vectors (i.e., columns)
1307  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1308 }
1309 
1310 //===----------------------------------------------------------------------===//
1311 // SPIR-V Dialect
1312 //===----------------------------------------------------------------------===//
1313 
1314 void SPIRVDialect::registerTypes() {
1318 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:436
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:442
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:408
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:424
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:430
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:412
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:418
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:82
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:70
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:75
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:164
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:184
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:145
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:150
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:121
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:103
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:109
static bool classof(Type type)
Definition: SPIRVTypes.cpp:94
unsigned getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:309
unsigned getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:311
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:297
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:324
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:307
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:315
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:247
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:241
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:243
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:229
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:253
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:261
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:249
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:489
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:491
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:503
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:499
Type getElementType() const
Definition: SPIRVTypes.cpp:485
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:495
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:510
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:505
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:374
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:393
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:384
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:378
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:363
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:376
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:380
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:556
Type getPointeeType() const
Definition: SPIRVTypes.cpp:550
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:552
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:567
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:546
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:621
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:613
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:615
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:603
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:845
static bool classof(Type type)
Definition: SPIRVTypes.cpp:791
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:825
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:806
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:885
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:899
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:877
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:893
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:872
static bool classof(Type type)
Definition: SPIRVTypes.cpp:637
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:655
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:647
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:686
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:774
SPIR-V struct type.
Definition: SPIRVTypes.h:294
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:32
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:34
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:39
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:275
static CooperativeMatrixNVTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:278
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:208
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:205
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:458
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:453
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:451
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
Definition: SPIRVTypes.cpp:339
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:341
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:347
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:535
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:527
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:529
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:586
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:592
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:861
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:863
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:923
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:934
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:965
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:981
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:927
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:958
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet