MLIR  21.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
32  using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
35  const KeyTy &key) {
36  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37  }
38 
39  bool operator==(const KeyTy &key) const {
40  return key == KeyTy(elementType, elementCount, stride);
41  }
42 
43  ArrayTypeStorage(const KeyTy &key)
44  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45  stride(std::get<2>(key)) {}
46 
48  unsigned elementCount;
49  unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53  assert(elementCount && "ArrayType needs at least one element");
54  return Base::get(elementType.getContext(), elementType, elementCount,
55  /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59  unsigned stride) {
60  assert(elementCount && "ArrayType needs at least one element");
61  return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
71  std::optional<StorageClass> storage) {
72  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
77  std::optional<StorageClass> storage) {
78  llvm::cast<SPIRVType>(getElementType())
79  .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83  auto elementType = llvm::cast<SPIRVType>(getElementType());
84  std::optional<int64_t> size = elementType.getSizeInBytes();
85  if (!size)
86  return std::nullopt;
87  return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
95  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96  return isValid(vectorType);
99  spirv::StructType>(type);
100 }
101 
102 bool CompositeType::isValid(VectorType type) {
103  return type.getRank() == 1 &&
104  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105  llvm::isa<ScalarType>(type.getElementType());
106 }
107 
108 Type CompositeType::getElementType(unsigned index) const {
109  return TypeSwitch<Type, Type>(*this)
110  .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
111  [](auto type) { return type.getElementType(); })
112  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
113  .Case<StructType>(
114  [index](StructType type) { return type.getElementType(index); })
115  .Default(
116  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117 }
118 
120  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
121  return arrayType.getNumElements();
122  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
123  return matrixType.getNumColumns();
124  if (auto structType = llvm::dyn_cast<StructType>(*this))
125  return structType.getNumElements();
126  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
127  return vectorType.getNumElements();
128  if (llvm::isa<CooperativeMatrixType>(*this)) {
129  llvm_unreachable(
130  "invalid to query number of elements of spirv Cooperative Matrix type");
131  }
132  if (llvm::isa<RuntimeArrayType>(*this)) {
133  llvm_unreachable(
134  "invalid to query number of elements of spirv::RuntimeArray type");
135  }
136  llvm_unreachable("invalid composite type");
137 }
138 
140  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
145  std::optional<StorageClass> storage) {
146  TypeSwitch<Type>(*this)
148  StructType>(
149  [&](auto type) { type.getExtensions(extensions, storage); })
150  .Case<VectorType>([&](VectorType type) {
151  return llvm::cast<ScalarType>(type.getElementType())
152  .getExtensions(extensions, storage);
153  })
154  .Default([](Type) { llvm_unreachable("invalid composite type"); });
155 }
156 
159  std::optional<StorageClass> storage) {
160  TypeSwitch<Type>(*this)
162  StructType>(
163  [&](auto type) { type.getCapabilities(capabilities, storage); })
164  .Case<VectorType>([&](VectorType type) {
165  auto vecSize = getNumElements();
166  if (vecSize == 8 || vecSize == 16) {
167  static const Capability caps[] = {Capability::Vector16};
168  ArrayRef<Capability> ref(caps, std::size(caps));
169  capabilities.push_back(ref);
170  }
171  return llvm::cast<ScalarType>(type.getElementType())
172  .getCapabilities(capabilities, storage);
173  })
174  .Default([](Type) { llvm_unreachable("invalid composite type"); });
175 }
176 
177 std::optional<int64_t> CompositeType::getSizeInBytes() {
178  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
179  return arrayType.getSizeInBytes();
180  if (auto structType = llvm::dyn_cast<StructType>(*this))
181  return structType.getSizeInBytes();
182  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
183  std::optional<int64_t> elementSize =
184  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
185  if (!elementSize)
186  return std::nullopt;
187  return *elementSize * vectorType.getNumElements();
188  }
189  return std::nullopt;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // CooperativeMatrixType
194 //===----------------------------------------------------------------------===//
195 
197  // In the specification dimensions of the Cooperative Matrix are 32-bit
198  // integers --- the initial implementation kept those values as such. However,
199  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
200  // as 32-bits and expose it as int64_t through `getShape`, however, this
201  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
202  // 32-bits integers would require an extra logic and storage. So, we diverge
203  // from the spec and internally represent the dimensions as 64-bit integers,
204  // so we can easily return an `ArrayRef` from `getShape` without any extra
205  // logic. Alternatively, we could store both rows and columns (both 32-bits)
206  // and shape (64-bits), assigning rows and columns to shape whenever
207  // `getShape` is called. This would be at the cost of extra logic and storage.
208  // Note: Because `ArrayRef` is returned we cannot construct an object in
209  // `getShape` on the fly.
210  using KeyTy =
211  std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
212 
214  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
215  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
217  }
218 
219  bool operator==(const KeyTy &key) const {
220  return key == KeyTy(elementType, shape[0], shape[1], scope, use);
221  }
222 
224  : elementType(std::get<0>(key)),
225  shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
226  use(std::get<4>(key)) {}
227 
229  // [#rows, #columns]
230  std::array<int64_t, 2> shape;
231  Scope scope;
232  CooperativeMatrixUseKHR use;
233 };
234 
236  uint32_t rows,
237  uint32_t columns, Scope scope,
238  CooperativeMatrixUseKHR use) {
239  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
240  use);
241 }
242 
244  return getImpl()->elementType;
245 }
246 
248  assert(getImpl()->shape[0] != ShapedType::kDynamic);
249  return static_cast<uint32_t>(getImpl()->shape[0]);
250 }
251 
253  assert(getImpl()->shape[1] != ShapedType::kDynamic);
254  return static_cast<uint32_t>(getImpl()->shape[1]);
255 }
256 
258  return getImpl()->shape;
259 }
260 
261 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
262 
263 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
264  return getImpl()->use;
265 }
266 
269  std::optional<StorageClass> storage) {
270  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
271  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
272  extensions.push_back(exts);
273 }
274 
277  std::optional<StorageClass> storage) {
278  llvm::cast<SPIRVType>(getElementType())
279  .getCapabilities(capabilities, storage);
280  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
281  capabilities.push_back(caps);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // ImageType
286 //===----------------------------------------------------------------------===//
287 
288 template <typename T>
289 static constexpr unsigned getNumBits() {
290  return 0;
291 }
292 template <>
293 constexpr unsigned getNumBits<Dim>() {
294  static_assert((1 << 3) > getMaxEnumValForDim(),
295  "Not enough bits to encode Dim value");
296  return 3;
297 }
298 template <>
299 constexpr unsigned getNumBits<ImageDepthInfo>() {
300  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
301  "Not enough bits to encode ImageDepthInfo value");
302  return 2;
303 }
304 template <>
305 constexpr unsigned getNumBits<ImageArrayedInfo>() {
306  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
307  "Not enough bits to encode ImageArrayedInfo value");
308  return 1;
309 }
310 template <>
311 constexpr unsigned getNumBits<ImageSamplingInfo>() {
312  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
313  "Not enough bits to encode ImageSamplingInfo value");
314  return 1;
315 }
316 template <>
317 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
318  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
319  "Not enough bits to encode ImageSamplerUseInfo value");
320  return 2;
321 }
322 template <>
323 constexpr unsigned getNumBits<ImageFormat>() {
324  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
325  "Not enough bits to encode ImageFormat value");
326  return 6;
327 }
328 
330 public:
331  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
332  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
333 
335  const KeyTy &key) {
336  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
337  }
338 
339  bool operator==(const KeyTy &key) const {
340  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
341  samplerUseInfo, format);
342  }
343 
345  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
346  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
347  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
348  format(std::get<6>(key)) {}
349 
357 };
358 
359 ImageType
360 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
361  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
362  value) {
363  return Base::get(std::get<0>(value).getContext(), value);
364 }
365 
366 Type ImageType::getElementType() const { return getImpl()->elementType; }
367 
368 Dim ImageType::getDim() const { return getImpl()->dim; }
369 
370 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
371 
372 ImageArrayedInfo ImageType::getArrayedInfo() const {
373  return getImpl()->arrayedInfo;
374 }
375 
376 ImageSamplingInfo ImageType::getSamplingInfo() const {
377  return getImpl()->samplingInfo;
378 }
379 
380 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
381  return getImpl()->samplerUseInfo;
382 }
383 
384 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
385 
387  std::optional<StorageClass>) {
388  // Image types do not require extra extensions thus far.
389 }
390 
393  std::optional<StorageClass>) {
394  if (auto dimCaps = spirv::getCapabilities(getDim()))
395  capabilities.push_back(*dimCaps);
396 
397  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
398  capabilities.push_back(*fmtCaps);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // PointerType
403 //===----------------------------------------------------------------------===//
404 
406  // (Type, StorageClass) as the key: Type stored in this struct, and
407  // StorageClass stored as TypeStorage's subclass data.
408  using KeyTy = std::pair<Type, StorageClass>;
409 
411  const KeyTy &key) {
412  return new (allocator.allocate<PointerTypeStorage>())
413  PointerTypeStorage(key);
414  }
415 
416  bool operator==(const KeyTy &key) const {
417  return key == KeyTy(pointeeType, storageClass);
418  }
419 
421  : pointeeType(key.first), storageClass(key.second) {}
422 
424  StorageClass storageClass;
425 };
426 
427 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
428  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
429 }
430 
431 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
432 
433 StorageClass PointerType::getStorageClass() const {
434  return getImpl()->storageClass;
435 }
436 
438  std::optional<StorageClass> storage) {
439  // Use this pointer type's storage class because this pointer indicates we are
440  // using the pointee type in that specific storage class.
441  llvm::cast<SPIRVType>(getPointeeType())
442  .getExtensions(extensions, getStorageClass());
443 
444  if (auto scExts = spirv::getExtensions(getStorageClass()))
445  extensions.push_back(*scExts);
446 }
447 
450  std::optional<StorageClass> storage) {
451  // Use this pointer type's storage class because this pointer indicates we are
452  // using the pointee type in that specific storage class.
453  llvm::cast<SPIRVType>(getPointeeType())
454  .getCapabilities(capabilities, getStorageClass());
455 
456  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
457  capabilities.push_back(*scCaps);
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // RuntimeArrayType
462 //===----------------------------------------------------------------------===//
463 
465  using KeyTy = std::pair<Type, unsigned>;
466 
468  const KeyTy &key) {
469  return new (allocator.allocate<RuntimeArrayTypeStorage>())
471  }
472 
473  bool operator==(const KeyTy &key) const {
474  return key == KeyTy(elementType, stride);
475  }
476 
478  : elementType(key.first), stride(key.second) {}
479 
481  unsigned stride;
482 };
483 
485  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
486 }
487 
488 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
489  return Base::get(elementType.getContext(), elementType, stride);
490 }
491 
492 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
493 
494 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
495 
498  std::optional<StorageClass> storage) {
499  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
500 }
501 
504  std::optional<StorageClass> storage) {
505  {
506  static const Capability caps[] = {Capability::Shader};
507  ArrayRef<Capability> ref(caps, std::size(caps));
508  capabilities.push_back(ref);
509  }
510  llvm::cast<SPIRVType>(getElementType())
511  .getCapabilities(capabilities, storage);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // ScalarType
516 //===----------------------------------------------------------------------===//
517 
519  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
520  return isValid(floatType);
521  }
522  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
523  return isValid(intType);
524  }
525  return false;
526 }
527 
528 bool ScalarType::isValid(FloatType type) {
529  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
530 }
531 
532 bool ScalarType::isValid(IntegerType type) {
533  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
534 }
535 
537  std::optional<StorageClass> storage) {
538  if (isa<BFloat16Type>(*this)) {
539  static const Extension ext = Extension::SPV_KHR_bfloat16;
540  extensions.push_back(ext);
541  }
542 
543  // 8- or 16-bit integer/floating-point numbers will require extra extensions
544  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
545  // SPV_KHR_8bit_storage for more details.
546  if (!storage)
547  return;
548 
549  switch (*storage) {
550  case StorageClass::PushConstant:
551  case StorageClass::StorageBuffer:
552  case StorageClass::Uniform:
553  if (getIntOrFloatBitWidth() == 8) {
554  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
555  ArrayRef<Extension> ref(exts, std::size(exts));
556  extensions.push_back(ref);
557  }
558  [[fallthrough]];
559  case StorageClass::Input:
560  case StorageClass::Output:
561  if (getIntOrFloatBitWidth() == 16) {
562  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
563  ArrayRef<Extension> ref(exts, std::size(exts));
564  extensions.push_back(ref);
565  }
566  break;
567  default:
568  break;
569  }
570 }
571 
574  std::optional<StorageClass> storage) {
575  unsigned bitwidth = getIntOrFloatBitWidth();
576 
577  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
578  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
579  // SPV_KHR_8bit_storage for more details.
580 
581 #define STORAGE_CASE(storage, cap8, cap16) \
582  case StorageClass::storage: { \
583  if (bitwidth == 8) { \
584  static const Capability caps[] = {Capability::cap8}; \
585  ArrayRef<Capability> ref(caps, std::size(caps)); \
586  capabilities.push_back(ref); \
587  return; \
588  } \
589  if (bitwidth == 16) { \
590  static const Capability caps[] = {Capability::cap16}; \
591  ArrayRef<Capability> ref(caps, std::size(caps)); \
592  capabilities.push_back(ref); \
593  return; \
594  } \
595  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
596  /* storage classes. Fall through to the next section. */ \
597  } break
598 
599  // This part only handles the cases where special bitwidths appearing in
600  // interface storage classes.
601  if (storage) {
602  switch (*storage) {
603  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
604  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
605  StorageBuffer16BitAccess);
606  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
607  StorageUniform16);
608  case StorageClass::Input:
609  case StorageClass::Output: {
610  if (bitwidth == 16) {
611  static const Capability caps[] = {Capability::StorageInputOutput16};
612  ArrayRef<Capability> ref(caps, std::size(caps));
613  capabilities.push_back(ref);
614  return;
615  }
616  break;
617  }
618  default:
619  break;
620  }
621  }
622 #undef STORAGE_CASE
623 
624  // For other non-interface storage classes, require a different set of
625  // capabilities for special bitwidths.
626 
627 #define WIDTH_CASE(type, width) \
628  case width: { \
629  static const Capability caps[] = {Capability::type##width}; \
630  ArrayRef<Capability> ref(caps, std::size(caps)); \
631  capabilities.push_back(ref); \
632  } break
633 
634  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
635  switch (bitwidth) {
636  WIDTH_CASE(Int, 8);
637  WIDTH_CASE(Int, 16);
638  WIDTH_CASE(Int, 64);
639  case 1:
640  case 32:
641  break;
642  default:
643  llvm_unreachable("invalid bitwidth to getCapabilities");
644  }
645  } else {
646  assert(llvm::isa<FloatType>(*this));
647  switch (bitwidth) {
648  case 16: {
649  if (isa<BFloat16Type>(*this)) {
650  static const Capability cap = Capability::BFloat16TypeKHR;
651  capabilities.push_back(cap);
652  } else {
653  static const Capability cap = Capability::Float16;
654  capabilities.push_back(cap);
655  }
656  break;
657  }
658  WIDTH_CASE(Float, 64);
659  case 32:
660  break;
661  default:
662  llvm_unreachable("invalid bitwidth to getCapabilities");
663  }
664  }
665 
666 #undef WIDTH_CASE
667 }
668 
669 std::optional<int64_t> ScalarType::getSizeInBytes() {
670  auto bitWidth = getIntOrFloatBitWidth();
671  // According to the SPIR-V spec:
672  // "There is no physical size or bit pattern defined for values with boolean
673  // type. If they are stored (in conjunction with OpVariable), they can only
674  // be used with logical addressing operations, not physical, and only with
675  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
676  // Private, Function, Input, and Output."
677  if (bitWidth == 1)
678  return std::nullopt;
679  return bitWidth / 8;
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // SPIRVType
684 //===----------------------------------------------------------------------===//
685 
687  // Allow SPIR-V dialect types
688  if (llvm::isa<SPIRVDialect>(type.getDialect()))
689  return true;
690  if (llvm::isa<ScalarType>(type))
691  return true;
692  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
693  return CompositeType::isValid(vectorType);
694  return false;
695 }
696 
698  return isIntOrFloat() || llvm::isa<VectorType>(*this);
699 }
700 
702  std::optional<StorageClass> storage) {
703  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
704  scalarType.getExtensions(extensions, storage);
705  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
706  compositeType.getExtensions(extensions, storage);
707  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
708  imageType.getExtensions(extensions, storage);
709  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
710  sampledImageType.getExtensions(extensions, storage);
711  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
712  matrixType.getExtensions(extensions, storage);
713  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
714  ptrType.getExtensions(extensions, storage);
715  } else {
716  llvm_unreachable("invalid SPIR-V Type to getExtensions");
717  }
718 }
719 
722  std::optional<StorageClass> storage) {
723  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
724  scalarType.getCapabilities(capabilities, storage);
725  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
726  compositeType.getCapabilities(capabilities, storage);
727  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
728  imageType.getCapabilities(capabilities, storage);
729  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
730  sampledImageType.getCapabilities(capabilities, storage);
731  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
732  matrixType.getCapabilities(capabilities, storage);
733  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
734  ptrType.getCapabilities(capabilities, storage);
735  } else {
736  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
737  }
738 }
739 
740 std::optional<int64_t> SPIRVType::getSizeInBytes() {
741  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
742  return scalarType.getSizeInBytes();
743  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
744  return compositeType.getSizeInBytes();
745  return std::nullopt;
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // SampledImageType
750 //===----------------------------------------------------------------------===//
752  using KeyTy = Type;
753 
754  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
755 
756  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
757 
759  const KeyTy &key) {
760  return new (allocator.allocate<SampledImageTypeStorage>())
762  }
763 
765 };
766 
768  return Base::get(imageType.getContext(), imageType);
769 }
770 
773  Type imageType) {
774  return Base::getChecked(emitError, imageType.getContext(), imageType);
775 }
776 
777 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
778 
779 LogicalResult
781  Type imageType) {
782  if (!llvm::isa<ImageType>(imageType))
783  return emitError() << "expected image type";
784 
785  return success();
786 }
787 
790  std::optional<StorageClass> storage) {
791  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
792 }
793 
796  std::optional<StorageClass> storage) {
797  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // StructType
802 //===----------------------------------------------------------------------===//
803 
804 /// Type storage for SPIR-V structure types:
805 ///
806 /// Structures are uniqued using:
807 /// - for identified structs:
808 /// - a string identifier;
809 /// - for literal structs:
810 /// - a list of member types;
811 /// - a list of member offset info;
812 /// - a list of member decoration info.
813 ///
814 /// Identified structures only have a mutable component consisting of:
815 /// - a list of member types;
816 /// - a list of member offset info;
817 /// - a list of member decoration info.
819  /// Construct a storage object for an identified struct type. A struct type
820  /// associated with such storage must call StructType::trySetBody(...) later
821  /// in order to mutate the storage object providing the actual content.
822  StructTypeStorage(StringRef identifier)
823  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
824  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
825  identifier(identifier) {}
826 
827  /// Construct a storage object for a literal struct type. A struct type
828  /// associated with such storage is immutable.
830  unsigned numMembers, Type const *memberTypes,
831  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
832  StructType::MemberDecorationInfo const *memberDecorationsInfo)
833  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
834  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
835  memberDecorationsInfo(memberDecorationsInfo) {}
836 
837  /// A storage key is divided into 2 parts:
838  /// - for identified structs:
839  /// - a StringRef representing the struct identifier;
840  /// - for literal structs:
841  /// - an ArrayRef<Type> for member types;
842  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
843  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
844  /// info.
845  ///
846  /// An identified struct type is uniqued only by the first part (field 0)
847  /// of the key.
848  ///
849  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
850  /// 3) of the key. The identifier field (field 0) must be empty.
851  using KeyTy =
852  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
854 
855  /// For identified structs, return true if the given key contains the same
856  /// identifier.
857  ///
858  /// For literal structs, return true if the given key contains a matching list
859  /// of member types + offset info + decoration info.
860  bool operator==(const KeyTy &key) const {
861  if (isIdentified()) {
862  // Identified types are uniqued by their identifier.
863  return getIdentifier() == std::get<0>(key);
864  }
865 
866  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
867  getMemberDecorationsInfo());
868  }
869 
870  /// If the given key contains a non-empty identifier, this method constructs
871  /// an identified struct and leaves the rest of the struct type data to be set
872  /// through a later call to StructType::trySetBody(...).
873  ///
874  /// If, on the other hand, the key contains an empty identifier, a literal
875  /// struct is constructed using the other fields of the key.
877  const KeyTy &key) {
878  StringRef keyIdentifier = std::get<0>(key);
879 
880  if (!keyIdentifier.empty()) {
881  StringRef identifier = allocator.copyInto(keyIdentifier);
882 
883  // Identified StructType body/members will be set through trySetBody(...)
884  // later.
885  return new (allocator.allocate<StructTypeStorage>())
886  StructTypeStorage(identifier);
887  }
888 
889  ArrayRef<Type> keyTypes = std::get<1>(key);
890 
891  // Copy the member type and layout information into the bump pointer
892  const Type *typesList = nullptr;
893  if (!keyTypes.empty()) {
894  typesList = allocator.copyInto(keyTypes).data();
895  }
896 
897  const StructType::OffsetInfo *offsetInfoList = nullptr;
898  if (!std::get<2>(key).empty()) {
899  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
900  assert(keyOffsetInfo.size() == keyTypes.size() &&
901  "size of offset information must be same as the size of number of "
902  "elements");
903  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
904  }
905 
906  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
907  unsigned numMemberDecorations = 0;
908  if (!std::get<3>(key).empty()) {
909  auto keyMemberDecorations = std::get<3>(key);
910  numMemberDecorations = keyMemberDecorations.size();
911  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
912  }
913 
914  return new (allocator.allocate<StructTypeStorage>())
915  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
916  numMemberDecorations, memberDecorationList);
917  }
918 
920  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
921  }
922 
924  if (offsetInfo) {
925  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
926  }
927  return {};
928  }
929 
931  if (memberDecorationsInfo) {
932  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
933  numMemberDecorations);
934  }
935  return {};
936  }
937 
938  StringRef getIdentifier() const { return identifier; }
939 
940  bool isIdentified() const { return !identifier.empty(); }
941 
942  /// Sets the struct type content for identified structs. Calling this method
943  /// is only valid for identified structs.
944  ///
945  /// Fails under the following conditions:
946  /// - If called for a literal struct;
947  /// - If called for an identified struct whose body was set before (through a
948  /// call to this method) but with different contents from the passed
949  /// arguments.
950  LogicalResult mutate(
951  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
952  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
953  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
954  if (!isIdentified())
955  return failure();
956 
957  if (memberTypesAndIsBodySet.getInt() &&
958  (getMemberTypes() != structMemberTypes ||
959  getOffsetInfo() != structOffsetInfo ||
960  getMemberDecorationsInfo() != structMemberDecorationInfo))
961  return failure();
962 
963  memberTypesAndIsBodySet.setInt(true);
964  numMembers = structMemberTypes.size();
965 
966  // Copy the member type and layout information into the bump pointer.
967  if (!structMemberTypes.empty())
968  memberTypesAndIsBodySet.setPointer(
969  allocator.copyInto(structMemberTypes).data());
970 
971  if (!structOffsetInfo.empty()) {
972  assert(structOffsetInfo.size() == structMemberTypes.size() &&
973  "size of offset information must be same as the size of number of "
974  "elements");
975  offsetInfo = allocator.copyInto(structOffsetInfo).data();
976  }
977 
978  if (!structMemberDecorationInfo.empty()) {
979  numMemberDecorations = structMemberDecorationInfo.size();
980  memberDecorationsInfo =
981  allocator.copyInto(structMemberDecorationInfo).data();
982  }
983 
984  return success();
985  }
986 
987  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
989  unsigned numMembers;
992  StringRef identifier;
993 };
994 
998  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
999  assert(!memberTypes.empty() && "Struct needs at least one member type");
1000  // Sort the decorations.
1002  memberDecorations);
1003  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1004  return Base::get(memberTypes.vec().front().getContext(),
1005  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1006  sortedDecorations);
1007 }
1008 
1010  StringRef identifier) {
1011  assert(!identifier.empty() &&
1012  "StructType identifier must be non-empty string");
1013 
1014  return Base::get(context, identifier, ArrayRef<Type>(),
1017 }
1018 
1019 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1020  StructType newStructType = Base::get(
1021  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1023  // Set an empty body in case this is a identified struct.
1024  if (newStructType.isIdentified() &&
1025  failed(newStructType.trySetBody(
1028  return StructType();
1029 
1030  return newStructType;
1031 }
1032 
1033 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1034 
1035 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1036 
1037 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1038 
1039 Type StructType::getElementType(unsigned index) const {
1040  assert(getNumElements() > index && "member index out of range");
1041  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1042 }
1043 
1045  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1046  getNumElements());
1047 }
1048 
1049 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1050 
1051 uint64_t StructType::getMemberOffset(unsigned index) const {
1052  assert(getNumElements() > index && "member index out of range");
1053  return getImpl()->offsetInfo[index];
1054 }
1055 
1058  const {
1059  memberDecorations.clear();
1060  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1061  memberDecorations.append(implMemberDecorations.begin(),
1062  implMemberDecorations.end());
1063 }
1064 
1066  unsigned index,
1067  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1068  assert(getNumElements() > index && "member index out of range");
1069  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1070  decorationsInfo.clear();
1071  for (const auto &memberDecoration : memberDecorations) {
1072  if (memberDecoration.memberIndex == index) {
1073  decorationsInfo.push_back(memberDecoration);
1074  }
1075  if (memberDecoration.memberIndex > index) {
1076  // Early exit since the decorations are stored sorted.
1077  return;
1078  }
1079  }
1080 }
1081 
1082 LogicalResult
1084  ArrayRef<OffsetInfo> offsetInfo,
1085  ArrayRef<MemberDecorationInfo> memberDecorations) {
1086  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1087 }
1088 
1090  std::optional<StorageClass> storage) {
1091  for (Type elementType : getElementTypes())
1092  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1093 }
1094 
1097  std::optional<StorageClass> storage) {
1098  for (Type elementType : getElementTypes())
1099  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1100 }
1101 
1102 llvm::hash_code spirv::hash_value(
1103  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1104  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1105  memberDecorationInfo.decoration);
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // MatrixType
1110 //===----------------------------------------------------------------------===//
1111 
1113  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1114  : columnType(columnType), columnCount(columnCount) {}
1115 
1116  using KeyTy = std::tuple<Type, uint32_t>;
1117 
1119  const KeyTy &key) {
1120 
1121  // Initialize the memory using placement new.
1122  return new (allocator.allocate<MatrixTypeStorage>())
1123  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1124  }
1125 
1126  bool operator==(const KeyTy &key) const {
1127  return key == KeyTy(columnType, columnCount);
1128  }
1129 
1131  const uint32_t columnCount;
1132 };
1133 
1134 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1135  return Base::get(columnType.getContext(), columnType, columnCount);
1136 }
1137 
1139  Type columnType, uint32_t columnCount) {
1140  return Base::getChecked(emitError, columnType.getContext(), columnType,
1141  columnCount);
1142 }
1143 
1144 LogicalResult
1146  Type columnType, uint32_t columnCount) {
1147  if (columnCount < 2 || columnCount > 4)
1148  return emitError() << "matrix can have 2, 3, or 4 columns only";
1149 
1150  if (!isValidColumnType(columnType))
1151  return emitError() << "matrix columns must be vectors of floats";
1152 
1153  /// The underlying vectors (columns) must be of size 2, 3, or 4
1154  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1155  if (columnShape.size() != 1)
1156  return emitError() << "matrix columns must be 1D vectors";
1157 
1158  if (columnShape[0] < 2 || columnShape[0] > 4)
1159  return emitError() << "matrix columns must be of size 2, 3, or 4";
1160 
1161  return success();
1162 }
1163 
1164 /// Returns true if the matrix elements are vectors of float elements
1166  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1167  if (llvm::isa<FloatType>(vectorType.getElementType()))
1168  return true;
1169  }
1170  return false;
1171 }
1172 
1173 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1174 
1176  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1177 }
1178 
1179 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1180 
1181 unsigned MatrixType::getNumRows() const {
1182  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1183 }
1184 
1185 unsigned MatrixType::getNumElements() const {
1186  return (getImpl()->columnCount) * getNumRows();
1187 }
1188 
1190  std::optional<StorageClass> storage) {
1191  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1192 }
1193 
1196  std::optional<StorageClass> storage) {
1197  {
1198  static const Capability caps[] = {Capability::Matrix};
1199  ArrayRef<Capability> ref(caps, std::size(caps));
1200  capabilities.push_back(ref);
1201  }
1202  // Add any capabilities associated with the underlying vectors (i.e., columns)
1203  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1204 }
1205 
1206 //===----------------------------------------------------------------------===//
1207 // SPIR-V Dialect
1208 //===----------------------------------------------------------------------===//
1209 
1210 void SPIRVDialect::registerTypes() {
1213 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:317
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:323
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:289
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:305
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:311
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:293
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:299
int64_t rows
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:82
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:70
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:75
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:157
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:177
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:139
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:143
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:119
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:108
static bool classof(Type type)
Definition: SPIRVTypes.cpp:94
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:261
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:247
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:252
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:235
ArrayRef< int64_t > getShape() const
Definition: SPIRVTypes.cpp:257
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:267
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:275
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:263
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:370
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:372
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:384
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:380
Type getElementType() const
Definition: SPIRVTypes.cpp:366
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:376
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:391
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:386
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:437
Type getPointeeType() const
Definition: SPIRVTypes.cpp:431
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:433
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:448
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:502
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:494
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:496
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:740
static bool classof(Type type)
Definition: SPIRVTypes.cpp:686
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:720
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:701
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:780
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:794
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:772
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:788
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:767
static bool classof(Type type)
Definition: SPIRVTypes.cpp:518
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:536
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:528
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:572
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:669
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:996
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:32
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:34
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:39
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:211
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:214
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:339
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:334
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:332
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:416
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:408
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:410
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:467
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:473
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:756
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:758
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:818
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:923
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:829
StructType::OffsetInfo const * offsetInfo
Definition: SPIRVTypes.cpp:988
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:860
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:950
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:876
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:919
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:822
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:930
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:853
StructType::MemberDecorationInfo const * memberDecorationsInfo
Definition: SPIRVTypes.cpp:991
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
Definition: SPIRVTypes.cpp:987