MLIR  22.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 #include <cstdint>
21 
22 using namespace mlir;
23 using namespace mlir::spirv;
24 
25 //===----------------------------------------------------------------------===//
26 // ArrayType
27 //===----------------------------------------------------------------------===//
28 
30  using KeyTy = std::tuple<Type, unsigned, unsigned>;
31 
33  const KeyTy &key) {
34  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
35  }
36 
37  bool operator==(const KeyTy &key) const {
38  return key == KeyTy(elementType, elementCount, stride);
39  }
40 
41  ArrayTypeStorage(const KeyTy &key)
42  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
43  stride(std::get<2>(key)) {}
44 
46  unsigned elementCount;
47  unsigned stride;
48 };
49 
50 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
51  assert(elementCount && "ArrayType needs at least one element");
52  return Base::get(elementType.getContext(), elementType, elementCount,
53  /*stride=*/0);
54 }
55 
56 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
57  unsigned stride) {
58  assert(elementCount && "ArrayType needs at least one element");
59  return Base::get(elementType.getContext(), elementType, elementCount, stride);
60 }
61 
62 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
63 
64 Type ArrayType::getElementType() const { return getImpl()->elementType; }
65 
66 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67 
69  std::optional<StorageClass> storage) {
70  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
71 }
72 
75  std::optional<StorageClass> storage) {
76  llvm::cast<SPIRVType>(getElementType())
77  .getCapabilities(capabilities, storage);
78 }
79 
80 std::optional<int64_t> ArrayType::getSizeInBytes() {
81  auto elementType = llvm::cast<SPIRVType>(getElementType());
82  std::optional<int64_t> size = elementType.getSizeInBytes();
83  if (!size)
84  return std::nullopt;
85  return (*size + getArrayStride()) * getNumElements();
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // CompositeType
90 //===----------------------------------------------------------------------===//
91 
93  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
94  return isValid(vectorType);
98 }
99 
100 bool CompositeType::isValid(VectorType type) {
101  return type.getRank() == 1 &&
102  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
103  llvm::isa<ScalarType>(type.getElementType());
104 }
105 
106 Type CompositeType::getElementType(unsigned index) const {
107  return TypeSwitch<Type, Type>(*this)
109  TensorArmType>([](auto type) { return type.getElementType(); })
110  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
111  .Case<StructType>(
112  [index](StructType type) { return type.getElementType(index); })
113  .Default(
114  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
115 }
116 
118  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
119  return arrayType.getNumElements();
120  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
121  return matrixType.getNumColumns();
122  if (auto structType = llvm::dyn_cast<StructType>(*this))
123  return structType.getNumElements();
124  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
125  return vectorType.getNumElements();
126  if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
127  return tensorArmType.getNumElements();
128  if (llvm::isa<CooperativeMatrixType>(*this)) {
129  llvm_unreachable(
130  "invalid to query number of elements of spirv Cooperative Matrix type");
131  }
132  if (llvm::isa<RuntimeArrayType>(*this)) {
133  llvm_unreachable(
134  "invalid to query number of elements of spirv::RuntimeArray type");
135  }
136  llvm_unreachable("invalid composite type");
137 }
138 
140  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
145  std::optional<StorageClass> storage) {
146  TypeSwitch<Type>(*this)
148  StructType>(
149  [&](auto type) { type.getExtensions(extensions, storage); })
150  .Case<VectorType>([&](VectorType type) {
151  return llvm::cast<ScalarType>(type.getElementType())
152  .getExtensions(extensions, storage);
153  })
154  .Case<TensorArmType>([&](TensorArmType type) {
155  static constexpr Extension ext{Extension::SPV_ARM_tensors};
156  extensions.push_back(ext);
157  return llvm::cast<ScalarType>(type.getElementType())
158  .getExtensions(extensions, storage);
159  })
160 
161  .Default([](Type) { llvm_unreachable("invalid composite type"); });
162 }
163 
166  std::optional<StorageClass> storage) {
167  TypeSwitch<Type>(*this)
169  StructType>(
170  [&](auto type) { type.getCapabilities(capabilities, storage); })
171  .Case<VectorType>([&](VectorType type) {
172  auto vecSize = getNumElements();
173  if (vecSize == 8 || vecSize == 16) {
174  static const Capability caps[] = {Capability::Vector16};
175  ArrayRef<Capability> ref(caps, std::size(caps));
176  capabilities.push_back(ref);
177  }
178  return llvm::cast<ScalarType>(type.getElementType())
179  .getCapabilities(capabilities, storage);
180  })
181  .Case<TensorArmType>([&](TensorArmType type) {
182  static constexpr Capability cap{Capability::TensorsARM};
183  capabilities.push_back(cap);
184  return llvm::cast<ScalarType>(type.getElementType())
185  .getCapabilities(capabilities, storage);
186  })
187  .Default([](Type) { llvm_unreachable("invalid composite type"); });
188 }
189 
190 std::optional<int64_t> CompositeType::getSizeInBytes() {
191  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
192  return arrayType.getSizeInBytes();
193  if (auto structType = llvm::dyn_cast<StructType>(*this))
194  return structType.getSizeInBytes();
195  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
196  std::optional<int64_t> elementSize =
197  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
198  if (!elementSize)
199  return std::nullopt;
200  return *elementSize * vectorType.getNumElements();
201  }
202  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
203  std::optional<int64_t> elementSize =
204  llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
205  if (!elementSize)
206  return std::nullopt;
207  return *elementSize * tensorArmType.getNumElements();
208  }
209  return std::nullopt;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // CooperativeMatrixType
214 //===----------------------------------------------------------------------===//
215 
217  // In the specification dimensions of the Cooperative Matrix are 32-bit
218  // integers --- the initial implementation kept those values as such. However,
219  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
220  // as 32-bits and expose it as int64_t through `getShape`, however, this
221  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
222  // 32-bits integers would require an extra logic and storage. So, we diverge
223  // from the spec and internally represent the dimensions as 64-bit integers,
224  // so we can easily return an `ArrayRef` from `getShape` without any extra
225  // logic. Alternatively, we could store both rows and columns (both 32-bits)
226  // and shape (64-bits), assigning rows and columns to shape whenever
227  // `getShape` is called. This would be at the cost of extra logic and storage.
228  // Note: Because `ArrayRef` is returned we cannot construct an object in
229  // `getShape` on the fly.
230  using KeyTy =
231  std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
232 
234  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
235  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
237  }
238 
239  bool operator==(const KeyTy &key) const {
240  return key == KeyTy(elementType, shape[0], shape[1], scope, use);
241  }
242 
244  : elementType(std::get<0>(key)),
245  shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
246  use(std::get<4>(key)) {}
247 
249  // [#rows, #columns]
250  std::array<int64_t, 2> shape;
251  Scope scope;
252  CooperativeMatrixUseKHR use;
253 };
254 
256  uint32_t rows,
257  uint32_t columns, Scope scope,
258  CooperativeMatrixUseKHR use) {
259  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
260  use);
261 }
262 
264  return getImpl()->elementType;
265 }
266 
268  assert(getImpl()->shape[0] != ShapedType::kDynamic);
269  return static_cast<uint32_t>(getImpl()->shape[0]);
270 }
271 
273  assert(getImpl()->shape[1] != ShapedType::kDynamic);
274  return static_cast<uint32_t>(getImpl()->shape[1]);
275 }
276 
278  return getImpl()->shape;
279 }
280 
281 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
282 
283 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284  return getImpl()->use;
285 }
286 
289  std::optional<StorageClass> storage) {
290  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
291  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292  extensions.push_back(exts);
293 }
294 
297  std::optional<StorageClass> storage) {
298  llvm::cast<SPIRVType>(getElementType())
299  .getCapabilities(capabilities, storage);
300  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
301  capabilities.push_back(caps);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ImageType
306 //===----------------------------------------------------------------------===//
307 
308 template <typename T>
309 static constexpr unsigned getNumBits() {
310  return 0;
311 }
312 template <>
313 constexpr unsigned getNumBits<Dim>() {
314  static_assert((1 << 3) > getMaxEnumValForDim(),
315  "Not enough bits to encode Dim value");
316  return 3;
317 }
318 template <>
319 constexpr unsigned getNumBits<ImageDepthInfo>() {
320  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
321  "Not enough bits to encode ImageDepthInfo value");
322  return 2;
323 }
324 template <>
325 constexpr unsigned getNumBits<ImageArrayedInfo>() {
326  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
327  "Not enough bits to encode ImageArrayedInfo value");
328  return 1;
329 }
330 template <>
331 constexpr unsigned getNumBits<ImageSamplingInfo>() {
332  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
333  "Not enough bits to encode ImageSamplingInfo value");
334  return 1;
335 }
336 template <>
337 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
338  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
339  "Not enough bits to encode ImageSamplerUseInfo value");
340  return 2;
341 }
342 template <>
343 constexpr unsigned getNumBits<ImageFormat>() {
344  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
345  "Not enough bits to encode ImageFormat value");
346  return 6;
347 }
348 
350 public:
351  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
352  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
353 
355  const KeyTy &key) {
356  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
357  }
358 
359  bool operator==(const KeyTy &key) const {
360  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
361  samplerUseInfo, format);
362  }
363 
365  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
366  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
367  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
368  format(std::get<6>(key)) {}
369 
377 };
378 
379 ImageType
380 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
381  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
382  value) {
383  return Base::get(std::get<0>(value).getContext(), value);
384 }
385 
386 Type ImageType::getElementType() const { return getImpl()->elementType; }
387 
388 Dim ImageType::getDim() const { return getImpl()->dim; }
389 
390 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
391 
392 ImageArrayedInfo ImageType::getArrayedInfo() const {
393  return getImpl()->arrayedInfo;
394 }
395 
396 ImageSamplingInfo ImageType::getSamplingInfo() const {
397  return getImpl()->samplingInfo;
398 }
399 
400 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
401  return getImpl()->samplerUseInfo;
402 }
403 
404 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
405 
407  std::optional<StorageClass>) {
408  // Image types do not require extra extensions thus far.
409 }
410 
413  std::optional<StorageClass>) {
414  if (auto dimCaps = spirv::getCapabilities(getDim()))
415  capabilities.push_back(*dimCaps);
416 
417  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
418  capabilities.push_back(*fmtCaps);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // PointerType
423 //===----------------------------------------------------------------------===//
424 
426  // (Type, StorageClass) as the key: Type stored in this struct, and
427  // StorageClass stored as TypeStorage's subclass data.
428  using KeyTy = std::pair<Type, StorageClass>;
429 
431  const KeyTy &key) {
432  return new (allocator.allocate<PointerTypeStorage>())
433  PointerTypeStorage(key);
434  }
435 
436  bool operator==(const KeyTy &key) const {
437  return key == KeyTy(pointeeType, storageClass);
438  }
439 
441  : pointeeType(key.first), storageClass(key.second) {}
442 
444  StorageClass storageClass;
445 };
446 
447 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
448  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
449 }
450 
451 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
452 
453 StorageClass PointerType::getStorageClass() const {
454  return getImpl()->storageClass;
455 }
456 
458  std::optional<StorageClass> storage) {
459  // Use this pointer type's storage class because this pointer indicates we are
460  // using the pointee type in that specific storage class.
461  llvm::cast<SPIRVType>(getPointeeType())
462  .getExtensions(extensions, getStorageClass());
463 
464  if (auto scExts = spirv::getExtensions(getStorageClass()))
465  extensions.push_back(*scExts);
466 }
467 
470  std::optional<StorageClass> storage) {
471  // Use this pointer type's storage class because this pointer indicates we are
472  // using the pointee type in that specific storage class.
473  llvm::cast<SPIRVType>(getPointeeType())
474  .getCapabilities(capabilities, getStorageClass());
475 
476  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
477  capabilities.push_back(*scCaps);
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // RuntimeArrayType
482 //===----------------------------------------------------------------------===//
483 
485  using KeyTy = std::pair<Type, unsigned>;
486 
488  const KeyTy &key) {
489  return new (allocator.allocate<RuntimeArrayTypeStorage>())
491  }
492 
493  bool operator==(const KeyTy &key) const {
494  return key == KeyTy(elementType, stride);
495  }
496 
498  : elementType(key.first), stride(key.second) {}
499 
501  unsigned stride;
502 };
503 
505  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
506 }
507 
508 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
509  return Base::get(elementType.getContext(), elementType, stride);
510 }
511 
512 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
513 
514 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
515 
518  std::optional<StorageClass> storage) {
519  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
520 }
521 
524  std::optional<StorageClass> storage) {
525  {
526  static const Capability caps[] = {Capability::Shader};
527  ArrayRef<Capability> ref(caps, std::size(caps));
528  capabilities.push_back(ref);
529  }
530  llvm::cast<SPIRVType>(getElementType())
531  .getCapabilities(capabilities, storage);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // ScalarType
536 //===----------------------------------------------------------------------===//
537 
539  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
540  return isValid(floatType);
541  }
542  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
543  return isValid(intType);
544  }
545  return false;
546 }
547 
548 bool ScalarType::isValid(FloatType type) {
549  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
550 }
551 
552 bool ScalarType::isValid(IntegerType type) {
553  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
554 }
555 
557  std::optional<StorageClass> storage) {
558  if (isa<BFloat16Type>(*this)) {
559  static const Extension ext = Extension::SPV_KHR_bfloat16;
560  extensions.push_back(ext);
561  }
562 
563  // 8- or 16-bit integer/floating-point numbers will require extra extensions
564  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
565  // SPV_KHR_8bit_storage for more details.
566  if (!storage)
567  return;
568 
569  switch (*storage) {
570  case StorageClass::PushConstant:
571  case StorageClass::StorageBuffer:
572  case StorageClass::Uniform:
573  if (getIntOrFloatBitWidth() == 8) {
574  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575  ArrayRef<Extension> ref(exts, std::size(exts));
576  extensions.push_back(ref);
577  }
578  [[fallthrough]];
579  case StorageClass::Input:
580  case StorageClass::Output:
581  if (getIntOrFloatBitWidth() == 16) {
582  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583  ArrayRef<Extension> ref(exts, std::size(exts));
584  extensions.push_back(ref);
585  }
586  break;
587  default:
588  break;
589  }
590 }
591 
594  std::optional<StorageClass> storage) {
595  unsigned bitwidth = getIntOrFloatBitWidth();
596 
597  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
598  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
599  // SPV_KHR_8bit_storage for more details.
600 
601 #define STORAGE_CASE(storage, cap8, cap16) \
602  case StorageClass::storage: { \
603  if (bitwidth == 8) { \
604  static const Capability caps[] = {Capability::cap8}; \
605  ArrayRef<Capability> ref(caps, std::size(caps)); \
606  capabilities.push_back(ref); \
607  return; \
608  } \
609  if (bitwidth == 16) { \
610  static const Capability caps[] = {Capability::cap16}; \
611  ArrayRef<Capability> ref(caps, std::size(caps)); \
612  capabilities.push_back(ref); \
613  return; \
614  } \
615  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
616  /* storage classes. Fall through to the next section. */ \
617  } break
618 
619  // This part only handles the cases where special bitwidths appearing in
620  // interface storage classes.
621  if (storage) {
622  switch (*storage) {
623  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
624  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
625  StorageBuffer16BitAccess);
626  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
627  StorageUniform16);
628  case StorageClass::Input:
629  case StorageClass::Output: {
630  if (bitwidth == 16) {
631  static const Capability caps[] = {Capability::StorageInputOutput16};
632  ArrayRef<Capability> ref(caps, std::size(caps));
633  capabilities.push_back(ref);
634  return;
635  }
636  break;
637  }
638  default:
639  break;
640  }
641  }
642 #undef STORAGE_CASE
643 
644  // For other non-interface storage classes, require a different set of
645  // capabilities for special bitwidths.
646 
647 #define WIDTH_CASE(type, width) \
648  case width: { \
649  static const Capability caps[] = {Capability::type##width}; \
650  ArrayRef<Capability> ref(caps, std::size(caps)); \
651  capabilities.push_back(ref); \
652  } break
653 
654  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
655  switch (bitwidth) {
656  WIDTH_CASE(Int, 8);
657  WIDTH_CASE(Int, 16);
658  WIDTH_CASE(Int, 64);
659  case 1:
660  case 32:
661  break;
662  default:
663  llvm_unreachable("invalid bitwidth to getCapabilities");
664  }
665  } else {
666  assert(llvm::isa<FloatType>(*this));
667  switch (bitwidth) {
668  case 16: {
669  if (isa<BFloat16Type>(*this)) {
670  static const Capability cap = Capability::BFloat16TypeKHR;
671  capabilities.push_back(cap);
672  } else {
673  static const Capability cap = Capability::Float16;
674  capabilities.push_back(cap);
675  }
676  break;
677  }
678  WIDTH_CASE(Float, 64);
679  case 32:
680  break;
681  default:
682  llvm_unreachable("invalid bitwidth to getCapabilities");
683  }
684  }
685 
686 #undef WIDTH_CASE
687 }
688 
689 std::optional<int64_t> ScalarType::getSizeInBytes() {
690  auto bitWidth = getIntOrFloatBitWidth();
691  // According to the SPIR-V spec:
692  // "There is no physical size or bit pattern defined for values with boolean
693  // type. If they are stored (in conjunction with OpVariable), they can only
694  // be used with logical addressing operations, not physical, and only with
695  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
696  // Private, Function, Input, and Output."
697  if (bitWidth == 1)
698  return std::nullopt;
699  return bitWidth / 8;
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // SPIRVType
704 //===----------------------------------------------------------------------===//
705 
707  // Allow SPIR-V dialect types
708  if (llvm::isa<SPIRVDialect>(type.getDialect()))
709  return true;
710  if (llvm::isa<ScalarType>(type))
711  return true;
712  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
713  return CompositeType::isValid(vectorType);
714  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
715  return llvm::isa<ScalarType>(tensorArmType.getElementType());
716  return false;
717 }
718 
720  return isIntOrFloat() || llvm::isa<VectorType>(*this);
721 }
722 
724  std::optional<StorageClass> storage) {
725  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
726  scalarType.getExtensions(extensions, storage);
727  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
728  compositeType.getExtensions(extensions, storage);
729  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
730  imageType.getExtensions(extensions, storage);
731  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
732  sampledImageType.getExtensions(extensions, storage);
733  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
734  matrixType.getExtensions(extensions, storage);
735  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
736  ptrType.getExtensions(extensions, storage);
737  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
738  tensorArmType.getExtensions(extensions, storage);
739  } else {
740  llvm_unreachable("invalid SPIR-V Type to getExtensions");
741  }
742 }
743 
746  std::optional<StorageClass> storage) {
747  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
748  scalarType.getCapabilities(capabilities, storage);
749  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
750  compositeType.getCapabilities(capabilities, storage);
751  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
752  imageType.getCapabilities(capabilities, storage);
753  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
754  sampledImageType.getCapabilities(capabilities, storage);
755  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
756  matrixType.getCapabilities(capabilities, storage);
757  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
758  ptrType.getCapabilities(capabilities, storage);
759  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
760  tensorArmType.getCapabilities(capabilities, storage);
761  } else {
762  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
763  }
764 }
765 
766 std::optional<int64_t> SPIRVType::getSizeInBytes() {
767  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
768  return scalarType.getSizeInBytes();
769  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
770  return compositeType.getSizeInBytes();
771  return std::nullopt;
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // SampledImageType
776 //===----------------------------------------------------------------------===//
778  using KeyTy = Type;
779 
780  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
781 
782  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
783 
785  const KeyTy &key) {
786  return new (allocator.allocate<SampledImageTypeStorage>())
788  }
789 
791 };
792 
794  return Base::get(imageType.getContext(), imageType);
795 }
796 
799  Type imageType) {
800  return Base::getChecked(emitError, imageType.getContext(), imageType);
801 }
802 
803 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
804 
805 LogicalResult
807  Type imageType) {
808  auto image = dyn_cast<ImageType>(imageType);
809  if (!image)
810  return emitError() << "expected image type";
811 
812  // As per SPIR-V spec: "It [ImageType] must not have a Dim of SubpassData.
813  // Additionally, starting with version 1.6, it must not have a Dim of Buffer.
814  // ("3.3.6. Type-Declaration Instructions")
815  if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, image.getDim()))
816  return emitError() << "Dim must not be SubpassData or Buffer";
817 
818  return success();
819 }
820 
823  std::optional<StorageClass> storage) {
824  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
825 }
826 
829  std::optional<StorageClass> storage) {
830  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
831 }
832 
833 //===----------------------------------------------------------------------===//
834 // StructType
835 //===----------------------------------------------------------------------===//
836 
837 /// Type storage for SPIR-V structure types:
838 ///
839 /// Structures are uniqued using:
840 /// - for identified structs:
841 /// - a string identifier;
842 /// - for literal structs:
843 /// - a list of member types;
844 /// - a list of member offset info;
845 /// - a list of member decoration info;
846 /// - a list of struct decoration info.
847 ///
848 /// Identified structures only have a mutable component consisting of:
849 /// - a list of member types;
850 /// - a list of member offset info;
851 /// - a list of member decoration info;
852 /// - a list of struct decoration info.
854  /// Construct a storage object for an identified struct type. A struct type
855  /// associated with such storage must call StructType::trySetBody(...) later
856  /// in order to mutate the storage object providing the actual content.
857  StructTypeStorage(StringRef identifier)
858  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
859  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
860  numStructDecorations(0), structDecorationsInfo(nullptr),
861  identifier(identifier) {}
862 
863  /// Construct a storage object for a literal struct type. A struct type
864  /// associated with such storage is immutable.
866  unsigned numMembers, Type const *memberTypes,
867  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
868  StructType::MemberDecorationInfo const *memberDecorationsInfo,
869  unsigned numStructDecorations,
870  StructType::StructDecorationInfo const *structDecorationsInfo)
871  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
872  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
873  memberDecorationsInfo(memberDecorationsInfo),
874  numStructDecorations(numStructDecorations),
875  structDecorationsInfo(structDecorationsInfo) {}
876 
877  /// A storage key is divided into 2 parts:
878  /// - for identified structs:
879  /// - a StringRef representing the struct identifier;
880  /// - for literal structs:
881  /// - an ArrayRef<Type> for member types;
882  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
883  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
884  /// info;
885  /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
886  /// info.
887  ///
888  /// An identified struct type is uniqued only by the first part (field 0)
889  /// of the key.
890  ///
891  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
892  /// and 4) of the key. The identifier field (field 0) must be empty.
893  using KeyTy =
894  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
897 
898  /// For identified structs, return true if the given key contains the same
899  /// identifier.
900  ///
901  /// For literal structs, return true if the given key contains a matching list
902  /// of member types + offset info + decoration info.
903  bool operator==(const KeyTy &key) const {
904  if (isIdentified()) {
905  // Identified types are uniqued by their identifier.
906  return getIdentifier() == std::get<0>(key);
907  }
908 
909  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
910  getMemberDecorationsInfo(), getStructDecorationsInfo());
911  }
912 
913  /// If the given key contains a non-empty identifier, this method constructs
914  /// an identified struct and leaves the rest of the struct type data to be set
915  /// through a later call to StructType::trySetBody(...).
916  ///
917  /// If, on the other hand, the key contains an empty identifier, a literal
918  /// struct is constructed using the other fields of the key.
920  const KeyTy &key) {
921  StringRef keyIdentifier = std::get<0>(key);
922 
923  if (!keyIdentifier.empty()) {
924  StringRef identifier = allocator.copyInto(keyIdentifier);
925 
926  // Identified StructType body/members will be set through trySetBody(...)
927  // later.
928  return new (allocator.allocate<StructTypeStorage>())
929  StructTypeStorage(identifier);
930  }
931 
932  ArrayRef<Type> keyTypes = std::get<1>(key);
933 
934  // Copy the member type and layout information into the bump pointer
935  const Type *typesList = nullptr;
936  if (!keyTypes.empty()) {
937  typesList = allocator.copyInto(keyTypes).data();
938  }
939 
940  const StructType::OffsetInfo *offsetInfoList = nullptr;
941  if (!std::get<2>(key).empty()) {
942  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
943  assert(keyOffsetInfo.size() == keyTypes.size() &&
944  "size of offset information must be same as the size of number of "
945  "elements");
946  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
947  }
948 
949  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
950  unsigned numMemberDecorations = 0;
951  if (!std::get<3>(key).empty()) {
952  auto keyMemberDecorations = std::get<3>(key);
953  numMemberDecorations = keyMemberDecorations.size();
954  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
955  }
956 
957  const StructType::StructDecorationInfo *structDecorationList = nullptr;
958  unsigned numStructDecorations = 0;
959  if (!std::get<4>(key).empty()) {
960  auto keyStructDecorations = std::get<4>(key);
961  numStructDecorations = keyStructDecorations.size();
962  structDecorationList = allocator.copyInto(keyStructDecorations).data();
963  }
964 
965  return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
966  keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
967  memberDecorationList, numStructDecorations, structDecorationList);
968  }
969 
971  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
972  }
973 
975  if (offsetInfo) {
976  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
977  }
978  return {};
979  }
980 
982  if (memberDecorationsInfo) {
983  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
984  numMemberDecorations);
985  }
986  return {};
987  }
988 
990  if (structDecorationsInfo)
991  return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
992  numStructDecorations);
993  return {};
994  }
995 
996  StringRef getIdentifier() const { return identifier; }
997 
998  bool isIdentified() const { return !identifier.empty(); }
999 
1000  /// Sets the struct type content for identified structs. Calling this method
1001  /// is only valid for identified structs.
1002  ///
1003  /// Fails under the following conditions:
1004  /// - If called for a literal struct;
1005  /// - If called for an identified struct whose body was set before (through a
1006  /// call to this method) but with different contents from the passed
1007  /// arguments.
1008  LogicalResult
1009  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
1010  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
1011  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
1012  ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
1013  if (!isIdentified())
1014  return failure();
1015 
1016  if (memberTypesAndIsBodySet.getInt() &&
1017  (getMemberTypes() != structMemberTypes ||
1018  getOffsetInfo() != structOffsetInfo ||
1019  getMemberDecorationsInfo() != structMemberDecorationInfo ||
1020  getStructDecorationsInfo() != structDecorationInfo))
1021  return failure();
1022 
1023  memberTypesAndIsBodySet.setInt(true);
1024  numMembers = structMemberTypes.size();
1025 
1026  // Copy the member type and layout information into the bump pointer.
1027  if (!structMemberTypes.empty())
1028  memberTypesAndIsBodySet.setPointer(
1029  allocator.copyInto(structMemberTypes).data());
1030 
1031  if (!structOffsetInfo.empty()) {
1032  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1033  "size of offset information must be same as the size of number of "
1034  "elements");
1035  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1036  }
1037 
1038  if (!structMemberDecorationInfo.empty()) {
1039  numMemberDecorations = structMemberDecorationInfo.size();
1040  memberDecorationsInfo =
1041  allocator.copyInto(structMemberDecorationInfo).data();
1042  }
1043 
1044  if (!structDecorationInfo.empty()) {
1045  numStructDecorations = structDecorationInfo.size();
1046  structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
1047  }
1048 
1049  return success();
1050  }
1051 
1052  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1054  unsigned numMembers;
1059  StringRef identifier;
1060 };
1061 
1062 StructType
1066  ArrayRef<StructType::StructDecorationInfo> structDecorations) {
1067  assert(!memberTypes.empty() && "Struct needs at least one member type");
1068  // Sort the decorations.
1069  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
1070  memberDecorations);
1071  llvm::array_pod_sort(sortedMemberDecorations.begin(),
1072  sortedMemberDecorations.end());
1073  SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
1074  structDecorations);
1075  llvm::array_pod_sort(sortedStructDecorations.begin(),
1076  sortedStructDecorations.end());
1077 
1078  return Base::get(memberTypes.vec().front().getContext(),
1079  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1080  sortedMemberDecorations, sortedStructDecorations);
1081 }
1082 
1084  StringRef identifier) {
1085  assert(!identifier.empty() &&
1086  "StructType identifier must be non-empty string");
1087 
1088  return Base::get(context, identifier, ArrayRef<Type>(),
1092 }
1093 
1094 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1095  StructType newStructType = Base::get(
1096  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1099  // Set an empty body in case this is a identified struct.
1100  if (newStructType.isIdentified() &&
1101  failed(newStructType.trySetBody(
1105  return StructType();
1106 
1107  return newStructType;
1108 }
1109 
1110 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1111 
1112 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1113 
1114 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1115 
1116 Type StructType::getElementType(unsigned index) const {
1117  assert(getNumElements() > index && "member index out of range");
1118  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1119 }
1120 
1122  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1123  getNumElements());
1124 }
1125 
1126 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1127 
1128 bool StructType::hasDecoration(spirv::Decoration decoration) const {
1130  getImpl()->getStructDecorationsInfo())
1131  if (info.decoration == decoration)
1132  return true;
1133 
1134  return false;
1135 }
1136 
1137 uint64_t StructType::getMemberOffset(unsigned index) const {
1138  assert(getNumElements() > index && "member index out of range");
1139  return getImpl()->offsetInfo[index];
1140 }
1141 
1144  const {
1145  memberDecorations.clear();
1146  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1147  memberDecorations.append(implMemberDecorations.begin(),
1148  implMemberDecorations.end());
1149 }
1150 
1152  unsigned index,
1153  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1154  assert(getNumElements() > index && "member index out of range");
1155  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1156  decorationsInfo.clear();
1157  for (const auto &memberDecoration : memberDecorations) {
1158  if (memberDecoration.memberIndex == index) {
1159  decorationsInfo.push_back(memberDecoration);
1160  }
1161  if (memberDecoration.memberIndex > index) {
1162  // Early exit since the decorations are stored sorted.
1163  return;
1164  }
1165  }
1166 }
1167 
1170  const {
1171  structDecorations.clear();
1172  auto implDecorations = getImpl()->getStructDecorationsInfo();
1173  structDecorations.append(implDecorations.begin(), implDecorations.end());
1174 }
1175 
1176 LogicalResult
1178  ArrayRef<OffsetInfo> offsetInfo,
1179  ArrayRef<MemberDecorationInfo> memberDecorations,
1180  ArrayRef<StructDecorationInfo> structDecorations) {
1181  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1182  structDecorations);
1183 }
1184 
1186  std::optional<StorageClass> storage) {
1187  for (Type elementType : getElementTypes())
1188  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1189 }
1190 
1193  std::optional<StorageClass> storage) {
1194  for (Type elementType : getElementTypes())
1195  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1196 }
1197 
1198 llvm::hash_code spirv::hash_value(
1199  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1200  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1201  memberDecorationInfo.decoration);
1202 }
1203 
1204 llvm::hash_code spirv::hash_value(
1205  const StructType::StructDecorationInfo &structDecorationInfo) {
1206  return llvm::hash_value(structDecorationInfo.decoration);
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // MatrixType
1211 //===----------------------------------------------------------------------===//
1212 
1214  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1215  : columnType(columnType), columnCount(columnCount) {}
1216 
1217  using KeyTy = std::tuple<Type, uint32_t>;
1218 
1220  const KeyTy &key) {
1221 
1222  // Initialize the memory using placement new.
1223  return new (allocator.allocate<MatrixTypeStorage>())
1224  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1225  }
1226 
1227  bool operator==(const KeyTy &key) const {
1228  return key == KeyTy(columnType, columnCount);
1229  }
1230 
1232  const uint32_t columnCount;
1233 };
1234 
1235 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1236  return Base::get(columnType.getContext(), columnType, columnCount);
1237 }
1238 
1240  Type columnType, uint32_t columnCount) {
1241  return Base::getChecked(emitError, columnType.getContext(), columnType,
1242  columnCount);
1243 }
1244 
1245 LogicalResult
1247  Type columnType, uint32_t columnCount) {
1248  if (columnCount < 2 || columnCount > 4)
1249  return emitError() << "matrix can have 2, 3, or 4 columns only";
1250 
1251  if (!isValidColumnType(columnType))
1252  return emitError() << "matrix columns must be vectors of floats";
1253 
1254  /// The underlying vectors (columns) must be of size 2, 3, or 4
1255  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1256  if (columnShape.size() != 1)
1257  return emitError() << "matrix columns must be 1D vectors";
1258 
1259  if (columnShape[0] < 2 || columnShape[0] > 4)
1260  return emitError() << "matrix columns must be of size 2, 3, or 4";
1261 
1262  return success();
1263 }
1264 
1265 /// Returns true if the matrix elements are vectors of float elements
1267  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1268  if (llvm::isa<FloatType>(vectorType.getElementType()))
1269  return true;
1270  }
1271  return false;
1272 }
1273 
1274 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1275 
1277  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1278 }
1279 
1280 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1281 
1282 unsigned MatrixType::getNumRows() const {
1283  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1284 }
1285 
1286 unsigned MatrixType::getNumElements() const {
1287  return (getImpl()->columnCount) * getNumRows();
1288 }
1289 
1291  std::optional<StorageClass> storage) {
1292  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1293 }
1294 
1297  std::optional<StorageClass> storage) {
1298  {
1299  static const Capability caps[] = {Capability::Matrix};
1300  ArrayRef<Capability> ref(caps, std::size(caps));
1301  capabilities.push_back(ref);
1302  }
1303  // Add any capabilities associated with the underlying vectors (i.e., columns)
1304  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // TensorArmType
1309 //===----------------------------------------------------------------------===//
1310 
1312  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1313 
1315  const KeyTy &key) {
1316  auto [shape, elementType] = key;
1317  shape = allocator.copyInto(shape);
1318  return new (allocator.allocate<TensorArmTypeStorage>())
1319  TensorArmTypeStorage(shape, elementType);
1320  }
1321 
1322  static llvm::hash_code hashKey(const KeyTy &key) {
1323  auto [shape, elementType] = key;
1324  return llvm::hash_combine(shape, elementType);
1325  }
1326 
1327  bool operator==(const KeyTy &key) const {
1328  return key == KeyTy(shape, elementType);
1329  }
1330 
1332  : shape(shape), elementType(elementType) {}
1333 
1336 };
1337 
1339  return Base::get(elementType.getContext(), shape, elementType);
1340 }
1341 
1343  Type elementType) const {
1344  return TensorArmType::get(shape.value_or(getShape()), elementType);
1345 }
1346 
1347 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1349 
1352  std::optional<StorageClass> storage) {
1353 
1354  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
1355  static constexpr Extension ext{Extension::SPV_ARM_tensors};
1356  extensions.push_back(ext);
1357 }
1358 
1361  std::optional<StorageClass> storage) {
1362  llvm::cast<SPIRVType>(getElementType())
1363  .getCapabilities(capabilities, storage);
1364  static constexpr Capability cap{Capability::TensorsARM};
1365  capabilities.push_back(cap);
1366 }
1367 
1368 LogicalResult
1370  ArrayRef<int64_t> shape, Type elementType) {
1371  if (llvm::is_contained(shape, 0))
1372  return emitError() << "arm.tensor do not support dimensions = 0";
1373  if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1374  llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1375  return emitError()
1376  << "arm.tensor shape dimensions must be either fully dynamic or "
1377  "completed shaped";
1378  return success();
1379 }
1380 
1381 //===----------------------------------------------------------------------===//
1382 // SPIR-V Dialect
1383 //===----------------------------------------------------------------------===//
1384 
1385 void SPIRVDialect::registerTypes() {
1388 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:337
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:343
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:309
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:325
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:331
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:313
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:319
int64_t rows
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:80
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:68
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:73
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:164
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:190
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:139
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:143
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:117
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:100
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:106
static bool classof(Type type)
Definition: SPIRVTypes.cpp:92
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:281
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:267
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:272
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:255
ArrayRef< int64_t > getShape() const
Definition: SPIRVTypes.cpp:277
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:287
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:295
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:283
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:390
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:392
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:404
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:400
Type getElementType() const
Definition: SPIRVTypes.cpp:386
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:396
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:411
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:406
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:457
Type getPointeeType() const
Definition: SPIRVTypes.cpp:451
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:453
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:468
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:522
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:514
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:516
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:766
static bool classof(Type type)
Definition: SPIRVTypes.cpp:706
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:744
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:723
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:806
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:827
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:798
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:821
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:793
static bool classof(Type type)
Definition: SPIRVTypes.cpp:538
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:556
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:548
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:592
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:689
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition: SPIRVTypes.h:524
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:30
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:32
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:37
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:231
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:234
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:359
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:354
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:352
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:436
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:428
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:430
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:487
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:493
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:782
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:784
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:853
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:974
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:903
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:919
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:970
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:857
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:896
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:981
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:865
StructType::StructDecorationInfo const * structDecorationsInfo
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
Definition: SPIRVTypes.cpp:989
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const