MLIR  21.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
32  using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
35  const KeyTy &key) {
36  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37  }
38 
39  bool operator==(const KeyTy &key) const {
40  return key == KeyTy(elementType, elementCount, stride);
41  }
42 
43  ArrayTypeStorage(const KeyTy &key)
44  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45  stride(std::get<2>(key)) {}
46 
48  unsigned elementCount;
49  unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53  assert(elementCount && "ArrayType needs at least one element");
54  return Base::get(elementType.getContext(), elementType, elementCount,
55  /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59  unsigned stride) {
60  assert(elementCount && "ArrayType needs at least one element");
61  return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
71  std::optional<StorageClass> storage) {
72  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
77  std::optional<StorageClass> storage) {
78  llvm::cast<SPIRVType>(getElementType())
79  .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83  auto elementType = llvm::cast<SPIRVType>(getElementType());
84  std::optional<int64_t> size = elementType.getSizeInBytes();
85  if (!size)
86  return std::nullopt;
87  return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
95  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96  return isValid(vectorType);
99  spirv::StructType>(type);
100 }
101 
102 bool CompositeType::isValid(VectorType type) {
103  return type.getRank() == 1 &&
104  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105  llvm::isa<ScalarType>(type.getElementType());
106 }
107 
108 Type CompositeType::getElementType(unsigned index) const {
109  return TypeSwitch<Type, Type>(*this)
110  .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
111  [](auto type) { return type.getElementType(); })
112  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
113  .Case<StructType>(
114  [index](StructType type) { return type.getElementType(index); })
115  .Default(
116  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117 }
118 
120  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
121  return arrayType.getNumElements();
122  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
123  return matrixType.getNumColumns();
124  if (auto structType = llvm::dyn_cast<StructType>(*this))
125  return structType.getNumElements();
126  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
127  return vectorType.getNumElements();
128  if (llvm::isa<CooperativeMatrixType>(*this)) {
129  llvm_unreachable(
130  "invalid to query number of elements of spirv Cooperative Matrix type");
131  }
132  if (llvm::isa<RuntimeArrayType>(*this)) {
133  llvm_unreachable(
134  "invalid to query number of elements of spirv::RuntimeArray type");
135  }
136  llvm_unreachable("invalid composite type");
137 }
138 
140  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
145  std::optional<StorageClass> storage) {
146  TypeSwitch<Type>(*this)
148  StructType>(
149  [&](auto type) { type.getExtensions(extensions, storage); })
150  .Case<VectorType>([&](VectorType type) {
151  return llvm::cast<ScalarType>(type.getElementType())
152  .getExtensions(extensions, storage);
153  })
154  .Default([](Type) { llvm_unreachable("invalid composite type"); });
155 }
156 
159  std::optional<StorageClass> storage) {
160  TypeSwitch<Type>(*this)
162  StructType>(
163  [&](auto type) { type.getCapabilities(capabilities, storage); })
164  .Case<VectorType>([&](VectorType type) {
165  auto vecSize = getNumElements();
166  if (vecSize == 8 || vecSize == 16) {
167  static const Capability caps[] = {Capability::Vector16};
168  ArrayRef<Capability> ref(caps, std::size(caps));
169  capabilities.push_back(ref);
170  }
171  return llvm::cast<ScalarType>(type.getElementType())
172  .getCapabilities(capabilities, storage);
173  })
174  .Default([](Type) { llvm_unreachable("invalid composite type"); });
175 }
176 
177 std::optional<int64_t> CompositeType::getSizeInBytes() {
178  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
179  return arrayType.getSizeInBytes();
180  if (auto structType = llvm::dyn_cast<StructType>(*this))
181  return structType.getSizeInBytes();
182  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
183  std::optional<int64_t> elementSize =
184  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
185  if (!elementSize)
186  return std::nullopt;
187  return *elementSize * vectorType.getNumElements();
188  }
189  return std::nullopt;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // CooperativeMatrixType
194 //===----------------------------------------------------------------------===//
195 
197  // In the specification dimensions of the Cooperative Matrix are 32-bit
198  // integers --- the initial implementation kept those values as such. However,
199  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
200  // as 32-bits and expose it as int64_t through `getShape`, however, this
201  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
202  // 32-bits integers would require an extra logic and storage. So, we diverge
203  // from the spec and internally represent the dimensions as 64-bit integers,
204  // so we can easily return an `ArrayRef` from `getShape` without any extra
205  // logic. Alternatively, we could store both rows and columns (both 32-bits)
206  // and shape (64-bits), assigning rows and columns to shape whenever
207  // `getShape` is called. This would be at the cost of extra logic and storage.
208  // Note: Because `ArrayRef` is returned we cannot construct an object in
209  // `getShape` on the fly.
210  using KeyTy =
211  std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
212 
214  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
215  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
217  }
218 
219  bool operator==(const KeyTy &key) const {
220  return key == KeyTy(elementType, shape[0], shape[1], scope, use);
221  }
222 
224  : elementType(std::get<0>(key)),
225  shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
226  use(std::get<4>(key)) {}
227 
229  // [#rows, #columns]
230  std::array<int64_t, 2> shape;
231  Scope scope;
232  CooperativeMatrixUseKHR use;
233 };
234 
236  uint32_t rows,
237  uint32_t columns, Scope scope,
238  CooperativeMatrixUseKHR use) {
239  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
240  use);
241 }
242 
244  return getImpl()->elementType;
245 }
246 
248  assert(getImpl()->shape[0] != ShapedType::kDynamic);
249  return static_cast<uint32_t>(getImpl()->shape[0]);
250 }
251 
253  assert(getImpl()->shape[1] != ShapedType::kDynamic);
254  return static_cast<uint32_t>(getImpl()->shape[1]);
255 }
256 
258  return getImpl()->shape;
259 }
260 
261 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
262 
263 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
264  return getImpl()->use;
265 }
266 
269  std::optional<StorageClass> storage) {
270  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
271  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
272  extensions.push_back(exts);
273 }
274 
277  std::optional<StorageClass> storage) {
278  llvm::cast<SPIRVType>(getElementType())
279  .getCapabilities(capabilities, storage);
280  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
281  capabilities.push_back(caps);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // ImageType
286 //===----------------------------------------------------------------------===//
287 
288 template <typename T>
289 static constexpr unsigned getNumBits() {
290  return 0;
291 }
292 template <>
293 constexpr unsigned getNumBits<Dim>() {
294  static_assert((1 << 3) > getMaxEnumValForDim(),
295  "Not enough bits to encode Dim value");
296  return 3;
297 }
298 template <>
299 constexpr unsigned getNumBits<ImageDepthInfo>() {
300  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
301  "Not enough bits to encode ImageDepthInfo value");
302  return 2;
303 }
304 template <>
305 constexpr unsigned getNumBits<ImageArrayedInfo>() {
306  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
307  "Not enough bits to encode ImageArrayedInfo value");
308  return 1;
309 }
310 template <>
311 constexpr unsigned getNumBits<ImageSamplingInfo>() {
312  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
313  "Not enough bits to encode ImageSamplingInfo value");
314  return 1;
315 }
316 template <>
317 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
318  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
319  "Not enough bits to encode ImageSamplerUseInfo value");
320  return 2;
321 }
322 template <>
323 constexpr unsigned getNumBits<ImageFormat>() {
324  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
325  "Not enough bits to encode ImageFormat value");
326  return 6;
327 }
328 
330 public:
331  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
332  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
333 
335  const KeyTy &key) {
336  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
337  }
338 
339  bool operator==(const KeyTy &key) const {
340  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
341  samplerUseInfo, format);
342  }
343 
345  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
346  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
347  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
348  format(std::get<6>(key)) {}
349 
357 };
358 
359 ImageType
360 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
361  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
362  value) {
363  return Base::get(std::get<0>(value).getContext(), value);
364 }
365 
366 Type ImageType::getElementType() const { return getImpl()->elementType; }
367 
368 Dim ImageType::getDim() const { return getImpl()->dim; }
369 
370 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
371 
372 ImageArrayedInfo ImageType::getArrayedInfo() const {
373  return getImpl()->arrayedInfo;
374 }
375 
376 ImageSamplingInfo ImageType::getSamplingInfo() const {
377  return getImpl()->samplingInfo;
378 }
379 
380 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
381  return getImpl()->samplerUseInfo;
382 }
383 
384 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
385 
387  std::optional<StorageClass>) {
388  // Image types do not require extra extensions thus far.
389 }
390 
393  std::optional<StorageClass>) {
394  if (auto dimCaps = spirv::getCapabilities(getDim()))
395  capabilities.push_back(*dimCaps);
396 
397  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
398  capabilities.push_back(*fmtCaps);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // PointerType
403 //===----------------------------------------------------------------------===//
404 
406  // (Type, StorageClass) as the key: Type stored in this struct, and
407  // StorageClass stored as TypeStorage's subclass data.
408  using KeyTy = std::pair<Type, StorageClass>;
409 
411  const KeyTy &key) {
412  return new (allocator.allocate<PointerTypeStorage>())
413  PointerTypeStorage(key);
414  }
415 
416  bool operator==(const KeyTy &key) const {
417  return key == KeyTy(pointeeType, storageClass);
418  }
419 
421  : pointeeType(key.first), storageClass(key.second) {}
422 
424  StorageClass storageClass;
425 };
426 
427 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
428  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
429 }
430 
431 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
432 
433 StorageClass PointerType::getStorageClass() const {
434  return getImpl()->storageClass;
435 }
436 
438  std::optional<StorageClass> storage) {
439  // Use this pointer type's storage class because this pointer indicates we are
440  // using the pointee type in that specific storage class.
441  llvm::cast<SPIRVType>(getPointeeType())
442  .getExtensions(extensions, getStorageClass());
443 
444  if (auto scExts = spirv::getExtensions(getStorageClass()))
445  extensions.push_back(*scExts);
446 }
447 
450  std::optional<StorageClass> storage) {
451  // Use this pointer type's storage class because this pointer indicates we are
452  // using the pointee type in that specific storage class.
453  llvm::cast<SPIRVType>(getPointeeType())
454  .getCapabilities(capabilities, getStorageClass());
455 
456  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
457  capabilities.push_back(*scCaps);
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // RuntimeArrayType
462 //===----------------------------------------------------------------------===//
463 
465  using KeyTy = std::pair<Type, unsigned>;
466 
468  const KeyTy &key) {
469  return new (allocator.allocate<RuntimeArrayTypeStorage>())
471  }
472 
473  bool operator==(const KeyTy &key) const {
474  return key == KeyTy(elementType, stride);
475  }
476 
478  : elementType(key.first), stride(key.second) {}
479 
481  unsigned stride;
482 };
483 
485  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
486 }
487 
488 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
489  return Base::get(elementType.getContext(), elementType, stride);
490 }
491 
492 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
493 
494 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
495 
498  std::optional<StorageClass> storage) {
499  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
500 }
501 
504  std::optional<StorageClass> storage) {
505  {
506  static const Capability caps[] = {Capability::Shader};
507  ArrayRef<Capability> ref(caps, std::size(caps));
508  capabilities.push_back(ref);
509  }
510  llvm::cast<SPIRVType>(getElementType())
511  .getCapabilities(capabilities, storage);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // ScalarType
516 //===----------------------------------------------------------------------===//
517 
519  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
520  return isValid(floatType);
521  }
522  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
523  return isValid(intType);
524  }
525  return false;
526 }
527 
528 bool ScalarType::isValid(FloatType type) {
529  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
530 }
531 
532 bool ScalarType::isValid(IntegerType type) {
533  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
534 }
535 
537  std::optional<StorageClass> storage) {
538  // 8- or 16-bit integer/floating-point numbers will require extra extensions
539  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
540  // SPV_KHR_8bit_storage for more details.
541  if (!storage)
542  return;
543 
544  switch (*storage) {
545  case StorageClass::PushConstant:
546  case StorageClass::StorageBuffer:
547  case StorageClass::Uniform:
548  if (getIntOrFloatBitWidth() == 8) {
549  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
550  ArrayRef<Extension> ref(exts, std::size(exts));
551  extensions.push_back(ref);
552  }
553  [[fallthrough]];
554  case StorageClass::Input:
555  case StorageClass::Output:
556  if (getIntOrFloatBitWidth() == 16) {
557  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
558  ArrayRef<Extension> ref(exts, std::size(exts));
559  extensions.push_back(ref);
560  }
561  break;
562  default:
563  break;
564  }
565 }
566 
569  std::optional<StorageClass> storage) {
570  unsigned bitwidth = getIntOrFloatBitWidth();
571 
572  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
573  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
574  // SPV_KHR_8bit_storage for more details.
575 
576 #define STORAGE_CASE(storage, cap8, cap16) \
577  case StorageClass::storage: { \
578  if (bitwidth == 8) { \
579  static const Capability caps[] = {Capability::cap8}; \
580  ArrayRef<Capability> ref(caps, std::size(caps)); \
581  capabilities.push_back(ref); \
582  return; \
583  } \
584  if (bitwidth == 16) { \
585  static const Capability caps[] = {Capability::cap16}; \
586  ArrayRef<Capability> ref(caps, std::size(caps)); \
587  capabilities.push_back(ref); \
588  return; \
589  } \
590  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
591  /* storage classes. Fall through to the next section. */ \
592  } break
593 
594  // This part only handles the cases where special bitwidths appearing in
595  // interface storage classes.
596  if (storage) {
597  switch (*storage) {
598  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
599  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
600  StorageBuffer16BitAccess);
601  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
602  StorageUniform16);
603  case StorageClass::Input:
604  case StorageClass::Output: {
605  if (bitwidth == 16) {
606  static const Capability caps[] = {Capability::StorageInputOutput16};
607  ArrayRef<Capability> ref(caps, std::size(caps));
608  capabilities.push_back(ref);
609  return;
610  }
611  break;
612  }
613  default:
614  break;
615  }
616  }
617 #undef STORAGE_CASE
618 
619  // For other non-interface storage classes, require a different set of
620  // capabilities for special bitwidths.
621 
622 #define WIDTH_CASE(type, width) \
623  case width: { \
624  static const Capability caps[] = {Capability::type##width}; \
625  ArrayRef<Capability> ref(caps, std::size(caps)); \
626  capabilities.push_back(ref); \
627  } break
628 
629  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
630  switch (bitwidth) {
631  WIDTH_CASE(Int, 8);
632  WIDTH_CASE(Int, 16);
633  WIDTH_CASE(Int, 64);
634  case 1:
635  case 32:
636  break;
637  default:
638  llvm_unreachable("invalid bitwidth to getCapabilities");
639  }
640  } else {
641  assert(llvm::isa<FloatType>(*this));
642  switch (bitwidth) {
643  WIDTH_CASE(Float, 16);
644  WIDTH_CASE(Float, 64);
645  case 32:
646  break;
647  default:
648  llvm_unreachable("invalid bitwidth to getCapabilities");
649  }
650  }
651 
652 #undef WIDTH_CASE
653 }
654 
655 std::optional<int64_t> ScalarType::getSizeInBytes() {
656  auto bitWidth = getIntOrFloatBitWidth();
657  // According to the SPIR-V spec:
658  // "There is no physical size or bit pattern defined for values with boolean
659  // type. If they are stored (in conjunction with OpVariable), they can only
660  // be used with logical addressing operations, not physical, and only with
661  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
662  // Private, Function, Input, and Output."
663  if (bitWidth == 1)
664  return std::nullopt;
665  return bitWidth / 8;
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // SPIRVType
670 //===----------------------------------------------------------------------===//
671 
673  // Allow SPIR-V dialect types
674  if (llvm::isa<SPIRVDialect>(type.getDialect()))
675  return true;
676  if (llvm::isa<ScalarType>(type))
677  return true;
678  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
679  return CompositeType::isValid(vectorType);
680  return false;
681 }
682 
684  return isIntOrFloat() || llvm::isa<VectorType>(*this);
685 }
686 
688  std::optional<StorageClass> storage) {
689  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
690  scalarType.getExtensions(extensions, storage);
691  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
692  compositeType.getExtensions(extensions, storage);
693  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
694  imageType.getExtensions(extensions, storage);
695  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
696  sampledImageType.getExtensions(extensions, storage);
697  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
698  matrixType.getExtensions(extensions, storage);
699  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
700  ptrType.getExtensions(extensions, storage);
701  } else {
702  llvm_unreachable("invalid SPIR-V Type to getExtensions");
703  }
704 }
705 
708  std::optional<StorageClass> storage) {
709  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
710  scalarType.getCapabilities(capabilities, storage);
711  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
712  compositeType.getCapabilities(capabilities, storage);
713  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
714  imageType.getCapabilities(capabilities, storage);
715  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
716  sampledImageType.getCapabilities(capabilities, storage);
717  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
718  matrixType.getCapabilities(capabilities, storage);
719  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
720  ptrType.getCapabilities(capabilities, storage);
721  } else {
722  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
723  }
724 }
725 
726 std::optional<int64_t> SPIRVType::getSizeInBytes() {
727  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
728  return scalarType.getSizeInBytes();
729  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
730  return compositeType.getSizeInBytes();
731  return std::nullopt;
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // SampledImageType
736 //===----------------------------------------------------------------------===//
738  using KeyTy = Type;
739 
740  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
741 
742  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
743 
745  const KeyTy &key) {
746  return new (allocator.allocate<SampledImageTypeStorage>())
748  }
749 
751 };
752 
754  return Base::get(imageType.getContext(), imageType);
755 }
756 
759  Type imageType) {
760  return Base::getChecked(emitError, imageType.getContext(), imageType);
761 }
762 
763 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
764 
765 LogicalResult
767  Type imageType) {
768  if (!llvm::isa<ImageType>(imageType))
769  return emitError() << "expected image type";
770 
771  return success();
772 }
773 
776  std::optional<StorageClass> storage) {
777  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
778 }
779 
782  std::optional<StorageClass> storage) {
783  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // StructType
788 //===----------------------------------------------------------------------===//
789 
790 /// Type storage for SPIR-V structure types:
791 ///
792 /// Structures are uniqued using:
793 /// - for identified structs:
794 /// - a string identifier;
795 /// - for literal structs:
796 /// - a list of member types;
797 /// - a list of member offset info;
798 /// - a list of member decoration info.
799 ///
800 /// Identified structures only have a mutable component consisting of:
801 /// - a list of member types;
802 /// - a list of member offset info;
803 /// - a list of member decoration info.
805  /// Construct a storage object for an identified struct type. A struct type
806  /// associated with such storage must call StructType::trySetBody(...) later
807  /// in order to mutate the storage object providing the actual content.
808  StructTypeStorage(StringRef identifier)
809  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
810  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
811  identifier(identifier) {}
812 
813  /// Construct a storage object for a literal struct type. A struct type
814  /// associated with such storage is immutable.
816  unsigned numMembers, Type const *memberTypes,
817  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
818  StructType::MemberDecorationInfo const *memberDecorationsInfo)
819  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
820  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
821  memberDecorationsInfo(memberDecorationsInfo) {}
822 
823  /// A storage key is divided into 2 parts:
824  /// - for identified structs:
825  /// - a StringRef representing the struct identifier;
826  /// - for literal structs:
827  /// - an ArrayRef<Type> for member types;
828  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
829  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
830  /// info.
831  ///
832  /// An identified struct type is uniqued only by the first part (field 0)
833  /// of the key.
834  ///
835  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
836  /// 3) of the key. The identifier field (field 0) must be empty.
837  using KeyTy =
838  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
840 
841  /// For identified structs, return true if the given key contains the same
842  /// identifier.
843  ///
844  /// For literal structs, return true if the given key contains a matching list
845  /// of member types + offset info + decoration info.
846  bool operator==(const KeyTy &key) const {
847  if (isIdentified()) {
848  // Identified types are uniqued by their identifier.
849  return getIdentifier() == std::get<0>(key);
850  }
851 
852  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
853  getMemberDecorationsInfo());
854  }
855 
856  /// If the given key contains a non-empty identifier, this method constructs
857  /// an identified struct and leaves the rest of the struct type data to be set
858  /// through a later call to StructType::trySetBody(...).
859  ///
860  /// If, on the other hand, the key contains an empty identifier, a literal
861  /// struct is constructed using the other fields of the key.
863  const KeyTy &key) {
864  StringRef keyIdentifier = std::get<0>(key);
865 
866  if (!keyIdentifier.empty()) {
867  StringRef identifier = allocator.copyInto(keyIdentifier);
868 
869  // Identified StructType body/members will be set through trySetBody(...)
870  // later.
871  return new (allocator.allocate<StructTypeStorage>())
872  StructTypeStorage(identifier);
873  }
874 
875  ArrayRef<Type> keyTypes = std::get<1>(key);
876 
877  // Copy the member type and layout information into the bump pointer
878  const Type *typesList = nullptr;
879  if (!keyTypes.empty()) {
880  typesList = allocator.copyInto(keyTypes).data();
881  }
882 
883  const StructType::OffsetInfo *offsetInfoList = nullptr;
884  if (!std::get<2>(key).empty()) {
885  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
886  assert(keyOffsetInfo.size() == keyTypes.size() &&
887  "size of offset information must be same as the size of number of "
888  "elements");
889  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
890  }
891 
892  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
893  unsigned numMemberDecorations = 0;
894  if (!std::get<3>(key).empty()) {
895  auto keyMemberDecorations = std::get<3>(key);
896  numMemberDecorations = keyMemberDecorations.size();
897  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
898  }
899 
900  return new (allocator.allocate<StructTypeStorage>())
901  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
902  numMemberDecorations, memberDecorationList);
903  }
904 
906  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
907  }
908 
910  if (offsetInfo) {
911  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
912  }
913  return {};
914  }
915 
917  if (memberDecorationsInfo) {
918  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
919  numMemberDecorations);
920  }
921  return {};
922  }
923 
924  StringRef getIdentifier() const { return identifier; }
925 
926  bool isIdentified() const { return !identifier.empty(); }
927 
928  /// Sets the struct type content for identified structs. Calling this method
929  /// is only valid for identified structs.
930  ///
931  /// Fails under the following conditions:
932  /// - If called for a literal struct;
933  /// - If called for an identified struct whose body was set before (through a
934  /// call to this method) but with different contents from the passed
935  /// arguments.
936  LogicalResult mutate(
937  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
938  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
939  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
940  if (!isIdentified())
941  return failure();
942 
943  if (memberTypesAndIsBodySet.getInt() &&
944  (getMemberTypes() != structMemberTypes ||
945  getOffsetInfo() != structOffsetInfo ||
946  getMemberDecorationsInfo() != structMemberDecorationInfo))
947  return failure();
948 
949  memberTypesAndIsBodySet.setInt(true);
950  numMembers = structMemberTypes.size();
951 
952  // Copy the member type and layout information into the bump pointer.
953  if (!structMemberTypes.empty())
954  memberTypesAndIsBodySet.setPointer(
955  allocator.copyInto(structMemberTypes).data());
956 
957  if (!structOffsetInfo.empty()) {
958  assert(structOffsetInfo.size() == structMemberTypes.size() &&
959  "size of offset information must be same as the size of number of "
960  "elements");
961  offsetInfo = allocator.copyInto(structOffsetInfo).data();
962  }
963 
964  if (!structMemberDecorationInfo.empty()) {
965  numMemberDecorations = structMemberDecorationInfo.size();
966  memberDecorationsInfo =
967  allocator.copyInto(structMemberDecorationInfo).data();
968  }
969 
970  return success();
971  }
972 
973  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
975  unsigned numMembers;
978  StringRef identifier;
979 };
980 
984  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
985  assert(!memberTypes.empty() && "Struct needs at least one member type");
986  // Sort the decorations.
988  memberDecorations);
989  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
990  return Base::get(memberTypes.vec().front().getContext(),
991  /*identifier=*/StringRef(), memberTypes, offsetInfo,
992  sortedDecorations);
993 }
994 
996  StringRef identifier) {
997  assert(!identifier.empty() &&
998  "StructType identifier must be non-empty string");
999 
1000  return Base::get(context, identifier, ArrayRef<Type>(),
1003 }
1004 
1005 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1006  StructType newStructType = Base::get(
1007  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1009  // Set an empty body in case this is a identified struct.
1010  if (newStructType.isIdentified() &&
1011  failed(newStructType.trySetBody(
1014  return StructType();
1015 
1016  return newStructType;
1017 }
1018 
1019 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1020 
1021 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1022 
1023 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1024 
1025 Type StructType::getElementType(unsigned index) const {
1026  assert(getNumElements() > index && "member index out of range");
1027  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1028 }
1029 
1031  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1032  getNumElements());
1033 }
1034 
1035 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1036 
1037 uint64_t StructType::getMemberOffset(unsigned index) const {
1038  assert(getNumElements() > index && "member index out of range");
1039  return getImpl()->offsetInfo[index];
1040 }
1041 
1044  const {
1045  memberDecorations.clear();
1046  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1047  memberDecorations.append(implMemberDecorations.begin(),
1048  implMemberDecorations.end());
1049 }
1050 
1052  unsigned index,
1053  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1054  assert(getNumElements() > index && "member index out of range");
1055  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1056  decorationsInfo.clear();
1057  for (const auto &memberDecoration : memberDecorations) {
1058  if (memberDecoration.memberIndex == index) {
1059  decorationsInfo.push_back(memberDecoration);
1060  }
1061  if (memberDecoration.memberIndex > index) {
1062  // Early exit since the decorations are stored sorted.
1063  return;
1064  }
1065  }
1066 }
1067 
1068 LogicalResult
1070  ArrayRef<OffsetInfo> offsetInfo,
1071  ArrayRef<MemberDecorationInfo> memberDecorations) {
1072  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1073 }
1074 
1076  std::optional<StorageClass> storage) {
1077  for (Type elementType : getElementTypes())
1078  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1079 }
1080 
1083  std::optional<StorageClass> storage) {
1084  for (Type elementType : getElementTypes())
1085  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1086 }
1087 
1088 llvm::hash_code spirv::hash_value(
1089  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1090  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1091  memberDecorationInfo.decoration);
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // MatrixType
1096 //===----------------------------------------------------------------------===//
1097 
1099  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1100  : columnType(columnType), columnCount(columnCount) {}
1101 
1102  using KeyTy = std::tuple<Type, uint32_t>;
1103 
1105  const KeyTy &key) {
1106 
1107  // Initialize the memory using placement new.
1108  return new (allocator.allocate<MatrixTypeStorage>())
1109  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1110  }
1111 
1112  bool operator==(const KeyTy &key) const {
1113  return key == KeyTy(columnType, columnCount);
1114  }
1115 
1117  const uint32_t columnCount;
1118 };
1119 
1120 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1121  return Base::get(columnType.getContext(), columnType, columnCount);
1122 }
1123 
1125  Type columnType, uint32_t columnCount) {
1126  return Base::getChecked(emitError, columnType.getContext(), columnType,
1127  columnCount);
1128 }
1129 
1130 LogicalResult
1132  Type columnType, uint32_t columnCount) {
1133  if (columnCount < 2 || columnCount > 4)
1134  return emitError() << "matrix can have 2, 3, or 4 columns only";
1135 
1136  if (!isValidColumnType(columnType))
1137  return emitError() << "matrix columns must be vectors of floats";
1138 
1139  /// The underlying vectors (columns) must be of size 2, 3, or 4
1140  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1141  if (columnShape.size() != 1)
1142  return emitError() << "matrix columns must be 1D vectors";
1143 
1144  if (columnShape[0] < 2 || columnShape[0] > 4)
1145  return emitError() << "matrix columns must be of size 2, 3, or 4";
1146 
1147  return success();
1148 }
1149 
1150 /// Returns true if the matrix elements are vectors of float elements
1152  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1153  if (llvm::isa<FloatType>(vectorType.getElementType()))
1154  return true;
1155  }
1156  return false;
1157 }
1158 
1159 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1160 
1162  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1163 }
1164 
1165 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1166 
1167 unsigned MatrixType::getNumRows() const {
1168  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1169 }
1170 
1171 unsigned MatrixType::getNumElements() const {
1172  return (getImpl()->columnCount) * getNumRows();
1173 }
1174 
1176  std::optional<StorageClass> storage) {
1177  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1178 }
1179 
1182  std::optional<StorageClass> storage) {
1183  {
1184  static const Capability caps[] = {Capability::Matrix};
1185  ArrayRef<Capability> ref(caps, std::size(caps));
1186  capabilities.push_back(ref);
1187  }
1188  // Add any capabilities associated with the underlying vectors (i.e., columns)
1189  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1190 }
1191 
1192 //===----------------------------------------------------------------------===//
1193 // SPIR-V Dialect
1194 //===----------------------------------------------------------------------===//
1195 
1196 void SPIRVDialect::registerTypes() {
1199 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:317
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:323
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:289
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:305
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:311
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:293
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:299
int64_t rows
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:82
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:70
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:75
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:157
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:177
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:139
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:143
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:119
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:108
static bool classof(Type type)
Definition: SPIRVTypes.cpp:94
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:261
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:247
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:252
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:235
ArrayRef< int64_t > getShape() const
Definition: SPIRVTypes.cpp:257
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:267
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:275
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:263
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:370
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:372
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:384
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:380
Type getElementType() const
Definition: SPIRVTypes.cpp:366
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:376
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:391
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:386
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:437
Type getPointeeType() const
Definition: SPIRVTypes.cpp:431
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:433
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:448
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:502
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:494
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:496
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:726
static bool classof(Type type)
Definition: SPIRVTypes.cpp:672
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:706
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:687
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:766
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:780
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:758
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:774
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:753
static bool classof(Type type)
Definition: SPIRVTypes.cpp:518
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:536
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:528
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:567
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:655
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:995
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:982
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:32
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:34
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:39
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:211
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:214
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:339
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:334
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:332
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:416
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:408
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:410
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:467
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:473
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:742
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:744
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:804
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:909
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:815
StructType::OffsetInfo const * offsetInfo
Definition: SPIRVTypes.cpp:974
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:846
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:936
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:862
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:905
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:808
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:916
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:839
StructType::MemberDecorationInfo const * memberDecorationsInfo
Definition: SPIRVTypes.cpp:977
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
Definition: SPIRVTypes.cpp:973