MLIR  19.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 #include <cstdint>
22 #include <iterator>
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 //===----------------------------------------------------------------------===//
28 // ArrayType
29 //===----------------------------------------------------------------------===//
30 
32  using KeyTy = std::tuple<Type, unsigned, unsigned>;
33 
35  const KeyTy &key) {
36  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37  }
38 
39  bool operator==(const KeyTy &key) const {
40  return key == KeyTy(elementType, elementCount, stride);
41  }
42 
43  ArrayTypeStorage(const KeyTy &key)
44  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45  stride(std::get<2>(key)) {}
46 
48  unsigned elementCount;
49  unsigned stride;
50 };
51 
52 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53  assert(elementCount && "ArrayType needs at least one element");
54  return Base::get(elementType.getContext(), elementType, elementCount,
55  /*stride=*/0);
56 }
57 
58 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59  unsigned stride) {
60  assert(elementCount && "ArrayType needs at least one element");
61  return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63 
64 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65 
66 Type ArrayType::getElementType() const { return getImpl()->elementType; }
67 
68 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69 
71  std::optional<StorageClass> storage) {
72  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
73 }
74 
77  std::optional<StorageClass> storage) {
78  llvm::cast<SPIRVType>(getElementType())
79  .getCapabilities(capabilities, storage);
80 }
81 
82 std::optional<int64_t> ArrayType::getSizeInBytes() {
83  auto elementType = llvm::cast<SPIRVType>(getElementType());
84  std::optional<int64_t> size = elementType.getSizeInBytes();
85  if (!size)
86  return std::nullopt;
87  return (*size + getArrayStride()) * getNumElements();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // CompositeType
92 //===----------------------------------------------------------------------===//
93 
95  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96  return isValid(vectorType);
100 }
101 
102 bool CompositeType::isValid(VectorType type) {
103  return type.getRank() == 1 &&
104  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105  llvm::isa<ScalarType>(type.getElementType());
106 }
107 
108 Type CompositeType::getElementType(unsigned index) const {
109  return TypeSwitch<Type, Type>(*this)
111  RuntimeArrayType, VectorType>(
112  [](auto type) { return type.getElementType(); })
113  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
114  .Case<StructType>(
115  [index](StructType type) { return type.getElementType(index); })
116  .Default(
117  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
118 }
119 
121  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
122  return arrayType.getNumElements();
123  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
124  return matrixType.getNumColumns();
125  if (auto structType = llvm::dyn_cast<StructType>(*this))
126  return structType.getNumElements();
127  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
128  return vectorType.getNumElements();
129  if (llvm::isa<CooperativeMatrixType>(*this)) {
130  llvm_unreachable(
131  "invalid to query number of elements of spirv Cooperative Matrix type");
132  }
133  if (llvm::isa<JointMatrixINTELType>(*this)) {
134  llvm_unreachable(
135  "invalid to query number of elements of spirv::JointMatrix type");
136  }
137  if (llvm::isa<RuntimeArrayType>(*this)) {
138  llvm_unreachable(
139  "invalid to query number of elements of spirv::RuntimeArray type");
140  }
141  llvm_unreachable("invalid composite type");
142 }
143 
145  return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
146  RuntimeArrayType>(*this);
147 }
148 
151  std::optional<StorageClass> storage) {
152  TypeSwitch<Type>(*this)
155  [&](auto type) { type.getExtensions(extensions, storage); })
156  .Case<VectorType>([&](VectorType type) {
157  return llvm::cast<ScalarType>(type.getElementType())
158  .getExtensions(extensions, storage);
159  })
160  .Default([](Type) { llvm_unreachable("invalid composite type"); });
161 }
162 
165  std::optional<StorageClass> storage) {
166  TypeSwitch<Type>(*this)
169  [&](auto type) { type.getCapabilities(capabilities, storage); })
170  .Case<VectorType>([&](VectorType type) {
171  auto vecSize = getNumElements();
172  if (vecSize == 8 || vecSize == 16) {
173  static const Capability caps[] = {Capability::Vector16};
174  ArrayRef<Capability> ref(caps, std::size(caps));
175  capabilities.push_back(ref);
176  }
177  return llvm::cast<ScalarType>(type.getElementType())
178  .getCapabilities(capabilities, storage);
179  })
180  .Default([](Type) { llvm_unreachable("invalid composite type"); });
181 }
182 
183 std::optional<int64_t> CompositeType::getSizeInBytes() {
184  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
185  return arrayType.getSizeInBytes();
186  if (auto structType = llvm::dyn_cast<StructType>(*this))
187  return structType.getSizeInBytes();
188  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
189  std::optional<int64_t> elementSize =
190  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
191  if (!elementSize)
192  return std::nullopt;
193  return *elementSize * vectorType.getNumElements();
194  }
195  return std::nullopt;
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // CooperativeMatrixType
200 //===----------------------------------------------------------------------===//
201 
203  using KeyTy =
204  std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
205 
207  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
208  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
210  }
211 
212  bool operator==(const KeyTy &key) const {
213  return key == KeyTy(elementType, rows, columns, scope, use);
214  }
215 
217  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
218  columns(std::get<2>(key)), scope(std::get<3>(key)),
219  use(std::get<4>(key)) {}
220 
222  uint32_t rows;
223  uint32_t columns;
224  Scope scope;
225  CooperativeMatrixUseKHR use;
226 };
227 
229  uint32_t rows,
230  uint32_t columns, Scope scope,
231  CooperativeMatrixUseKHR use) {
232  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
233  use);
234 }
235 
237  return getImpl()->elementType;
238 }
239 
240 uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
241 
243  return getImpl()->columns;
244 }
245 
246 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
247 
248 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
249  return getImpl()->use;
250 }
251 
254  std::optional<StorageClass> storage) {
255  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
256  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
257  extensions.push_back(exts);
258 }
259 
262  std::optional<StorageClass> storage) {
263  llvm::cast<SPIRVType>(getElementType())
264  .getCapabilities(capabilities, storage);
265  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
266  capabilities.push_back(caps);
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // JointMatrixType
271 //===----------------------------------------------------------------------===//
272 
274  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
275 
277  const KeyTy &key) {
278  return new (allocator.allocate<JointMatrixTypeStorage>())
280  }
281 
282  bool operator==(const KeyTy &key) const {
283  return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
284  }
285 
287  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
288  columns(std::get<2>(key)), scope(std::get<4>(key)),
289  matrixLayout(std::get<3>(key)) {}
290 
292  unsigned rows;
293  unsigned columns;
294  Scope scope;
295  MatrixLayout matrixLayout;
296 };
297 
299  unsigned rows, unsigned columns,
300  MatrixLayout matrixLayout) {
301  return Base::get(elementType.getContext(), elementType, rows, columns,
302  matrixLayout, scope);
303 }
304 
306  return getImpl()->elementType;
307 }
308 
309 Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
310 
311 unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
312 
313 unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
314 
316  return getImpl()->matrixLayout;
317 }
318 
321  std::optional<StorageClass> storage) {
322  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
323  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
324  ArrayRef<Extension> ref(exts, std::size(exts));
325  extensions.push_back(ref);
326 }
327 
330  std::optional<StorageClass> storage) {
331  llvm::cast<SPIRVType>(getElementType())
332  .getCapabilities(capabilities, storage);
333  static const Capability caps[] = {Capability::JointMatrixINTEL};
334  ArrayRef<Capability> ref(caps, std::size(caps));
335  capabilities.push_back(ref);
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // ImageType
340 //===----------------------------------------------------------------------===//
341 
342 template <typename T>
343 static constexpr unsigned getNumBits() {
344  return 0;
345 }
346 template <>
347 constexpr unsigned getNumBits<Dim>() {
348  static_assert((1 << 3) > getMaxEnumValForDim(),
349  "Not enough bits to encode Dim value");
350  return 3;
351 }
352 template <>
353 constexpr unsigned getNumBits<ImageDepthInfo>() {
354  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
355  "Not enough bits to encode ImageDepthInfo value");
356  return 2;
357 }
358 template <>
359 constexpr unsigned getNumBits<ImageArrayedInfo>() {
360  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
361  "Not enough bits to encode ImageArrayedInfo value");
362  return 1;
363 }
364 template <>
365 constexpr unsigned getNumBits<ImageSamplingInfo>() {
366  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
367  "Not enough bits to encode ImageSamplingInfo value");
368  return 1;
369 }
370 template <>
371 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
372  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
373  "Not enough bits to encode ImageSamplerUseInfo value");
374  return 2;
375 }
376 template <>
377 constexpr unsigned getNumBits<ImageFormat>() {
378  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
379  "Not enough bits to encode ImageFormat value");
380  return 6;
381 }
382 
384 public:
385  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
386  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
387 
389  const KeyTy &key) {
390  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
391  }
392 
393  bool operator==(const KeyTy &key) const {
394  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
395  samplerUseInfo, format);
396  }
397 
399  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
400  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
401  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
402  format(std::get<6>(key)) {}
403 
411 };
412 
413 ImageType
414 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
415  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
416  value) {
417  return Base::get(std::get<0>(value).getContext(), value);
418 }
419 
420 Type ImageType::getElementType() const { return getImpl()->elementType; }
421 
422 Dim ImageType::getDim() const { return getImpl()->dim; }
423 
424 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
425 
426 ImageArrayedInfo ImageType::getArrayedInfo() const {
427  return getImpl()->arrayedInfo;
428 }
429 
430 ImageSamplingInfo ImageType::getSamplingInfo() const {
431  return getImpl()->samplingInfo;
432 }
433 
434 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
435  return getImpl()->samplerUseInfo;
436 }
437 
438 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
439 
441  std::optional<StorageClass>) {
442  // Image types do not require extra extensions thus far.
443 }
444 
447  std::optional<StorageClass>) {
448  if (auto dimCaps = spirv::getCapabilities(getDim()))
449  capabilities.push_back(*dimCaps);
450 
451  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
452  capabilities.push_back(*fmtCaps);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // PointerType
457 //===----------------------------------------------------------------------===//
458 
460  // (Type, StorageClass) as the key: Type stored in this struct, and
461  // StorageClass stored as TypeStorage's subclass data.
462  using KeyTy = std::pair<Type, StorageClass>;
463 
465  const KeyTy &key) {
466  return new (allocator.allocate<PointerTypeStorage>())
467  PointerTypeStorage(key);
468  }
469 
470  bool operator==(const KeyTy &key) const {
471  return key == KeyTy(pointeeType, storageClass);
472  }
473 
475  : pointeeType(key.first), storageClass(key.second) {}
476 
478  StorageClass storageClass;
479 };
480 
481 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
482  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
483 }
484 
485 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
486 
487 StorageClass PointerType::getStorageClass() const {
488  return getImpl()->storageClass;
489 }
490 
492  std::optional<StorageClass> storage) {
493  // Use this pointer type's storage class because this pointer indicates we are
494  // using the pointee type in that specific storage class.
495  llvm::cast<SPIRVType>(getPointeeType())
496  .getExtensions(extensions, getStorageClass());
497 
498  if (auto scExts = spirv::getExtensions(getStorageClass()))
499  extensions.push_back(*scExts);
500 }
501 
504  std::optional<StorageClass> storage) {
505  // Use this pointer type's storage class because this pointer indicates we are
506  // using the pointee type in that specific storage class.
507  llvm::cast<SPIRVType>(getPointeeType())
508  .getCapabilities(capabilities, getStorageClass());
509 
510  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
511  capabilities.push_back(*scCaps);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // RuntimeArrayType
516 //===----------------------------------------------------------------------===//
517 
519  using KeyTy = std::pair<Type, unsigned>;
520 
522  const KeyTy &key) {
523  return new (allocator.allocate<RuntimeArrayTypeStorage>())
525  }
526 
527  bool operator==(const KeyTy &key) const {
528  return key == KeyTy(elementType, stride);
529  }
530 
532  : elementType(key.first), stride(key.second) {}
533 
535  unsigned stride;
536 };
537 
539  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
540 }
541 
542 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
543  return Base::get(elementType.getContext(), elementType, stride);
544 }
545 
546 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
547 
548 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
549 
552  std::optional<StorageClass> storage) {
553  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
554 }
555 
558  std::optional<StorageClass> storage) {
559  {
560  static const Capability caps[] = {Capability::Shader};
561  ArrayRef<Capability> ref(caps, std::size(caps));
562  capabilities.push_back(ref);
563  }
564  llvm::cast<SPIRVType>(getElementType())
565  .getCapabilities(capabilities, storage);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // ScalarType
570 //===----------------------------------------------------------------------===//
571 
573  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
574  return isValid(floatType);
575  }
576  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
577  return isValid(intType);
578  }
579  return false;
580 }
581 
583  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
584 }
585 
586 bool ScalarType::isValid(IntegerType type) {
587  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
588 }
589 
591  std::optional<StorageClass> storage) {
592  // 8- or 16-bit integer/floating-point numbers will require extra extensions
593  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
594  // SPV_KHR_8bit_storage for more details.
595  if (!storage)
596  return;
597 
598  switch (*storage) {
599  case StorageClass::PushConstant:
600  case StorageClass::StorageBuffer:
601  case StorageClass::Uniform:
602  if (getIntOrFloatBitWidth() == 8) {
603  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
604  ArrayRef<Extension> ref(exts, std::size(exts));
605  extensions.push_back(ref);
606  }
607  [[fallthrough]];
608  case StorageClass::Input:
609  case StorageClass::Output:
610  if (getIntOrFloatBitWidth() == 16) {
611  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
612  ArrayRef<Extension> ref(exts, std::size(exts));
613  extensions.push_back(ref);
614  }
615  break;
616  default:
617  break;
618  }
619 }
620 
623  std::optional<StorageClass> storage) {
624  unsigned bitwidth = getIntOrFloatBitWidth();
625 
626  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
627  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
628  // SPV_KHR_8bit_storage for more details.
629 
630 #define STORAGE_CASE(storage, cap8, cap16) \
631  case StorageClass::storage: { \
632  if (bitwidth == 8) { \
633  static const Capability caps[] = {Capability::cap8}; \
634  ArrayRef<Capability> ref(caps, std::size(caps)); \
635  capabilities.push_back(ref); \
636  return; \
637  } \
638  if (bitwidth == 16) { \
639  static const Capability caps[] = {Capability::cap16}; \
640  ArrayRef<Capability> ref(caps, std::size(caps)); \
641  capabilities.push_back(ref); \
642  return; \
643  } \
644  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
645  /* storage classes. Fall through to the next section. */ \
646  } break
647 
648  // This part only handles the cases where special bitwidths appearing in
649  // interface storage classes.
650  if (storage) {
651  switch (*storage) {
652  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
653  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
654  StorageBuffer16BitAccess);
655  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
656  StorageUniform16);
657  case StorageClass::Input:
658  case StorageClass::Output: {
659  if (bitwidth == 16) {
660  static const Capability caps[] = {Capability::StorageInputOutput16};
661  ArrayRef<Capability> ref(caps, std::size(caps));
662  capabilities.push_back(ref);
663  return;
664  }
665  break;
666  }
667  default:
668  break;
669  }
670  }
671 #undef STORAGE_CASE
672 
673  // For other non-interface storage classes, require a different set of
674  // capabilities for special bitwidths.
675 
676 #define WIDTH_CASE(type, width) \
677  case width: { \
678  static const Capability caps[] = {Capability::type##width}; \
679  ArrayRef<Capability> ref(caps, std::size(caps)); \
680  capabilities.push_back(ref); \
681  } break
682 
683  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
684  switch (bitwidth) {
685  WIDTH_CASE(Int, 8);
686  WIDTH_CASE(Int, 16);
687  WIDTH_CASE(Int, 64);
688  case 1:
689  case 32:
690  break;
691  default:
692  llvm_unreachable("invalid bitwidth to getCapabilities");
693  }
694  } else {
695  assert(llvm::isa<FloatType>(*this));
696  switch (bitwidth) {
697  WIDTH_CASE(Float, 16);
698  WIDTH_CASE(Float, 64);
699  case 32:
700  break;
701  default:
702  llvm_unreachable("invalid bitwidth to getCapabilities");
703  }
704  }
705 
706 #undef WIDTH_CASE
707 }
708 
709 std::optional<int64_t> ScalarType::getSizeInBytes() {
710  auto bitWidth = getIntOrFloatBitWidth();
711  // According to the SPIR-V spec:
712  // "There is no physical size or bit pattern defined for values with boolean
713  // type. If they are stored (in conjunction with OpVariable), they can only
714  // be used with logical addressing operations, not physical, and only with
715  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
716  // Private, Function, Input, and Output."
717  if (bitWidth == 1)
718  return std::nullopt;
719  return bitWidth / 8;
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // SPIRVType
724 //===----------------------------------------------------------------------===//
725 
727  // Allow SPIR-V dialect types
728  if (llvm::isa<SPIRVDialect>(type.getDialect()))
729  return true;
730  if (llvm::isa<ScalarType>(type))
731  return true;
732  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
733  return CompositeType::isValid(vectorType);
734  return false;
735 }
736 
738  return isIntOrFloat() || llvm::isa<VectorType>(*this);
739 }
740 
742  std::optional<StorageClass> storage) {
743  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
744  scalarType.getExtensions(extensions, storage);
745  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
746  compositeType.getExtensions(extensions, storage);
747  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
748  imageType.getExtensions(extensions, storage);
749  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
750  sampledImageType.getExtensions(extensions, storage);
751  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
752  matrixType.getExtensions(extensions, storage);
753  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
754  ptrType.getExtensions(extensions, storage);
755  } else {
756  llvm_unreachable("invalid SPIR-V Type to getExtensions");
757  }
758 }
759 
762  std::optional<StorageClass> storage) {
763  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
764  scalarType.getCapabilities(capabilities, storage);
765  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
766  compositeType.getCapabilities(capabilities, storage);
767  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
768  imageType.getCapabilities(capabilities, storage);
769  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
770  sampledImageType.getCapabilities(capabilities, storage);
771  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
772  matrixType.getCapabilities(capabilities, storage);
773  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
774  ptrType.getCapabilities(capabilities, storage);
775  } else {
776  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
777  }
778 }
779 
780 std::optional<int64_t> SPIRVType::getSizeInBytes() {
781  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
782  return scalarType.getSizeInBytes();
783  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
784  return compositeType.getSizeInBytes();
785  return std::nullopt;
786 }
787 
788 //===----------------------------------------------------------------------===//
789 // SampledImageType
790 //===----------------------------------------------------------------------===//
792  using KeyTy = Type;
793 
794  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
795 
796  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
797 
799  const KeyTy &key) {
800  return new (allocator.allocate<SampledImageTypeStorage>())
802  }
803 
805 };
806 
808  return Base::get(imageType.getContext(), imageType);
809 }
810 
813  Type imageType) {
814  return Base::getChecked(emitError, imageType.getContext(), imageType);
815 }
816 
817 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
818 
821  Type imageType) {
822  if (!llvm::isa<ImageType>(imageType))
823  return emitError() << "expected image type";
824 
825  return success();
826 }
827 
830  std::optional<StorageClass> storage) {
831  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
832 }
833 
836  std::optional<StorageClass> storage) {
837  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // StructType
842 //===----------------------------------------------------------------------===//
843 
844 /// Type storage for SPIR-V structure types:
845 ///
846 /// Structures are uniqued using:
847 /// - for identified structs:
848 /// - a string identifier;
849 /// - for literal structs:
850 /// - a list of member types;
851 /// - a list of member offset info;
852 /// - a list of member decoration info.
853 ///
854 /// Identified structures only have a mutable component consisting of:
855 /// - a list of member types;
856 /// - a list of member offset info;
857 /// - a list of member decoration info.
859  /// Construct a storage object for an identified struct type. A struct type
860  /// associated with such storage must call StructType::trySetBody(...) later
861  /// in order to mutate the storage object providing the actual content.
862  StructTypeStorage(StringRef identifier)
863  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
864  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
865  identifier(identifier) {}
866 
867  /// Construct a storage object for a literal struct type. A struct type
868  /// associated with such storage is immutable.
870  unsigned numMembers, Type const *memberTypes,
871  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
872  StructType::MemberDecorationInfo const *memberDecorationsInfo)
873  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
874  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
875  memberDecorationsInfo(memberDecorationsInfo) {}
876 
877  /// A storage key is divided into 2 parts:
878  /// - for identified structs:
879  /// - a StringRef representing the struct identifier;
880  /// - for literal structs:
881  /// - an ArrayRef<Type> for member types;
882  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
883  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
884  /// info.
885  ///
886  /// An identified struct type is uniqued only by the first part (field 0)
887  /// of the key.
888  ///
889  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
890  /// 3) of the key. The identifier field (field 0) must be empty.
891  using KeyTy =
892  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
894 
895  /// For identified structs, return true if the given key contains the same
896  /// identifier.
897  ///
898  /// For literal structs, return true if the given key contains a matching list
899  /// of member types + offset info + decoration info.
900  bool operator==(const KeyTy &key) const {
901  if (isIdentified()) {
902  // Identified types are uniqued by their identifier.
903  return getIdentifier() == std::get<0>(key);
904  }
905 
906  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
907  getMemberDecorationsInfo());
908  }
909 
910  /// If the given key contains a non-empty identifier, this method constructs
911  /// an identified struct and leaves the rest of the struct type data to be set
912  /// through a later call to StructType::trySetBody(...).
913  ///
914  /// If, on the other hand, the key contains an empty identifier, a literal
915  /// struct is constructed using the other fields of the key.
917  const KeyTy &key) {
918  StringRef keyIdentifier = std::get<0>(key);
919 
920  if (!keyIdentifier.empty()) {
921  StringRef identifier = allocator.copyInto(keyIdentifier);
922 
923  // Identified StructType body/members will be set through trySetBody(...)
924  // later.
925  return new (allocator.allocate<StructTypeStorage>())
926  StructTypeStorage(identifier);
927  }
928 
929  ArrayRef<Type> keyTypes = std::get<1>(key);
930 
931  // Copy the member type and layout information into the bump pointer
932  const Type *typesList = nullptr;
933  if (!keyTypes.empty()) {
934  typesList = allocator.copyInto(keyTypes).data();
935  }
936 
937  const StructType::OffsetInfo *offsetInfoList = nullptr;
938  if (!std::get<2>(key).empty()) {
939  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
940  assert(keyOffsetInfo.size() == keyTypes.size() &&
941  "size of offset information must be same as the size of number of "
942  "elements");
943  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
944  }
945 
946  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
947  unsigned numMemberDecorations = 0;
948  if (!std::get<3>(key).empty()) {
949  auto keyMemberDecorations = std::get<3>(key);
950  numMemberDecorations = keyMemberDecorations.size();
951  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
952  }
953 
954  return new (allocator.allocate<StructTypeStorage>())
955  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
956  numMemberDecorations, memberDecorationList);
957  }
958 
960  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
961  }
962 
964  if (offsetInfo) {
965  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
966  }
967  return {};
968  }
969 
971  if (memberDecorationsInfo) {
972  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
973  numMemberDecorations);
974  }
975  return {};
976  }
977 
978  StringRef getIdentifier() const { return identifier; }
979 
980  bool isIdentified() const { return !identifier.empty(); }
981 
982  /// Sets the struct type content for identified structs. Calling this method
983  /// is only valid for identified structs.
984  ///
985  /// Fails under the following conditions:
986  /// - If called for a literal struct;
987  /// - If called for an identified struct whose body was set before (through a
988  /// call to this method) but with different contents from the passed
989  /// arguments.
991  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
992  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
993  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
994  if (!isIdentified())
995  return failure();
996 
997  if (memberTypesAndIsBodySet.getInt() &&
998  (getMemberTypes() != structMemberTypes ||
999  getOffsetInfo() != structOffsetInfo ||
1000  getMemberDecorationsInfo() != structMemberDecorationInfo))
1001  return failure();
1002 
1003  memberTypesAndIsBodySet.setInt(true);
1004  numMembers = structMemberTypes.size();
1005 
1006  // Copy the member type and layout information into the bump pointer.
1007  if (!structMemberTypes.empty())
1008  memberTypesAndIsBodySet.setPointer(
1009  allocator.copyInto(structMemberTypes).data());
1010 
1011  if (!structOffsetInfo.empty()) {
1012  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1013  "size of offset information must be same as the size of number of "
1014  "elements");
1015  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1016  }
1017 
1018  if (!structMemberDecorationInfo.empty()) {
1019  numMemberDecorations = structMemberDecorationInfo.size();
1020  memberDecorationsInfo =
1021  allocator.copyInto(structMemberDecorationInfo).data();
1022  }
1023 
1024  return success();
1025  }
1026 
1027  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1029  unsigned numMembers;
1032  StringRef identifier;
1033 };
1034 
1035 StructType
1038  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1039  assert(!memberTypes.empty() && "Struct needs at least one member type");
1040  // Sort the decorations.
1042  memberDecorations.begin(), memberDecorations.end());
1043  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1044  return Base::get(memberTypes.vec().front().getContext(),
1045  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1046  sortedDecorations);
1047 }
1048 
1050  StringRef identifier) {
1051  assert(!identifier.empty() &&
1052  "StructType identifier must be non-empty string");
1053 
1054  return Base::get(context, identifier, ArrayRef<Type>(),
1057 }
1058 
1059 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1060  StructType newStructType = Base::get(
1061  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1063  // Set an empty body in case this is a identified struct.
1064  if (newStructType.isIdentified() &&
1065  failed(newStructType.trySetBody(
1068  return StructType();
1069 
1070  return newStructType;
1071 }
1072 
1073 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1074 
1075 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1076 
1077 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1078 
1079 Type StructType::getElementType(unsigned index) const {
1080  assert(getNumElements() > index && "member index out of range");
1081  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1082 }
1083 
1085  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1086  getNumElements());
1087 }
1088 
1089 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1090 
1091 uint64_t StructType::getMemberOffset(unsigned index) const {
1092  assert(getNumElements() > index && "member index out of range");
1093  return getImpl()->offsetInfo[index];
1094 }
1095 
1098  const {
1099  memberDecorations.clear();
1100  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1101  memberDecorations.append(implMemberDecorations.begin(),
1102  implMemberDecorations.end());
1103 }
1104 
1106  unsigned index,
1107  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1108  assert(getNumElements() > index && "member index out of range");
1109  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1110  decorationsInfo.clear();
1111  for (const auto &memberDecoration : memberDecorations) {
1112  if (memberDecoration.memberIndex == index) {
1113  decorationsInfo.push_back(memberDecoration);
1114  }
1115  if (memberDecoration.memberIndex > index) {
1116  // Early exit since the decorations are stored sorted.
1117  return;
1118  }
1119  }
1120 }
1121 
1124  ArrayRef<OffsetInfo> offsetInfo,
1125  ArrayRef<MemberDecorationInfo> memberDecorations) {
1126  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1127 }
1128 
1130  std::optional<StorageClass> storage) {
1131  for (Type elementType : getElementTypes())
1132  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1133 }
1134 
1137  std::optional<StorageClass> storage) {
1138  for (Type elementType : getElementTypes())
1139  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1140 }
1141 
1142 llvm::hash_code spirv::hash_value(
1143  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1144  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1145  memberDecorationInfo.decoration);
1146 }
1147 
1148 //===----------------------------------------------------------------------===//
1149 // MatrixType
1150 //===----------------------------------------------------------------------===//
1151 
1153  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1154  : columnType(columnType), columnCount(columnCount) {}
1155 
1156  using KeyTy = std::tuple<Type, uint32_t>;
1157 
1159  const KeyTy &key) {
1160 
1161  // Initialize the memory using placement new.
1162  return new (allocator.allocate<MatrixTypeStorage>())
1163  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1164  }
1165 
1166  bool operator==(const KeyTy &key) const {
1167  return key == KeyTy(columnType, columnCount);
1168  }
1169 
1171  const uint32_t columnCount;
1172 };
1173 
1174 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1175  return Base::get(columnType.getContext(), columnType, columnCount);
1176 }
1177 
1179  Type columnType, uint32_t columnCount) {
1180  return Base::getChecked(emitError, columnType.getContext(), columnType,
1181  columnCount);
1182 }
1183 
1185  Type columnType, uint32_t columnCount) {
1186  if (columnCount < 2 || columnCount > 4)
1187  return emitError() << "matrix can have 2, 3, or 4 columns only";
1188 
1189  if (!isValidColumnType(columnType))
1190  return emitError() << "matrix columns must be vectors of floats";
1191 
1192  /// The underlying vectors (columns) must be of size 2, 3, or 4
1193  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1194  if (columnShape.size() != 1)
1195  return emitError() << "matrix columns must be 1D vectors";
1196 
1197  if (columnShape[0] < 2 || columnShape[0] > 4)
1198  return emitError() << "matrix columns must be of size 2, 3, or 4";
1199 
1200  return success();
1201 }
1202 
1203 /// Returns true if the matrix elements are vectors of float elements
1205  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206  if (llvm::isa<FloatType>(vectorType.getElementType()))
1207  return true;
1208  }
1209  return false;
1210 }
1211 
1212 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1213 
1215  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1216 }
1217 
1218 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1219 
1220 unsigned MatrixType::getNumRows() const {
1221  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1222 }
1223 
1224 unsigned MatrixType::getNumElements() const {
1225  return (getImpl()->columnCount) * getNumRows();
1226 }
1227 
1229  std::optional<StorageClass> storage) {
1230  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1231 }
1232 
1235  std::optional<StorageClass> storage) {
1236  {
1237  static const Capability caps[] = {Capability::Matrix};
1238  ArrayRef<Capability> ref(caps, std::size(caps));
1239  capabilities.push_back(ref);
1240  }
1241  // Add any capabilities associated with the underlying vectors (i.e., columns)
1242  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1243 }
1244 
1245 //===----------------------------------------------------------------------===//
1246 // SPIR-V Dialect
1247 //===----------------------------------------------------------------------===//
1248 
1249 void SPIRVDialect::registerTypes() {
1252  StructType>();
1253 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:371
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:377
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:343
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:359
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:365
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:347
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:353
unsigned getWidth()
Return the bitwidth of this float type.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isBF16() const
Definition: Types.cpp:48
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:82
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:70
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:75
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:163
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:183
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:144
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:149
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:120
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:108
static bool classof(Type type)
Definition: SPIRVTypes.cpp:94
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:246
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:242
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:228
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:252
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:260
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:248
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:169
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:424
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:426
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:438
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:434
Type getElementType() const
Definition: SPIRVTypes.cpp:420
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:430
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:445
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:440
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:309
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:328
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:319
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:313
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:298
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:311
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:315
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:491
Type getPointeeType() const
Definition: SPIRVTypes.cpp:485
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:487
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:502
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:556
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:550
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:780
static bool classof(Type type)
Definition: SPIRVTypes.cpp:726
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:760
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:741
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:820
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:834
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:812
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:828
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:807
static bool classof(Type type)
Definition: SPIRVTypes.cpp:572
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:590
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:582
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:621
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:709
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:32
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:34
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:39
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:207
std::tuple< Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:204
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:393
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:388
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:386
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
Definition: SPIRVTypes.cpp:274
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:276
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:282
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:470
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:462
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:464
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:521
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:527
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:796
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:798
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:858
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:963
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:869
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:900
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:990
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:916
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:959
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:862
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:970
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:893
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet