MLIR  16.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 #include <iterator>
21 
22 using namespace mlir;
23 using namespace mlir::spirv;
24 
25 //===----------------------------------------------------------------------===//
26 // ArrayType
27 //===----------------------------------------------------------------------===//
28 
30  using KeyTy = std::tuple<Type, unsigned, unsigned>;
31 
33  const KeyTy &key) {
34  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
35  }
36 
37  bool operator==(const KeyTy &key) const {
38  return key == KeyTy(elementType, elementCount, stride);
39  }
40 
41  ArrayTypeStorage(const KeyTy &key)
42  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
43  stride(std::get<2>(key)) {}
44 
46  unsigned elementCount;
47  unsigned stride;
48 };
49 
50 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
51  assert(elementCount && "ArrayType needs at least one element");
52  return Base::get(elementType.getContext(), elementType, elementCount,
53  /*stride=*/0);
54 }
55 
56 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
57  unsigned stride) {
58  assert(elementCount && "ArrayType needs at least one element");
59  return Base::get(elementType.getContext(), elementType, elementCount, stride);
60 }
61 
62 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
63 
64 Type ArrayType::getElementType() const { return getImpl()->elementType; }
65 
66 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67 
69  Optional<StorageClass> storage) {
70  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
71 }
72 
75  Optional<StorageClass> storage) {
76  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
77 }
78 
80  auto elementType = getElementType().cast<SPIRVType>();
81  Optional<int64_t> size = elementType.getSizeInBytes();
82  if (!size)
83  return llvm::None;
84  return (*size + getArrayStride()) * getNumElements();
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // CompositeType
89 //===----------------------------------------------------------------------===//
90 
92  if (auto vectorType = type.dyn_cast<VectorType>())
93  return isValid(vectorType);
97 }
98 
99 bool CompositeType::isValid(VectorType type) {
100  switch (type.getNumElements()) {
101  case 2:
102  case 3:
103  case 4:
104  case 8:
105  case 16:
106  break;
107  default:
108  return false;
109  }
110  return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
111 }
112 
113 Type CompositeType::getElementType(unsigned index) const {
114  return TypeSwitch<Type, Type>(*this)
116  RuntimeArrayType, VectorType>(
117  [](auto type) { return type.getElementType(); })
118  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
119  .Case<StructType>(
120  [index](StructType type) { return type.getElementType(index); })
121  .Default(
122  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
123 }
124 
126  if (auto arrayType = dyn_cast<ArrayType>())
127  return arrayType.getNumElements();
128  if (auto matrixType = dyn_cast<MatrixType>())
129  return matrixType.getNumColumns();
130  if (auto structType = dyn_cast<StructType>())
131  return structType.getNumElements();
132  if (auto vectorType = dyn_cast<VectorType>())
133  return vectorType.getNumElements();
134  if (isa<CooperativeMatrixNVType>()) {
135  llvm_unreachable(
136  "invalid to query number of elements of spirv::CooperativeMatrix type");
137  }
138  if (isa<JointMatrixINTELType>()) {
139  llvm_unreachable(
140  "invalid to query number of elements of spirv::JointMatrix type");
141  }
142  if (isa<RuntimeArrayType>()) {
143  llvm_unreachable(
144  "invalid to query number of elements of spirv::RuntimeArray type");
145  }
146  llvm_unreachable("invalid composite type");
147 }
148 
151  RuntimeArrayType>();
152 }
153 
156  Optional<StorageClass> storage) {
157  TypeSwitch<Type>(*this)
160  [&](auto type) { type.getExtensions(extensions, storage); })
161  .Case<VectorType>([&](VectorType type) {
162  return type.getElementType().cast<ScalarType>().getExtensions(
163  extensions, storage);
164  })
165  .Default([](Type) { llvm_unreachable("invalid composite type"); });
166 }
167 
170  Optional<StorageClass> storage) {
171  TypeSwitch<Type>(*this)
174  [&](auto type) { type.getCapabilities(capabilities, storage); })
175  .Case<VectorType>([&](VectorType type) {
176  auto vecSize = getNumElements();
177  if (vecSize == 8 || vecSize == 16) {
178  static const Capability caps[] = {Capability::Vector16};
179  ArrayRef<Capability> ref(caps, std::size(caps));
180  capabilities.push_back(ref);
181  }
182  return type.getElementType().cast<ScalarType>().getCapabilities(
183  capabilities, storage);
184  })
185  .Default([](Type) { llvm_unreachable("invalid composite type"); });
186 }
187 
189  if (auto arrayType = dyn_cast<ArrayType>())
190  return arrayType.getSizeInBytes();
191  if (auto structType = dyn_cast<StructType>())
192  return structType.getSizeInBytes();
193  if (auto vectorType = dyn_cast<VectorType>()) {
194  Optional<int64_t> elementSize =
195  vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
196  if (!elementSize)
197  return llvm::None;
198  return *elementSize * vectorType.getNumElements();
199  }
200  return llvm::None;
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // CooperativeMatrixType
205 //===----------------------------------------------------------------------===//
206 
208  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
209 
211  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
212  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
214  }
215 
216  bool operator==(const KeyTy &key) const {
217  return key == KeyTy(elementType, scope, rows, columns);
218  }
219 
221  : elementType(std::get<0>(key)), rows(std::get<2>(key)),
222  columns(std::get<3>(key)), scope(std::get<1>(key)) {}
223 
225  unsigned rows;
226  unsigned columns;
227  Scope scope;
228 };
229 
231  Scope scope, unsigned rows,
232  unsigned columns) {
233  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
234 }
235 
237  return getImpl()->elementType;
238 }
239 
240 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
241 
242 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
243 
245  return getImpl()->columns;
246 }
247 
250  Optional<StorageClass> storage) {
251  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
252  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
253  ArrayRef<Extension> ref(exts, std::size(exts));
254  extensions.push_back(ref);
255 }
256 
259  Optional<StorageClass> storage) {
260  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
261  static const Capability caps[] = {Capability::CooperativeMatrixNV};
262  ArrayRef<Capability> ref(caps, std::size(caps));
263  capabilities.push_back(ref);
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // JointMatrixType
268 //===----------------------------------------------------------------------===//
269 
271  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
272 
274  const KeyTy &key) {
275  return new (allocator.allocate<JointMatrixTypeStorage>())
277  }
278 
279  bool operator==(const KeyTy &key) const {
280  return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
281  }
282 
284  : elementType(std::get<0>(key)), rows(std::get<1>(key)),
285  columns(std::get<2>(key)), scope(std::get<4>(key)),
286  matrixLayout(std::get<3>(key)) {}
287 
289  unsigned rows;
290  unsigned columns;
291  Scope scope;
292  MatrixLayout matrixLayout;
293 };
294 
296  unsigned rows, unsigned columns,
297  MatrixLayout matrixLayout) {
298  return Base::get(elementType.getContext(), elementType, rows, columns,
299  matrixLayout, scope);
300 }
301 
303  return getImpl()->elementType;
304 }
305 
306 Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
307 
308 unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
309 
310 unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
311 
313  return getImpl()->matrixLayout;
314 }
315 
318  Optional<StorageClass> storage) {
319  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
320  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
321  ArrayRef<Extension> ref(exts, std::size(exts));
322  extensions.push_back(ref);
323 }
324 
327  Optional<StorageClass> storage) {
328  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
329  static const Capability caps[] = {Capability::JointMatrixINTEL};
330  ArrayRef<Capability> ref(caps, std::size(caps));
331  capabilities.push_back(ref);
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // ImageType
336 //===----------------------------------------------------------------------===//
337 
338 template <typename T>
339 static constexpr unsigned getNumBits() {
340  return 0;
341 }
342 template <>
343 constexpr unsigned getNumBits<Dim>() {
344  static_assert((1 << 3) > getMaxEnumValForDim(),
345  "Not enough bits to encode Dim value");
346  return 3;
347 }
348 template <>
349 constexpr unsigned getNumBits<ImageDepthInfo>() {
350  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
351  "Not enough bits to encode ImageDepthInfo value");
352  return 2;
353 }
354 template <>
355 constexpr unsigned getNumBits<ImageArrayedInfo>() {
356  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
357  "Not enough bits to encode ImageArrayedInfo value");
358  return 1;
359 }
360 template <>
361 constexpr unsigned getNumBits<ImageSamplingInfo>() {
362  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
363  "Not enough bits to encode ImageSamplingInfo value");
364  return 1;
365 }
366 template <>
367 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
368  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
369  "Not enough bits to encode ImageSamplerUseInfo value");
370  return 2;
371 }
372 template <>
373 constexpr unsigned getNumBits<ImageFormat>() {
374  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
375  "Not enough bits to encode ImageFormat value");
376  return 6;
377 }
378 
380 public:
381  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
382  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
383 
385  const KeyTy &key) {
386  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
387  }
388 
389  bool operator==(const KeyTy &key) const {
390  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
391  samplerUseInfo, format);
392  }
393 
395  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
396  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
397  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
398  format(std::get<6>(key)) {}
399 
407 };
408 
409 ImageType
410 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
411  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
412  value) {
413  return Base::get(std::get<0>(value).getContext(), value);
414 }
415 
416 Type ImageType::getElementType() const { return getImpl()->elementType; }
417 
418 Dim ImageType::getDim() const { return getImpl()->dim; }
419 
420 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
421 
422 ImageArrayedInfo ImageType::getArrayedInfo() const {
423  return getImpl()->arrayedInfo;
424 }
425 
426 ImageSamplingInfo ImageType::getSamplingInfo() const {
427  return getImpl()->samplingInfo;
428 }
429 
430 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
431  return getImpl()->samplerUseInfo;
432 }
433 
434 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
435 
438  // Image types do not require extra extensions thus far.
439 }
440 
443  if (auto dimCaps = spirv::getCapabilities(getDim()))
444  capabilities.push_back(*dimCaps);
445 
446  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
447  capabilities.push_back(*fmtCaps);
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // PointerType
452 //===----------------------------------------------------------------------===//
453 
455  // (Type, StorageClass) as the key: Type stored in this struct, and
456  // StorageClass stored as TypeStorage's subclass data.
457  using KeyTy = std::pair<Type, StorageClass>;
458 
460  const KeyTy &key) {
461  return new (allocator.allocate<PointerTypeStorage>())
462  PointerTypeStorage(key);
463  }
464 
465  bool operator==(const KeyTy &key) const {
466  return key == KeyTy(pointeeType, storageClass);
467  }
468 
470  : pointeeType(key.first), storageClass(key.second) {}
471 
473  StorageClass storageClass;
474 };
475 
476 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
477  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
478 }
479 
480 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
481 
482 StorageClass PointerType::getStorageClass() const {
483  return getImpl()->storageClass;
484 }
485 
487  Optional<StorageClass> storage) {
488  // Use this pointer type's storage class because this pointer indicates we are
489  // using the pointee type in that specific storage class.
490  getPointeeType().cast<SPIRVType>().getExtensions(extensions,
491  getStorageClass());
492 
493  if (auto scExts = spirv::getExtensions(getStorageClass()))
494  extensions.push_back(*scExts);
495 }
496 
499  Optional<StorageClass> storage) {
500  // Use this pointer type's storage class because this pointer indicates we are
501  // using the pointee type in that specific storage class.
502  getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
503  getStorageClass());
504 
505  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
506  capabilities.push_back(*scCaps);
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // RuntimeArrayType
511 //===----------------------------------------------------------------------===//
512 
514  using KeyTy = std::pair<Type, unsigned>;
515 
517  const KeyTy &key) {
518  return new (allocator.allocate<RuntimeArrayTypeStorage>())
520  }
521 
522  bool operator==(const KeyTy &key) const {
523  return key == KeyTy(elementType, stride);
524  }
525 
527  : elementType(key.first), stride(key.second) {}
528 
530  unsigned stride;
531 };
532 
534  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
535 }
536 
537 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
538  return Base::get(elementType.getContext(), elementType, stride);
539 }
540 
541 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
542 
543 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
544 
547  Optional<StorageClass> storage) {
548  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
549 }
550 
553  Optional<StorageClass> storage) {
554  {
555  static const Capability caps[] = {Capability::Shader};
556  ArrayRef<Capability> ref(caps, std::size(caps));
557  capabilities.push_back(ref);
558  }
559  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // ScalarType
564 //===----------------------------------------------------------------------===//
565 
567  if (auto floatType = type.dyn_cast<FloatType>()) {
568  return isValid(floatType);
569  }
570  if (auto intType = type.dyn_cast<IntegerType>()) {
571  return isValid(intType);
572  }
573  return false;
574 }
575 
576 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
577 
578 bool ScalarType::isValid(IntegerType type) {
579  switch (type.getWidth()) {
580  case 1:
581  case 8:
582  case 16:
583  case 32:
584  case 64:
585  return true;
586  default:
587  return false;
588  }
589 }
590 
592  Optional<StorageClass> storage) {
593  // 8- or 16-bit integer/floating-point numbers will require extra extensions
594  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
595  // SPV_KHR_8bit_storage for more details.
596  if (!storage)
597  return;
598 
599  switch (*storage) {
600  case StorageClass::PushConstant:
601  case StorageClass::StorageBuffer:
602  case StorageClass::Uniform:
603  if (getIntOrFloatBitWidth() == 8) {
604  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
605  ArrayRef<Extension> ref(exts, std::size(exts));
606  extensions.push_back(ref);
607  }
608  [[fallthrough]];
609  case StorageClass::Input:
610  case StorageClass::Output:
611  if (getIntOrFloatBitWidth() == 16) {
612  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
613  ArrayRef<Extension> ref(exts, std::size(exts));
614  extensions.push_back(ref);
615  }
616  break;
617  default:
618  break;
619  }
620 }
621 
624  Optional<StorageClass> storage) {
625  unsigned bitwidth = getIntOrFloatBitWidth();
626 
627  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
628  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
629  // SPV_KHR_8bit_storage for more details.
630 
631 #define STORAGE_CASE(storage, cap8, cap16) \
632  case StorageClass::storage: { \
633  if (bitwidth == 8) { \
634  static const Capability caps[] = {Capability::cap8}; \
635  ArrayRef<Capability> ref(caps, std::size(caps)); \
636  capabilities.push_back(ref); \
637  return; \
638  } \
639  if (bitwidth == 16) { \
640  static const Capability caps[] = {Capability::cap16}; \
641  ArrayRef<Capability> ref(caps, std::size(caps)); \
642  capabilities.push_back(ref); \
643  return; \
644  } \
645  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
646  /* storage classes. Fall through to the next section. */ \
647  } break
648 
649  // This part only handles the cases where special bitwidths appearing in
650  // interface storage classes.
651  if (storage) {
652  switch (*storage) {
653  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
654  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
655  StorageBuffer16BitAccess);
656  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
657  StorageUniform16);
658  case StorageClass::Input:
659  case StorageClass::Output: {
660  if (bitwidth == 16) {
661  static const Capability caps[] = {Capability::StorageInputOutput16};
662  ArrayRef<Capability> ref(caps, std::size(caps));
663  capabilities.push_back(ref);
664  return;
665  }
666  break;
667  }
668  default:
669  break;
670  }
671  }
672 #undef STORAGE_CASE
673 
674  // For other non-interface storage classes, require a different set of
675  // capabilities for special bitwidths.
676 
677 #define WIDTH_CASE(type, width) \
678  case width: { \
679  static const Capability caps[] = {Capability::type##width}; \
680  ArrayRef<Capability> ref(caps, std::size(caps)); \
681  capabilities.push_back(ref); \
682  } break
683 
684  if (auto intType = dyn_cast<IntegerType>()) {
685  switch (bitwidth) {
686  WIDTH_CASE(Int, 8);
687  WIDTH_CASE(Int, 16);
688  WIDTH_CASE(Int, 64);
689  case 1:
690  case 32:
691  break;
692  default:
693  llvm_unreachable("invalid bitwidth to getCapabilities");
694  }
695  } else {
696  assert(isa<FloatType>());
697  switch (bitwidth) {
698  WIDTH_CASE(Float, 16);
699  WIDTH_CASE(Float, 64);
700  case 32:
701  break;
702  default:
703  llvm_unreachable("invalid bitwidth to getCapabilities");
704  }
705  }
706 
707 #undef WIDTH_CASE
708 }
709 
711  auto bitWidth = getIntOrFloatBitWidth();
712  // According to the SPIR-V spec:
713  // "There is no physical size or bit pattern defined for values with boolean
714  // type. If they are stored (in conjunction with OpVariable), they can only
715  // be used with logical addressing operations, not physical, and only with
716  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
717  // Private, Function, Input, and Output."
718  if (bitWidth == 1)
719  return llvm::None;
720  return bitWidth / 8;
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // SPIRVType
725 //===----------------------------------------------------------------------===//
726 
728  // Allow SPIR-V dialect types
729  if (llvm::isa<SPIRVDialect>(type.getDialect()))
730  return true;
731  if (type.isa<ScalarType>())
732  return true;
733  if (auto vectorType = type.dyn_cast<VectorType>())
734  return CompositeType::isValid(vectorType);
735  return false;
736 }
737 
739  return isIntOrFloat() || isa<VectorType>();
740 }
741 
743  Optional<StorageClass> storage) {
744  if (auto scalarType = dyn_cast<ScalarType>()) {
745  scalarType.getExtensions(extensions, storage);
746  } else if (auto compositeType = dyn_cast<CompositeType>()) {
747  compositeType.getExtensions(extensions, storage);
748  } else if (auto imageType = dyn_cast<ImageType>()) {
749  imageType.getExtensions(extensions, storage);
750  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
751  sampledImageType.getExtensions(extensions, storage);
752  } else if (auto matrixType = dyn_cast<MatrixType>()) {
753  matrixType.getExtensions(extensions, storage);
754  } else if (auto ptrType = dyn_cast<PointerType>()) {
755  ptrType.getExtensions(extensions, storage);
756  } else {
757  llvm_unreachable("invalid SPIR-V Type to getExtensions");
758  }
759 }
760 
763  Optional<StorageClass> storage) {
764  if (auto scalarType = dyn_cast<ScalarType>()) {
765  scalarType.getCapabilities(capabilities, storage);
766  } else if (auto compositeType = dyn_cast<CompositeType>()) {
767  compositeType.getCapabilities(capabilities, storage);
768  } else if (auto imageType = dyn_cast<ImageType>()) {
769  imageType.getCapabilities(capabilities, storage);
770  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
771  sampledImageType.getCapabilities(capabilities, storage);
772  } else if (auto matrixType = dyn_cast<MatrixType>()) {
773  matrixType.getCapabilities(capabilities, storage);
774  } else if (auto ptrType = dyn_cast<PointerType>()) {
775  ptrType.getCapabilities(capabilities, storage);
776  } else {
777  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
778  }
779 }
780 
782  if (auto scalarType = dyn_cast<ScalarType>())
783  return scalarType.getSizeInBytes();
784  if (auto compositeType = dyn_cast<CompositeType>())
785  return compositeType.getSizeInBytes();
786  return llvm::None;
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // SampledImageType
791 //===----------------------------------------------------------------------===//
793  using KeyTy = Type;
794 
795  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
796 
797  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
798 
800  const KeyTy &key) {
801  return new (allocator.allocate<SampledImageTypeStorage>())
803  }
804 
806 };
807 
809  return Base::get(imageType.getContext(), imageType);
810 }
811 
814  Type imageType) {
815  return Base::getChecked(emitError, imageType.getContext(), imageType);
816 }
817 
818 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
819 
822  Type imageType) {
823  if (!imageType.isa<ImageType>())
824  return emitError() << "expected image type";
825 
826  return success();
827 }
828 
831  Optional<StorageClass> storage) {
832  getImageType().cast<ImageType>().getExtensions(extensions, storage);
833 }
834 
837  Optional<StorageClass> storage) {
838  getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // StructType
843 //===----------------------------------------------------------------------===//
844 
845 /// Type storage for SPIR-V structure types:
846 ///
847 /// Structures are uniqued using:
848 /// - for identified structs:
849 /// - a string identifier;
850 /// - for literal structs:
851 /// - a list of member types;
852 /// - a list of member offset info;
853 /// - a list of member decoration info.
854 ///
855 /// Identified structures only have a mutable component consisting of:
856 /// - a list of member types;
857 /// - a list of member offset info;
858 /// - a list of member decoration info.
860  /// Construct a storage object for an identified struct type. A struct type
861  /// associated with such storage must call StructType::trySetBody(...) later
862  /// in order to mutate the storage object providing the actual content.
863  StructTypeStorage(StringRef identifier)
864  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
865  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
866  identifier(identifier) {}
867 
868  /// Construct a storage object for a literal struct type. A struct type
869  /// associated with such storage is immutable.
871  unsigned numMembers, Type const *memberTypes,
872  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
873  StructType::MemberDecorationInfo const *memberDecorationsInfo)
874  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
875  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
876  memberDecorationsInfo(memberDecorationsInfo) {}
877 
878  /// A storage key is divided into 2 parts:
879  /// - for identified structs:
880  /// - a StringRef representing the struct identifier;
881  /// - for literal structs:
882  /// - an ArrayRef<Type> for member types;
883  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
884  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
885  /// info.
886  ///
887  /// An identified struct type is uniqued only by the first part (field 0)
888  /// of the key.
889  ///
890  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
891  /// 3) of the key. The identifier field (field 0) must be empty.
892  using KeyTy =
893  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
895 
896  /// For identified structs, return true if the given key contains the same
897  /// identifier.
898  ///
899  /// For literal structs, return true if the given key contains a matching list
900  /// of member types + offset info + decoration info.
901  bool operator==(const KeyTy &key) const {
902  if (isIdentified()) {
903  // Identified types are uniqued by their identifier.
904  return getIdentifier() == std::get<0>(key);
905  }
906 
907  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
908  getMemberDecorationsInfo());
909  }
910 
911  /// If the given key contains a non-empty identifier, this method constructs
912  /// an identified struct and leaves the rest of the struct type data to be set
913  /// through a later call to StructType::trySetBody(...).
914  ///
915  /// If, on the other hand, the key contains an empty identifier, a literal
916  /// struct is constructed using the other fields of the key.
918  const KeyTy &key) {
919  StringRef keyIdentifier = std::get<0>(key);
920 
921  if (!keyIdentifier.empty()) {
922  StringRef identifier = allocator.copyInto(keyIdentifier);
923 
924  // Identified StructType body/members will be set through trySetBody(...)
925  // later.
926  return new (allocator.allocate<StructTypeStorage>())
927  StructTypeStorage(identifier);
928  }
929 
930  ArrayRef<Type> keyTypes = std::get<1>(key);
931 
932  // Copy the member type and layout information into the bump pointer
933  const Type *typesList = nullptr;
934  if (!keyTypes.empty()) {
935  typesList = allocator.copyInto(keyTypes).data();
936  }
937 
938  const StructType::OffsetInfo *offsetInfoList = nullptr;
939  if (!std::get<2>(key).empty()) {
940  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
941  assert(keyOffsetInfo.size() == keyTypes.size() &&
942  "size of offset information must be same as the size of number of "
943  "elements");
944  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
945  }
946 
947  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
948  unsigned numMemberDecorations = 0;
949  if (!std::get<3>(key).empty()) {
950  auto keyMemberDecorations = std::get<3>(key);
951  numMemberDecorations = keyMemberDecorations.size();
952  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
953  }
954 
955  return new (allocator.allocate<StructTypeStorage>())
956  StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
957  numMemberDecorations, memberDecorationList);
958  }
959 
961  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
962  }
963 
965  if (offsetInfo) {
966  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
967  }
968  return {};
969  }
970 
972  if (memberDecorationsInfo) {
973  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
974  numMemberDecorations);
975  }
976  return {};
977  }
978 
979  StringRef getIdentifier() const { return identifier; }
980 
981  bool isIdentified() const { return !identifier.empty(); }
982 
983  /// Sets the struct type content for identified structs. Calling this method
984  /// is only valid for identified structs.
985  ///
986  /// Fails under the following conditions:
987  /// - If called for a literal struct;
988  /// - If called for an identified struct whose body was set before (through a
989  /// call to this method) but with different contents from the passed
990  /// arguments.
992  TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
993  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
994  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
995  if (!isIdentified())
996  return failure();
997 
998  if (memberTypesAndIsBodySet.getInt() &&
999  (getMemberTypes() != structMemberTypes ||
1000  getOffsetInfo() != structOffsetInfo ||
1001  getMemberDecorationsInfo() != structMemberDecorationInfo))
1002  return failure();
1003 
1004  memberTypesAndIsBodySet.setInt(true);
1005  numMembers = structMemberTypes.size();
1006 
1007  // Copy the member type and layout information into the bump pointer.
1008  if (!structMemberTypes.empty())
1009  memberTypesAndIsBodySet.setPointer(
1010  allocator.copyInto(structMemberTypes).data());
1011 
1012  if (!structOffsetInfo.empty()) {
1013  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1014  "size of offset information must be same as the size of number of "
1015  "elements");
1016  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1017  }
1018 
1019  if (!structMemberDecorationInfo.empty()) {
1020  numMemberDecorations = structMemberDecorationInfo.size();
1021  memberDecorationsInfo =
1022  allocator.copyInto(structMemberDecorationInfo).data();
1023  }
1024 
1025  return success();
1026  }
1027 
1028  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1030  unsigned numMembers;
1033  StringRef identifier;
1034 };
1035 
1036 StructType
1039  ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1040  assert(!memberTypes.empty() && "Struct needs at least one member type");
1041  // Sort the decorations.
1043  memberDecorations.begin(), memberDecorations.end());
1044  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1045  return Base::get(memberTypes.vec().front().getContext(),
1046  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1047  sortedDecorations);
1048 }
1049 
1051  StringRef identifier) {
1052  assert(!identifier.empty() &&
1053  "StructType identifier must be non-empty string");
1054 
1055  return Base::get(context, identifier, ArrayRef<Type>(),
1058 }
1059 
1060 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1061  StructType newStructType = Base::get(
1062  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1064  // Set an empty body in case this is a identified struct.
1065  if (newStructType.isIdentified() &&
1066  failed(newStructType.trySetBody(
1069  return StructType();
1070 
1071  return newStructType;
1072 }
1073 
1074 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1075 
1076 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1077 
1078 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1079 
1080 Type StructType::getElementType(unsigned index) const {
1081  assert(getNumElements() > index && "member index out of range");
1082  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1083 }
1084 
1086  return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1087  getNumElements());
1088 }
1089 
1090 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1091 
1092 uint64_t StructType::getMemberOffset(unsigned index) const {
1093  assert(getNumElements() > index && "member index out of range");
1094  return getImpl()->offsetInfo[index];
1095 }
1096 
1099  const {
1100  memberDecorations.clear();
1101  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1102  memberDecorations.append(implMemberDecorations.begin(),
1103  implMemberDecorations.end());
1104 }
1105 
1107  unsigned index,
1108  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1109  assert(getNumElements() > index && "member index out of range");
1110  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1111  decorationsInfo.clear();
1112  for (const auto &memberDecoration : memberDecorations) {
1113  if (memberDecoration.memberIndex == index) {
1114  decorationsInfo.push_back(memberDecoration);
1115  }
1116  if (memberDecoration.memberIndex > index) {
1117  // Early exit since the decorations are stored sorted.
1118  return;
1119  }
1120  }
1121 }
1122 
1125  ArrayRef<OffsetInfo> offsetInfo,
1126  ArrayRef<MemberDecorationInfo> memberDecorations) {
1127  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1128 }
1129 
1131  Optional<StorageClass> storage) {
1132  for (Type elementType : getElementTypes())
1133  elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1134 }
1135 
1138  Optional<StorageClass> storage) {
1139  for (Type elementType : getElementTypes())
1140  elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1141 }
1142 
1143 llvm::hash_code spirv::hash_value(
1144  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1145  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1146  memberDecorationInfo.decoration);
1147 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // MatrixType
1151 //===----------------------------------------------------------------------===//
1152 
1154  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1155  : columnType(columnType), columnCount(columnCount) {}
1156 
1157  using KeyTy = std::tuple<Type, uint32_t>;
1158 
1160  const KeyTy &key) {
1161 
1162  // Initialize the memory using placement new.
1163  return new (allocator.allocate<MatrixTypeStorage>())
1164  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1165  }
1166 
1167  bool operator==(const KeyTy &key) const {
1168  return key == KeyTy(columnType, columnCount);
1169  }
1170 
1172  const uint32_t columnCount;
1173 };
1174 
1175 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1176  return Base::get(columnType.getContext(), columnType, columnCount);
1177 }
1178 
1180  Type columnType, uint32_t columnCount) {
1181  return Base::getChecked(emitError, columnType.getContext(), columnType,
1182  columnCount);
1183 }
1184 
1186  Type columnType, uint32_t columnCount) {
1187  if (columnCount < 2 || columnCount > 4)
1188  return emitError() << "matrix can have 2, 3, or 4 columns only";
1189 
1190  if (!isValidColumnType(columnType))
1191  return emitError() << "matrix columns must be vectors of floats";
1192 
1193  /// The underlying vectors (columns) must be of size 2, 3, or 4
1194  ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1195  if (columnShape.size() != 1)
1196  return emitError() << "matrix columns must be 1D vectors";
1197 
1198  if (columnShape[0] < 2 || columnShape[0] > 4)
1199  return emitError() << "matrix columns must be of size 2, 3, or 4";
1200 
1201  return success();
1202 }
1203 
1204 /// Returns true if the matrix elements are vectors of float elements
1206  if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1207  if (vectorType.getElementType().isa<FloatType>())
1208  return true;
1209  }
1210  return false;
1211 }
1212 
1213 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1214 
1216  return getImpl()->columnType.cast<VectorType>().getElementType();
1217 }
1218 
1219 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1220 
1221 unsigned MatrixType::getNumRows() const {
1222  return getImpl()->columnType.cast<VectorType>().getShape()[0];
1223 }
1224 
1225 unsigned MatrixType::getNumElements() const {
1226  return (getImpl()->columnCount) * getNumRows();
1227 }
1228 
1230  Optional<StorageClass> storage) {
1231  getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1232 }
1233 
1236  Optional<StorageClass> storage) {
1237  {
1238  static const Capability caps[] = {Capability::Matrix};
1239  ArrayRef<Capability> ref(caps, std::size(caps));
1240  capabilities.push_back(ref);
1241  }
1242  // Add any capabilities associated with the underlying vectors (i.e., columns)
1243  getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // SPIR-V Dialect
1248 //===----------------------------------------------------------------------===//
1249 
1250 void SPIRVDialect::registerTypes() {
1253  StructType>();
1254 }
static constexpr const bool value
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:367
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:373
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:339
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:355
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:361
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:343
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:349
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
Base storage class appearing in a Type.
Definition: TypeSupport.h:122
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
U dyn_cast() const
Definition: Types.h:270
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
bool isBF16() const
Definition: Types.cpp:23
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:73
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:68
Optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:79
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:188
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:149
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:125
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:99
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:113
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:168
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:154
static bool classof(Type type)
Definition: SPIRVTypes.cpp:91
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:242
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:248
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:244
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:230
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:257
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:240
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:164
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:420
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:441
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:422
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:436
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:434
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:430
Type getElementType() const
Definition: SPIRVTypes.cpp:416
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:426
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:306
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:316
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:310
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:295
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:308
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:325
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:312
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Type getPointeeType() const
Definition: SPIRVTypes.cpp:480
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:482
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:486
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:476
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:497
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:545
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:543
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:533
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:551
void getExtensions(ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:742
void getCapabilities(CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:761
static bool classof(Type type)
Definition: SPIRVTypes.cpp:727
Optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:781
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:821
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:829
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:813
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:808
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< spirv::StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:835
static bool classof(Type type)
Definition: SPIRVTypes.cpp:566
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:622
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
Definition: SPIRVTypes.cpp:591
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:576
Optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:710
Range class for element types.
Definition: SPIRVTypes.h:350
SPIR-V struct type.
Definition: SPIRVTypes.h:281
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional< StorageClass > storage=llvm::None)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional< StorageClass > storage=llvm::None)
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
ElementTypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:30
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:32
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:37
std::tuple< Type, Scope, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:208
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:211
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:389
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:384
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:382
std::tuple< Type, unsigned, unsigned, MatrixLayout, Scope > KeyTy
Definition: SPIRVTypes.cpp:271
static JointMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:273
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:279
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:465
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:457
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:459
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:516
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:522
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:797
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:799
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:859
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:964
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:870
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:901
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:991
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:917
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:960
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:863
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:971
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:894
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet