MLIR  22.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/ErrorHandling.h"
21 
22 #include <cstdint>
23 #include <optional>
24 
25 using namespace mlir;
26 using namespace mlir::spirv;
27 
28 namespace {
29 // Helper function to collect extensions implied by a type by visiting all its
30 // subtypes. Maintains a set of `seen` types to avoid recursion in structs.
31 //
32 // Serves as the source-of-truth for type extension information. All extension
33 // logic should be added to this class, while the
34 // `SPIRVType::getExtensions` function should not handle extension-related logic
35 // directly and only invoke `TypeExtensionVisitor::add(Type *)`.
36 class TypeExtensionVisitor {
37 public:
38  TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
39  std::optional<StorageClass> storage)
40  : extensions(extensions), storage(storage) {}
41 
42  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
43  // as seen and dispatches to the right concrete `.add` function.
44  void add(SPIRVType type) {
45  if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
46  return;
47 
50  [this](auto concreteType) { addConcrete(concreteType); })
51  .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
52  [this](auto concreteType) { add(concreteType.getElementType()); })
53  .Case<SampledImageType>([this](SampledImageType concreteType) {
54  add(concreteType.getImageType());
55  })
56  .Case<StructType>([this](StructType concreteType) {
57  for (Type elementType : concreteType.getElementTypes())
58  add(elementType);
59  })
60  .DefaultUnreachable("Unhandled type");
61  }
62 
63  void add(Type type) { add(cast<SPIRVType>(type)); }
64 
65 private:
66  // Types that add unique extensions.
67  void addConcrete(CooperativeMatrixType type);
68  void addConcrete(PointerType type);
69  void addConcrete(ScalarType type);
70  void addConcrete(TensorArmType type);
71 
73  std::optional<StorageClass> storage;
74  llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
75 };
76 
77 // Helper function to collect capabilities implied by a type by visiting all its
78 // subtypes. Maintains a set of `seen` types to avoid recursion in structs.
79 //
80 // Serves as the source-of-truth for type capability information. All capability
81 // logic should be added to this class, while the
82 // `SPIRVType::getCapabilities` function should not handle capability-related
83 // logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
84 class TypeCapabilityVisitor {
85 public:
86  TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
87  std::optional<StorageClass> storage)
88  : capabilities(capabilities), storage(storage) {}
89 
90  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
91  // as seen and dispatches to the right concrete `.add` function.
92  void add(SPIRVType type) {
93  if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
94  return;
95 
99  [this](auto concreteType) { addConcrete(concreteType); })
100  .Case<ArrayType>([this](ArrayType concreteType) {
101  add(concreteType.getElementType());
102  })
103  .Case<SampledImageType>([this](SampledImageType concreteType) {
104  add(concreteType.getImageType());
105  })
106  .Case<StructType>([this](StructType concreteType) {
107  for (Type elementType : concreteType.getElementTypes())
108  add(elementType);
109  })
110  .DefaultUnreachable("Unhandled type");
111  }
112 
113  void add(Type type) { add(cast<SPIRVType>(type)); }
114 
115 private:
116  // Types that add unique extensions.
117  void addConcrete(CooperativeMatrixType type);
118  void addConcrete(ImageType type);
119  void addConcrete(MatrixType type);
120  void addConcrete(PointerType type);
121  void addConcrete(RuntimeArrayType type);
122  void addConcrete(ScalarType type);
123  void addConcrete(TensorArmType type);
124  void addConcrete(VectorType type);
125 
127  std::optional<StorageClass> storage;
128  llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
129 };
130 
131 } // namespace
132 
133 //===----------------------------------------------------------------------===//
134 // ArrayType
135 //===----------------------------------------------------------------------===//
136 
138  using KeyTy = std::tuple<Type, unsigned, unsigned>;
139 
141  const KeyTy &key) {
142  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
143  }
144 
145  bool operator==(const KeyTy &key) const {
146  return key == KeyTy(elementType, elementCount, stride);
147  }
148 
150  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
151  stride(std::get<2>(key)) {}
152 
154  unsigned elementCount;
155  unsigned stride;
156 };
157 
158 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
159  assert(elementCount && "ArrayType needs at least one element");
160  return Base::get(elementType.getContext(), elementType, elementCount,
161  /*stride=*/0);
162 }
163 
164 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
165  unsigned stride) {
166  assert(elementCount && "ArrayType needs at least one element");
167  return Base::get(elementType.getContext(), elementType, elementCount, stride);
168 }
169 
170 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
171 
172 Type ArrayType::getElementType() const { return getImpl()->elementType; }
173 
174 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
175 
176 //===----------------------------------------------------------------------===//
177 // CompositeType
178 //===----------------------------------------------------------------------===//
179 
181  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
182  return isValid(vectorType);
186 }
187 
188 bool CompositeType::isValid(VectorType type) {
189  return type.getRank() == 1 &&
190  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
191  llvm::isa<ScalarType>(type.getElementType());
192 }
193 
194 Type CompositeType::getElementType(unsigned index) const {
195  return TypeSwitch<Type, Type>(*this)
197  TensorArmType>([](auto type) { return type.getElementType(); })
198  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
199  .Case<StructType>(
200  [index](StructType type) { return type.getElementType(index); })
201  .DefaultUnreachable("Invalid composite type");
202 }
203 
205  return TypeSwitch<SPIRVType, unsigned>(*this)
206  .Case<ArrayType, StructType, TensorArmType, VectorType>(
207  [](auto type) { return type.getNumElements(); })
208  .Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
209  .DefaultUnreachable("Invalid type for number of elements query");
210 }
211 
213  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
214 }
215 
216 void TypeCapabilityVisitor::addConcrete(VectorType type) {
217  add(type.getElementType());
218 
219  int64_t vecSize = type.getNumElements();
220  if (vecSize == 8 || vecSize == 16) {
221  static constexpr auto cap = Capability::Vector16;
222  capabilities.push_back(cap);
223  }
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // CooperativeMatrixType
228 //===----------------------------------------------------------------------===//
229 
231  // In the specification dimensions of the Cooperative Matrix are 32-bit
232  // integers --- the initial implementation kept those values as such. However,
233  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
234  // as 32-bits and expose it as int64_t through `getShape`, however, this
235  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
236  // 32-bits integers would require an extra logic and storage. So, we diverge
237  // from the spec and internally represent the dimensions as 64-bit integers,
238  // so we can easily return an `ArrayRef` from `getShape` without any extra
239  // logic. Alternatively, we could store both rows and columns (both 32-bits)
240  // and shape (64-bits), assigning rows and columns to shape whenever
241  // `getShape` is called. This would be at the cost of extra logic and storage.
242  // Note: Because `ArrayRef` is returned we cannot construct an object in
243  // `getShape` on the fly.
244  using KeyTy =
245  std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
246 
248  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
249  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
251  }
252 
253  bool operator==(const KeyTy &key) const {
254  return key == KeyTy(elementType, shape[0], shape[1], scope, use);
255  }
256 
258  : elementType(std::get<0>(key)),
259  shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
260  use(std::get<4>(key)) {}
261 
263  // [#rows, #columns]
264  std::array<int64_t, 2> shape;
265  Scope scope;
266  CooperativeMatrixUseKHR use;
267 };
268 
270  uint32_t rows,
271  uint32_t columns, Scope scope,
272  CooperativeMatrixUseKHR use) {
273  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
274  use);
275 }
276 
278  return getImpl()->elementType;
279 }
280 
282  assert(getImpl()->shape[0] != ShapedType::kDynamic);
283  return static_cast<uint32_t>(getImpl()->shape[0]);
284 }
285 
287  assert(getImpl()->shape[1] != ShapedType::kDynamic);
288  return static_cast<uint32_t>(getImpl()->shape[1]);
289 }
290 
292  return getImpl()->shape;
293 }
294 
295 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
296 
297 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
298  return getImpl()->use;
299 }
300 
301 void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
302  add(type.getElementType());
303  static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
304  extensions.push_back(ext);
305 }
306 
307 void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
308  add(type.getElementType());
309  static constexpr auto caps = Capability::CooperativeMatrixKHR;
310  capabilities.push_back(caps);
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // ImageType
315 //===----------------------------------------------------------------------===//
316 
317 template <typename T>
318 static constexpr unsigned getNumBits() {
319  return 0;
320 }
321 template <>
322 constexpr unsigned getNumBits<Dim>() {
323  static_assert((1 << 3) > getMaxEnumValForDim(),
324  "Not enough bits to encode Dim value");
325  return 3;
326 }
327 template <>
328 constexpr unsigned getNumBits<ImageDepthInfo>() {
329  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
330  "Not enough bits to encode ImageDepthInfo value");
331  return 2;
332 }
333 template <>
334 constexpr unsigned getNumBits<ImageArrayedInfo>() {
335  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
336  "Not enough bits to encode ImageArrayedInfo value");
337  return 1;
338 }
339 template <>
340 constexpr unsigned getNumBits<ImageSamplingInfo>() {
341  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
342  "Not enough bits to encode ImageSamplingInfo value");
343  return 1;
344 }
345 template <>
346 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
347  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
348  "Not enough bits to encode ImageSamplerUseInfo value");
349  return 2;
350 }
351 template <>
352 constexpr unsigned getNumBits<ImageFormat>() {
353  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
354  "Not enough bits to encode ImageFormat value");
355  return 6;
356 }
357 
359 public:
360  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
361  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
362 
364  const KeyTy &key) {
365  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
366  }
367 
368  bool operator==(const KeyTy &key) const {
369  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
370  samplerUseInfo, format);
371  }
372 
374  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
375  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
376  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
377  format(std::get<6>(key)) {}
378 
386 };
387 
388 ImageType
389 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
390  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
391  value) {
392  return Base::get(std::get<0>(value).getContext(), value);
393 }
394 
395 Type ImageType::getElementType() const { return getImpl()->elementType; }
396 
397 Dim ImageType::getDim() const { return getImpl()->dim; }
398 
399 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
400 
401 ImageArrayedInfo ImageType::getArrayedInfo() const {
402  return getImpl()->arrayedInfo;
403 }
404 
405 ImageSamplingInfo ImageType::getSamplingInfo() const {
406  return getImpl()->samplingInfo;
407 }
408 
409 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
410  return getImpl()->samplerUseInfo;
411 }
412 
413 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
414 
415 void TypeCapabilityVisitor::addConcrete(ImageType type) {
416  if (auto dimCaps = spirv::getCapabilities(type.getDim()))
417  capabilities.push_back(*dimCaps);
418 
419  if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
420  capabilities.push_back(*fmtCaps);
421 
422  add(type.getElementType());
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // PointerType
427 //===----------------------------------------------------------------------===//
428 
430  // (Type, StorageClass) as the key: Type stored in this struct, and
431  // StorageClass stored as TypeStorage's subclass data.
432  using KeyTy = std::pair<Type, StorageClass>;
433 
435  const KeyTy &key) {
436  return new (allocator.allocate<PointerTypeStorage>())
437  PointerTypeStorage(key);
438  }
439 
440  bool operator==(const KeyTy &key) const {
441  return key == KeyTy(pointeeType, storageClass);
442  }
443 
445  : pointeeType(key.first), storageClass(key.second) {}
446 
448  StorageClass storageClass;
449 };
450 
451 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
452  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
453 }
454 
455 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
456 
457 StorageClass PointerType::getStorageClass() const {
458  return getImpl()->storageClass;
459 }
460 
461 void TypeExtensionVisitor::addConcrete(PointerType type) {
462  // Use this pointer type's storage class because this pointer indicates we are
463  // using the pointee type in that specific storage class.
464  std::optional<StorageClass> oldStorageClass = storage;
465  storage = type.getStorageClass();
466  add(type.getPointeeType());
467  storage = oldStorageClass;
468 
469  if (auto scExts = spirv::getExtensions(type.getStorageClass()))
470  extensions.push_back(*scExts);
471 }
472 
473 void TypeCapabilityVisitor::addConcrete(PointerType type) {
474  // Use this pointer type's storage class because this pointer indicates we are
475  // using the pointee type in that specific storage class.
476  std::optional<StorageClass> oldStorageClass = storage;
477  storage = type.getStorageClass();
478  add(type.getPointeeType());
479  storage = oldStorageClass;
480 
481  if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
482  capabilities.push_back(*scCaps);
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // RuntimeArrayType
487 //===----------------------------------------------------------------------===//
488 
490  using KeyTy = std::pair<Type, unsigned>;
491 
493  const KeyTy &key) {
494  return new (allocator.allocate<RuntimeArrayTypeStorage>())
496  }
497 
498  bool operator==(const KeyTy &key) const {
499  return key == KeyTy(elementType, stride);
500  }
501 
503  : elementType(key.first), stride(key.second) {}
504 
506  unsigned stride;
507 };
508 
510  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
511 }
512 
513 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
514  return Base::get(elementType.getContext(), elementType, stride);
515 }
516 
517 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
518 
519 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
520 
521 void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
522  add(type.getElementType());
523  static constexpr auto cap = Capability::Shader;
524  capabilities.push_back(cap);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // ScalarType
529 //===----------------------------------------------------------------------===//
530 
532  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
533  return isValid(floatType);
534  }
535  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
536  return isValid(intType);
537  }
538  return false;
539 }
540 
541 bool ScalarType::isValid(FloatType type) {
542  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
543 }
544 
545 bool ScalarType::isValid(IntegerType type) {
546  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
547 }
548 
549 void TypeExtensionVisitor::addConcrete(ScalarType type) {
550  if (isa<BFloat16Type>(type)) {
551  static constexpr auto ext = Extension::SPV_KHR_bfloat16;
552  extensions.push_back(ext);
553  }
554 
555  // 8- or 16-bit integer/floating-point numbers will require extra extensions
556  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
557  // SPV_KHR_8bit_storage for more details.
558  if (!storage)
559  return;
560 
561  switch (*storage) {
562  case StorageClass::PushConstant:
563  case StorageClass::StorageBuffer:
564  case StorageClass::Uniform:
565  if (type.getIntOrFloatBitWidth() == 8) {
566  static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
567  extensions.push_back(ext);
568  }
569  [[fallthrough]];
570  case StorageClass::Input:
571  case StorageClass::Output:
572  if (type.getIntOrFloatBitWidth() == 16) {
573  static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
574  extensions.push_back(ext);
575  }
576  break;
577  default:
578  break;
579  }
580 }
581 
582 void TypeCapabilityVisitor::addConcrete(ScalarType type) {
583  unsigned bitwidth = type.getIntOrFloatBitWidth();
584 
585  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
586  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587  // SPV_KHR_8bit_storage for more details.
588 
589 #define STORAGE_CASE(storage, cap8, cap16) \
590  case StorageClass::storage: { \
591  if (bitwidth == 8) { \
592  static constexpr auto cap = Capability::cap8; \
593  capabilities.push_back(cap); \
594  return; \
595  } \
596  if (bitwidth == 16) { \
597  static constexpr auto cap = Capability::cap16; \
598  capabilities.push_back(cap); \
599  return; \
600  } \
601  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
602  /* storage classes. Fall through to the next section. */ \
603  } break
604 
605  // This part only handles the cases where special bitwidths appearing in
606  // interface storage classes.
607  if (storage) {
608  switch (*storage) {
609  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
610  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
611  StorageBuffer16BitAccess);
612  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
613  StorageUniform16);
614  case StorageClass::Input:
615  case StorageClass::Output: {
616  if (bitwidth == 16) {
617  static constexpr auto cap = Capability::StorageInputOutput16;
618  capabilities.push_back(cap);
619  return;
620  }
621  break;
622  }
623  default:
624  break;
625  }
626  }
627 #undef STORAGE_CASE
628 
629  // For other non-interface storage classes, require a different set of
630  // capabilities for special bitwidths.
631 
632 #define WIDTH_CASE(type, width) \
633  case width: { \
634  static constexpr auto cap = Capability::type##width; \
635  capabilities.push_back(cap); \
636  } break
637 
638  if (auto intType = dyn_cast<IntegerType>(type)) {
639  switch (bitwidth) {
640  WIDTH_CASE(Int, 8);
641  WIDTH_CASE(Int, 16);
642  WIDTH_CASE(Int, 64);
643  case 1:
644  case 32:
645  break;
646  default:
647  llvm_unreachable("invalid bitwidth to getCapabilities");
648  }
649  } else {
650  assert(isa<FloatType>(type));
651  switch (bitwidth) {
652  case 16: {
653  if (isa<BFloat16Type>(type)) {
654  static constexpr auto cap = Capability::BFloat16TypeKHR;
655  capabilities.push_back(cap);
656  } else {
657  static constexpr auto cap = Capability::Float16;
658  capabilities.push_back(cap);
659  }
660  break;
661  }
662  WIDTH_CASE(Float, 64);
663  case 32:
664  break;
665  default:
666  llvm_unreachable("invalid bitwidth to getCapabilities");
667  }
668  }
669 
670 #undef WIDTH_CASE
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // SPIRVType
675 //===----------------------------------------------------------------------===//
676 
678  // Allow SPIR-V dialect types
679  if (llvm::isa<SPIRVDialect>(type.getDialect()))
680  return true;
681  if (llvm::isa<ScalarType>(type))
682  return true;
683  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
684  return CompositeType::isValid(vectorType);
685  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
686  return llvm::isa<ScalarType>(tensorArmType.getElementType());
687  return false;
688 }
689 
691  return isIntOrFloat() || llvm::isa<VectorType>(*this);
692 }
693 
695  std::optional<StorageClass> storage) {
696  TypeExtensionVisitor{extensions, storage}.add(*this);
697 }
698 
701  std::optional<StorageClass> storage) {
702  TypeCapabilityVisitor{capabilities, storage}.add(*this);
703 }
704 
705 std::optional<int64_t> SPIRVType::getSizeInBytes() {
707  .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
708  // According to the SPIR-V spec:
709  // "There is no physical size or bit pattern defined for values with
710  // boolean type. If they are stored (in conjunction with OpVariable),
711  // they can only be used with logical addressing operations, not
712  // physical, and only with non-externally visible shader Storage
713  // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
714  // Output."
715  int64_t bitWidth = type.getIntOrFloatBitWidth();
716  if (bitWidth == 1)
717  return std::nullopt;
718  return bitWidth / 8;
719  })
720  .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
721  // Since array type may have an explicit stride declaration (in bytes),
722  // we also include it in the calculation.
723  auto elementType = cast<SPIRVType>(type.getElementType());
724  if (std::optional<int64_t> size = elementType.getSizeInBytes())
725  return (*size + type.getArrayStride()) * type.getNumElements();
726  return std::nullopt;
727  })
728  .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
729  if (std::optional<int64_t> elementSize =
730  cast<ScalarType>(type.getElementType()).getSizeInBytes())
731  return *elementSize * type.getNumElements();
732  return std::nullopt;
733  })
734  .Default(std::optional<int64_t>());
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // SampledImageType
739 //===----------------------------------------------------------------------===//
741  using KeyTy = Type;
742 
743  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
744 
745  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
746 
748  const KeyTy &key) {
749  return new (allocator.allocate<SampledImageTypeStorage>())
751  }
752 
754 };
755 
757  return Base::get(imageType.getContext(), imageType);
758 }
759 
762  Type imageType) {
763  return Base::getChecked(emitError, imageType.getContext(), imageType);
764 }
765 
766 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
767 
768 LogicalResult
770  Type imageType) {
771  auto image = dyn_cast<ImageType>(imageType);
772  if (!image)
773  return emitError() << "expected image type";
774 
775  // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
776  // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
777  // ("3.3.6. Type-Declaration Instructions")
778  if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
779  return emitError() << "Dim must not be SubpassData or Buffer";
780 
781  return success();
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // StructType
786 //===----------------------------------------------------------------------===//
787 
788 /// Type storage for SPIR-V structure types:
789 ///
790 /// Structures are uniqued using:
791 /// - for identified structs:
792 /// - a string identifier;
793 /// - for literal structs:
794 /// - a list of member types;
795 /// - a list of member offset info;
796 /// - a list of member decoration info;
797 /// - a list of struct decoration info.
798 ///
799 /// Identified structures only have a mutable component consisting of:
800 /// - a list of member types;
801 /// - a list of member offset info;
802 /// - a list of member decoration info;
803 /// - a list of struct decoration info.
805  /// Construct a storage object for an identified struct type. A struct type
806  /// associated with such storage must call StructType::trySetBody(...) later
807  /// in order to mutate the storage object providing the actual content.
808  StructTypeStorage(StringRef identifier)
809  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
810  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
811  numStructDecorations(0), structDecorationsInfo(nullptr),
812  identifier(identifier) {}
813 
814  /// Construct a storage object for a literal struct type. A struct type
815  /// associated with such storage is immutable.
817  unsigned numMembers, Type const *memberTypes,
818  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
819  StructType::MemberDecorationInfo const *memberDecorationsInfo,
820  unsigned numStructDecorations,
821  StructType::StructDecorationInfo const *structDecorationsInfo)
822  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
823  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
824  memberDecorationsInfo(memberDecorationsInfo),
825  numStructDecorations(numStructDecorations),
826  structDecorationsInfo(structDecorationsInfo) {}
827 
828  /// A storage key is divided into 2 parts:
829  /// - for identified structs:
830  /// - a StringRef representing the struct identifier;
831  /// - for literal structs:
832  /// - an ArrayRef<Type> for member types;
833  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
834  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
835  /// info;
836  /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
837  /// info.
838  ///
839  /// An identified struct type is uniqued only by the first part (field 0)
840  /// of the key.
841  ///
842  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
843  /// and 4) of the key. The identifier field (field 0) must be empty.
844  using KeyTy =
845  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
848 
849  /// For identified structs, return true if the given key contains the same
850  /// identifier.
851  ///
852  /// For literal structs, return true if the given key contains a matching list
853  /// of member types + offset info + decoration info.
854  bool operator==(const KeyTy &key) const {
855  if (isIdentified()) {
856  // Identified types are uniqued by their identifier.
857  return getIdentifier() == std::get<0>(key);
858  }
859 
860  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
861  getMemberDecorationsInfo(), getStructDecorationsInfo());
862  }
863 
864  /// If the given key contains a non-empty identifier, this method constructs
865  /// an identified struct and leaves the rest of the struct type data to be set
866  /// through a later call to StructType::trySetBody(...).
867  ///
868  /// If, on the other hand, the key contains an empty identifier, a literal
869  /// struct is constructed using the other fields of the key.
871  const KeyTy &key) {
872  StringRef keyIdentifier = std::get<0>(key);
873 
874  if (!keyIdentifier.empty()) {
875  StringRef identifier = allocator.copyInto(keyIdentifier);
876 
877  // Identified StructType body/members will be set through trySetBody(...)
878  // later.
879  return new (allocator.allocate<StructTypeStorage>())
880  StructTypeStorage(identifier);
881  }
882 
883  ArrayRef<Type> keyTypes = std::get<1>(key);
884 
885  // Copy the member type and layout information into the bump pointer
886  const Type *typesList = nullptr;
887  if (!keyTypes.empty()) {
888  typesList = allocator.copyInto(keyTypes).data();
889  }
890 
891  const StructType::OffsetInfo *offsetInfoList = nullptr;
892  if (!std::get<2>(key).empty()) {
893  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
894  assert(keyOffsetInfo.size() == keyTypes.size() &&
895  "size of offset information must be same as the size of number of "
896  "elements");
897  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
898  }
899 
900  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
901  unsigned numMemberDecorations = 0;
902  if (!std::get<3>(key).empty()) {
903  auto keyMemberDecorations = std::get<3>(key);
904  numMemberDecorations = keyMemberDecorations.size();
905  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
906  }
907 
908  const StructType::StructDecorationInfo *structDecorationList = nullptr;
909  unsigned numStructDecorations = 0;
910  if (!std::get<4>(key).empty()) {
911  auto keyStructDecorations = std::get<4>(key);
912  numStructDecorations = keyStructDecorations.size();
913  structDecorationList = allocator.copyInto(keyStructDecorations).data();
914  }
915 
916  return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
917  keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
918  memberDecorationList, numStructDecorations, structDecorationList);
919  }
920 
922  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
923  }
924 
926  if (offsetInfo) {
927  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
928  }
929  return {};
930  }
931 
933  if (memberDecorationsInfo) {
934  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
935  numMemberDecorations);
936  }
937  return {};
938  }
939 
941  if (structDecorationsInfo)
942  return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
943  numStructDecorations);
944  return {};
945  }
946 
947  StringRef getIdentifier() const { return identifier; }
948 
949  bool isIdentified() const { return !identifier.empty(); }
950 
951  /// Sets the struct type content for identified structs. Calling this method
952  /// is only valid for identified structs.
953  ///
954  /// Fails under the following conditions:
955  /// - If called for a literal struct;
956  /// - If called for an identified struct whose body was set before (through a
957  /// call to this method) but with different contents from the passed
958  /// arguments.
959  LogicalResult
960  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
961  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
962  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
963  ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
964  if (!isIdentified())
965  return failure();
966 
967  if (memberTypesAndIsBodySet.getInt() &&
968  (getMemberTypes() != structMemberTypes ||
969  getOffsetInfo() != structOffsetInfo ||
970  getMemberDecorationsInfo() != structMemberDecorationInfo ||
971  getStructDecorationsInfo() != structDecorationInfo))
972  return failure();
973 
974  memberTypesAndIsBodySet.setInt(true);
975  numMembers = structMemberTypes.size();
976 
977  // Copy the member type and layout information into the bump pointer.
978  if (!structMemberTypes.empty())
979  memberTypesAndIsBodySet.setPointer(
980  allocator.copyInto(structMemberTypes).data());
981 
982  if (!structOffsetInfo.empty()) {
983  assert(structOffsetInfo.size() == structMemberTypes.size() &&
984  "size of offset information must be same as the size of number of "
985  "elements");
986  offsetInfo = allocator.copyInto(structOffsetInfo).data();
987  }
988 
989  if (!structMemberDecorationInfo.empty()) {
990  numMemberDecorations = structMemberDecorationInfo.size();
991  memberDecorationsInfo =
992  allocator.copyInto(structMemberDecorationInfo).data();
993  }
994 
995  if (!structDecorationInfo.empty()) {
996  numStructDecorations = structDecorationInfo.size();
997  structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
998  }
999 
1000  return success();
1001  }
1002 
1003  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1005  unsigned numMembers;
1010  StringRef identifier;
1011 };
1012 
1013 StructType
1017  ArrayRef<StructType::StructDecorationInfo> structDecorations) {
1018  assert(!memberTypes.empty() && "Struct needs at least one member type");
1019  // Sort the decorations.
1020  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
1021  memberDecorations);
1022  llvm::array_pod_sort(sortedMemberDecorations.begin(),
1023  sortedMemberDecorations.end());
1024  SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
1025  structDecorations);
1026  llvm::array_pod_sort(sortedStructDecorations.begin(),
1027  sortedStructDecorations.end());
1028 
1029  return Base::get(memberTypes.vec().front().getContext(),
1030  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1031  sortedMemberDecorations, sortedStructDecorations);
1032 }
1033 
1035  StringRef identifier) {
1036  assert(!identifier.empty() &&
1037  "StructType identifier must be non-empty string");
1038 
1039  return Base::get(context, identifier, ArrayRef<Type>(),
1043 }
1044 
1045 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1046  StructType newStructType = Base::get(
1047  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1050  // Set an empty body in case this is a identified struct.
1051  if (newStructType.isIdentified() &&
1052  failed(newStructType.trySetBody(
1056  return StructType();
1057 
1058  return newStructType;
1059 }
1060 
1061 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1062 
1063 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1064 
1065 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1066 
1067 Type StructType::getElementType(unsigned index) const {
1068  assert(getNumElements() > index && "member index out of range");
1069  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1070 }
1071 
1073  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1074  getNumElements());
1075 }
1076 
1077 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1078 
1079 bool StructType::hasDecoration(spirv::Decoration decoration) const {
1081  getImpl()->getStructDecorationsInfo())
1082  if (info.decoration == decoration)
1083  return true;
1084 
1085  return false;
1086 }
1087 
1088 uint64_t StructType::getMemberOffset(unsigned index) const {
1089  assert(getNumElements() > index && "member index out of range");
1090  return getImpl()->offsetInfo[index];
1091 }
1092 
1095  const {
1096  memberDecorations.clear();
1097  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1098  memberDecorations.append(implMemberDecorations.begin(),
1099  implMemberDecorations.end());
1100 }
1101 
1103  unsigned index,
1104  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1105  assert(getNumElements() > index && "member index out of range");
1106  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1107  decorationsInfo.clear();
1108  for (const auto &memberDecoration : memberDecorations) {
1109  if (memberDecoration.memberIndex == index) {
1110  decorationsInfo.push_back(memberDecoration);
1111  }
1112  if (memberDecoration.memberIndex > index) {
1113  // Early exit since the decorations are stored sorted.
1114  return;
1115  }
1116  }
1117 }
1118 
1121  const {
1122  structDecorations.clear();
1123  auto implDecorations = getImpl()->getStructDecorationsInfo();
1124  structDecorations.append(implDecorations.begin(), implDecorations.end());
1125 }
1126 
1127 LogicalResult
1129  ArrayRef<OffsetInfo> offsetInfo,
1130  ArrayRef<MemberDecorationInfo> memberDecorations,
1131  ArrayRef<StructDecorationInfo> structDecorations) {
1132  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1133  structDecorations);
1134 }
1135 
1136 llvm::hash_code spirv::hash_value(
1137  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1138  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1139  memberDecorationInfo.decoration);
1140 }
1141 
1142 llvm::hash_code spirv::hash_value(
1143  const StructType::StructDecorationInfo &structDecorationInfo) {
1144  return llvm::hash_value(structDecorationInfo.decoration);
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // MatrixType
1149 //===----------------------------------------------------------------------===//
1150 
1152  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1153  : columnType(columnType), columnCount(columnCount) {}
1154 
1155  using KeyTy = std::tuple<Type, uint32_t>;
1156 
1158  const KeyTy &key) {
1159 
1160  // Initialize the memory using placement new.
1161  return new (allocator.allocate<MatrixTypeStorage>())
1162  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1163  }
1164 
1165  bool operator==(const KeyTy &key) const {
1166  return key == KeyTy(columnType, columnCount);
1167  }
1168 
1170  const uint32_t columnCount;
1171 };
1172 
1173 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1174  return Base::get(columnType.getContext(), columnType, columnCount);
1175 }
1176 
1178  Type columnType, uint32_t columnCount) {
1179  return Base::getChecked(emitError, columnType.getContext(), columnType,
1180  columnCount);
1181 }
1182 
1183 LogicalResult
1185  Type columnType, uint32_t columnCount) {
1186  if (columnCount < 2 || columnCount > 4)
1187  return emitError() << "matrix can have 2, 3, or 4 columns only";
1188 
1189  if (!isValidColumnType(columnType))
1190  return emitError() << "matrix columns must be vectors of floats";
1191 
1192  /// The underlying vectors (columns) must be of size 2, 3, or 4
1193  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1194  if (columnShape.size() != 1)
1195  return emitError() << "matrix columns must be 1D vectors";
1196 
1197  if (columnShape[0] < 2 || columnShape[0] > 4)
1198  return emitError() << "matrix columns must be of size 2, 3, or 4";
1199 
1200  return success();
1201 }
1202 
1203 /// Returns true if the matrix elements are vectors of float elements
1205  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206  if (llvm::isa<FloatType>(vectorType.getElementType()))
1207  return true;
1208  }
1209  return false;
1210 }
1211 
1212 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1213 
1215  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1216 }
1217 
1218 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1219 
1220 unsigned MatrixType::getNumRows() const {
1221  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1222 }
1223 
1224 unsigned MatrixType::getNumElements() const {
1225  return (getImpl()->columnCount) * getNumRows();
1226 }
1227 
1228 void TypeCapabilityVisitor::addConcrete(MatrixType type) {
1229  add(type.getColumnType());
1230  static constexpr auto cap = Capability::Matrix;
1231  capabilities.push_back(cap);
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // TensorArmType
1236 //===----------------------------------------------------------------------===//
1237 
1239  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1240 
1242  const KeyTy &key) {
1243  auto [shape, elementType] = key;
1244  shape = allocator.copyInto(shape);
1245  return new (allocator.allocate<TensorArmTypeStorage>())
1246  TensorArmTypeStorage(shape, elementType);
1247  }
1248 
1249  static llvm::hash_code hashKey(const KeyTy &key) {
1250  auto [shape, elementType] = key;
1251  return llvm::hash_combine(shape, elementType);
1252  }
1253 
1254  bool operator==(const KeyTy &key) const {
1255  return key == KeyTy(shape, elementType);
1256  }
1257 
1259  : shape(shape), elementType(elementType) {}
1260 
1263 };
1264 
1266  return Base::get(elementType.getContext(), shape, elementType);
1267 }
1268 
1270  Type elementType) const {
1271  return TensorArmType::get(shape.value_or(getShape()), elementType);
1272 }
1273 
1274 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1276 
1277 void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1278  add(type.getElementType());
1279  static constexpr auto ext = Extension::SPV_ARM_tensors;
1280  extensions.push_back(ext);
1281 }
1282 
1283 void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
1284  add(type.getElementType());
1285  static constexpr auto cap = Capability::TensorsARM;
1286  capabilities.push_back(cap);
1287 }
1288 
1289 LogicalResult
1291  ArrayRef<int64_t> shape, Type elementType) {
1292  if (llvm::is_contained(shape, 0))
1293  return emitError() << "arm.tensor do not support dimensions = 0";
1294  if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1295  llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1296  return emitError()
1297  << "arm.tensor shape dimensions must be either fully dynamic or "
1298  "completed shaped";
1299  return success();
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // SPIR-V Dialect
1304 //===----------------------------------------------------------------------===//
1305 
1306 void SPIRVDialect::registerTypes() {
1309 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:346
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:352
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:318
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:334
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:340
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:322
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:328
int64_t rows
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:172
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:174
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:170
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:158
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:212
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:204
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:188
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:194
static bool classof(Type type)
Definition: SPIRVTypes.cpp:180
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:295
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:281
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:286
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:269
ArrayRef< int64_t > getShape() const
Definition: SPIRVTypes.cpp:291
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:297
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:399
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:401
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:413
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:409
Type getElementType() const
Definition: SPIRVTypes.cpp:395
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:405
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:455
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:457
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:451
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:519
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:509
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:705
static bool classof(Type type)
Definition: SPIRVTypes.cpp:677
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:699
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:694
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:769
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:761
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:756
static bool classof(Type type)
Definition: SPIRVTypes.cpp:531
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:541
SPIR-V struct type.
Definition: SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition: SPIRVTypes.h:465
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:138
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:140
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:145
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:245
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:248
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:368
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:363
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:361
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:440
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:432
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:434
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:492
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:498
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:745
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:747
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:804
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:925
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:854
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:870
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:921
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:808
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:847
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:932
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:816
StructType::StructDecorationInfo const * structDecorationsInfo
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
Definition: SPIRVTypes.cpp:940
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
Definition: SPIRVTypes.cpp:960
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const