MLIR  22.0.0git
SPIRVTypes.cpp
Go to the documentation of this file.
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 #include <cstdint>
21 
22 using namespace mlir;
23 using namespace mlir::spirv;
24 
25 //===----------------------------------------------------------------------===//
26 // ArrayType
27 //===----------------------------------------------------------------------===//
28 
30  using KeyTy = std::tuple<Type, unsigned, unsigned>;
31 
33  const KeyTy &key) {
34  return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
35  }
36 
37  bool operator==(const KeyTy &key) const {
38  return key == KeyTy(elementType, elementCount, stride);
39  }
40 
41  ArrayTypeStorage(const KeyTy &key)
42  : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
43  stride(std::get<2>(key)) {}
44 
46  unsigned elementCount;
47  unsigned stride;
48 };
49 
50 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
51  assert(elementCount && "ArrayType needs at least one element");
52  return Base::get(elementType.getContext(), elementType, elementCount,
53  /*stride=*/0);
54 }
55 
56 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
57  unsigned stride) {
58  assert(elementCount && "ArrayType needs at least one element");
59  return Base::get(elementType.getContext(), elementType, elementCount, stride);
60 }
61 
62 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
63 
64 Type ArrayType::getElementType() const { return getImpl()->elementType; }
65 
66 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67 
69  std::optional<StorageClass> storage) {
70  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
71 }
72 
75  std::optional<StorageClass> storage) {
76  llvm::cast<SPIRVType>(getElementType())
77  .getCapabilities(capabilities, storage);
78 }
79 
80 std::optional<int64_t> ArrayType::getSizeInBytes() {
81  auto elementType = llvm::cast<SPIRVType>(getElementType());
82  std::optional<int64_t> size = elementType.getSizeInBytes();
83  if (!size)
84  return std::nullopt;
85  return (*size + getArrayStride()) * getNumElements();
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // CompositeType
90 //===----------------------------------------------------------------------===//
91 
93  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
94  return isValid(vectorType);
98 }
99 
100 bool CompositeType::isValid(VectorType type) {
101  return type.getRank() == 1 &&
102  llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
103  llvm::isa<ScalarType>(type.getElementType());
104 }
105 
106 Type CompositeType::getElementType(unsigned index) const {
107  return TypeSwitch<Type, Type>(*this)
109  TensorArmType>([](auto type) { return type.getElementType(); })
110  .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
111  .Case<StructType>(
112  [index](StructType type) { return type.getElementType(index); })
113  .Default(
114  [](Type) -> Type { llvm_unreachable("invalid composite type"); });
115 }
116 
118  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
119  return arrayType.getNumElements();
120  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
121  return matrixType.getNumColumns();
122  if (auto structType = llvm::dyn_cast<StructType>(*this))
123  return structType.getNumElements();
124  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
125  return vectorType.getNumElements();
126  if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
127  return tensorArmType.getNumElements();
128  if (llvm::isa<CooperativeMatrixType>(*this)) {
129  llvm_unreachable(
130  "invalid to query number of elements of spirv Cooperative Matrix type");
131  }
132  if (llvm::isa<RuntimeArrayType>(*this)) {
133  llvm_unreachable(
134  "invalid to query number of elements of spirv::RuntimeArray type");
135  }
136  llvm_unreachable("invalid composite type");
137 }
138 
140  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142 
145  std::optional<StorageClass> storage) {
146  TypeSwitch<Type>(*this)
148  StructType>(
149  [&](auto type) { type.getExtensions(extensions, storage); })
150  .Case<VectorType>([&](VectorType type) {
151  return llvm::cast<ScalarType>(type.getElementType())
152  .getExtensions(extensions, storage);
153  })
154  .Case<TensorArmType>([&](TensorArmType type) {
155  static constexpr Extension ext{Extension::SPV_ARM_tensors};
156  extensions.push_back(ext);
157  return llvm::cast<ScalarType>(type.getElementType())
158  .getExtensions(extensions, storage);
159  })
160 
161  .Default([](Type) { llvm_unreachable("invalid composite type"); });
162 }
163 
166  std::optional<StorageClass> storage) {
167  TypeSwitch<Type>(*this)
169  StructType>(
170  [&](auto type) { type.getCapabilities(capabilities, storage); })
171  .Case<VectorType>([&](VectorType type) {
172  auto vecSize = getNumElements();
173  if (vecSize == 8 || vecSize == 16) {
174  static const Capability caps[] = {Capability::Vector16};
175  ArrayRef<Capability> ref(caps, std::size(caps));
176  capabilities.push_back(ref);
177  }
178  return llvm::cast<ScalarType>(type.getElementType())
179  .getCapabilities(capabilities, storage);
180  })
181  .Case<TensorArmType>([&](TensorArmType type) {
182  static constexpr Capability cap{Capability::TensorsARM};
183  capabilities.push_back(cap);
184  return llvm::cast<ScalarType>(type.getElementType())
185  .getCapabilities(capabilities, storage);
186  })
187  .Default([](Type) { llvm_unreachable("invalid composite type"); });
188 }
189 
190 std::optional<int64_t> CompositeType::getSizeInBytes() {
191  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
192  return arrayType.getSizeInBytes();
193  if (auto structType = llvm::dyn_cast<StructType>(*this))
194  return structType.getSizeInBytes();
195  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
196  std::optional<int64_t> elementSize =
197  llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
198  if (!elementSize)
199  return std::nullopt;
200  return *elementSize * vectorType.getNumElements();
201  }
202  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
203  std::optional<int64_t> elementSize =
204  llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
205  if (!elementSize)
206  return std::nullopt;
207  return *elementSize * tensorArmType.getNumElements();
208  }
209  return std::nullopt;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // CooperativeMatrixType
214 //===----------------------------------------------------------------------===//
215 
217  // In the specification dimensions of the Cooperative Matrix are 32-bit
218  // integers --- the initial implementation kept those values as such. However,
219  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
220  // as 32-bits and expose it as int64_t through `getShape`, however, this
221  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
222  // 32-bits integers would require an extra logic and storage. So, we diverge
223  // from the spec and internally represent the dimensions as 64-bit integers,
224  // so we can easily return an `ArrayRef` from `getShape` without any extra
225  // logic. Alternatively, we could store both rows and columns (both 32-bits)
226  // and shape (64-bits), assigning rows and columns to shape whenever
227  // `getShape` is called. This would be at the cost of extra logic and storage.
228  // Note: Because `ArrayRef` is returned we cannot construct an object in
229  // `getShape` on the fly.
230  using KeyTy =
231  std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
232 
234  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
235  return new (allocator.allocate<CooperativeMatrixTypeStorage>())
237  }
238 
239  bool operator==(const KeyTy &key) const {
240  return key == KeyTy(elementType, shape[0], shape[1], scope, use);
241  }
242 
244  : elementType(std::get<0>(key)),
245  shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
246  use(std::get<4>(key)) {}
247 
249  // [#rows, #columns]
250  std::array<int64_t, 2> shape;
251  Scope scope;
252  CooperativeMatrixUseKHR use;
253 };
254 
256  uint32_t rows,
257  uint32_t columns, Scope scope,
258  CooperativeMatrixUseKHR use) {
259  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
260  use);
261 }
262 
264  return getImpl()->elementType;
265 }
266 
268  assert(getImpl()->shape[0] != ShapedType::kDynamic);
269  return static_cast<uint32_t>(getImpl()->shape[0]);
270 }
271 
273  assert(getImpl()->shape[1] != ShapedType::kDynamic);
274  return static_cast<uint32_t>(getImpl()->shape[1]);
275 }
276 
278  return getImpl()->shape;
279 }
280 
281 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
282 
283 CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284  return getImpl()->use;
285 }
286 
289  std::optional<StorageClass> storage) {
290  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
291  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292  extensions.push_back(exts);
293 }
294 
297  std::optional<StorageClass> storage) {
298  llvm::cast<SPIRVType>(getElementType())
299  .getCapabilities(capabilities, storage);
300  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
301  capabilities.push_back(caps);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ImageType
306 //===----------------------------------------------------------------------===//
307 
308 template <typename T>
309 static constexpr unsigned getNumBits() {
310  return 0;
311 }
312 template <>
313 constexpr unsigned getNumBits<Dim>() {
314  static_assert((1 << 3) > getMaxEnumValForDim(),
315  "Not enough bits to encode Dim value");
316  return 3;
317 }
318 template <>
319 constexpr unsigned getNumBits<ImageDepthInfo>() {
320  static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
321  "Not enough bits to encode ImageDepthInfo value");
322  return 2;
323 }
324 template <>
325 constexpr unsigned getNumBits<ImageArrayedInfo>() {
326  static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
327  "Not enough bits to encode ImageArrayedInfo value");
328  return 1;
329 }
330 template <>
331 constexpr unsigned getNumBits<ImageSamplingInfo>() {
332  static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
333  "Not enough bits to encode ImageSamplingInfo value");
334  return 1;
335 }
336 template <>
337 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
338  static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
339  "Not enough bits to encode ImageSamplerUseInfo value");
340  return 2;
341 }
342 template <>
343 constexpr unsigned getNumBits<ImageFormat>() {
344  static_assert((1 << 6) > getMaxEnumValForImageFormat(),
345  "Not enough bits to encode ImageFormat value");
346  return 6;
347 }
348 
350 public:
351  using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
352  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
353 
355  const KeyTy &key) {
356  return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
357  }
358 
359  bool operator==(const KeyTy &key) const {
360  return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
361  samplerUseInfo, format);
362  }
363 
365  : elementType(std::get<0>(key)), dim(std::get<1>(key)),
366  depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
367  samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
368  format(std::get<6>(key)) {}
369 
377 };
378 
379 ImageType
380 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
381  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
382  value) {
383  return Base::get(std::get<0>(value).getContext(), value);
384 }
385 
386 Type ImageType::getElementType() const { return getImpl()->elementType; }
387 
388 Dim ImageType::getDim() const { return getImpl()->dim; }
389 
390 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
391 
392 ImageArrayedInfo ImageType::getArrayedInfo() const {
393  return getImpl()->arrayedInfo;
394 }
395 
396 ImageSamplingInfo ImageType::getSamplingInfo() const {
397  return getImpl()->samplingInfo;
398 }
399 
400 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
401  return getImpl()->samplerUseInfo;
402 }
403 
404 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
405 
407  std::optional<StorageClass>) {
408  // Image types do not require extra extensions thus far.
409 }
410 
413  std::optional<StorageClass>) {
414  if (auto dimCaps = spirv::getCapabilities(getDim()))
415  capabilities.push_back(*dimCaps);
416 
417  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
418  capabilities.push_back(*fmtCaps);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // PointerType
423 //===----------------------------------------------------------------------===//
424 
426  // (Type, StorageClass) as the key: Type stored in this struct, and
427  // StorageClass stored as TypeStorage's subclass data.
428  using KeyTy = std::pair<Type, StorageClass>;
429 
431  const KeyTy &key) {
432  return new (allocator.allocate<PointerTypeStorage>())
433  PointerTypeStorage(key);
434  }
435 
436  bool operator==(const KeyTy &key) const {
437  return key == KeyTy(pointeeType, storageClass);
438  }
439 
441  : pointeeType(key.first), storageClass(key.second) {}
442 
444  StorageClass storageClass;
445 };
446 
447 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
448  return Base::get(pointeeType.getContext(), pointeeType, storageClass);
449 }
450 
451 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
452 
453 StorageClass PointerType::getStorageClass() const {
454  return getImpl()->storageClass;
455 }
456 
458  std::optional<StorageClass> storage) {
459  // Use this pointer type's storage class because this pointer indicates we are
460  // using the pointee type in that specific storage class.
461  llvm::cast<SPIRVType>(getPointeeType())
462  .getExtensions(extensions, getStorageClass());
463 
464  if (auto scExts = spirv::getExtensions(getStorageClass()))
465  extensions.push_back(*scExts);
466 }
467 
470  std::optional<StorageClass> storage) {
471  // Use this pointer type's storage class because this pointer indicates we are
472  // using the pointee type in that specific storage class.
473  llvm::cast<SPIRVType>(getPointeeType())
474  .getCapabilities(capabilities, getStorageClass());
475 
476  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
477  capabilities.push_back(*scCaps);
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // RuntimeArrayType
482 //===----------------------------------------------------------------------===//
483 
485  using KeyTy = std::pair<Type, unsigned>;
486 
488  const KeyTy &key) {
489  return new (allocator.allocate<RuntimeArrayTypeStorage>())
491  }
492 
493  bool operator==(const KeyTy &key) const {
494  return key == KeyTy(elementType, stride);
495  }
496 
498  : elementType(key.first), stride(key.second) {}
499 
501  unsigned stride;
502 };
503 
505  return Base::get(elementType.getContext(), elementType, /*stride=*/0);
506 }
507 
508 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
509  return Base::get(elementType.getContext(), elementType, stride);
510 }
511 
512 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
513 
514 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
515 
518  std::optional<StorageClass> storage) {
519  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
520 }
521 
524  std::optional<StorageClass> storage) {
525  {
526  static const Capability caps[] = {Capability::Shader};
527  ArrayRef<Capability> ref(caps, std::size(caps));
528  capabilities.push_back(ref);
529  }
530  llvm::cast<SPIRVType>(getElementType())
531  .getCapabilities(capabilities, storage);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // ScalarType
536 //===----------------------------------------------------------------------===//
537 
539  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
540  return isValid(floatType);
541  }
542  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
543  return isValid(intType);
544  }
545  return false;
546 }
547 
548 bool ScalarType::isValid(FloatType type) {
549  return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
550 }
551 
552 bool ScalarType::isValid(IntegerType type) {
553  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
554 }
555 
557  std::optional<StorageClass> storage) {
558  if (isa<BFloat16Type>(*this)) {
559  static const Extension ext = Extension::SPV_KHR_bfloat16;
560  extensions.push_back(ext);
561  }
562 
563  // 8- or 16-bit integer/floating-point numbers will require extra extensions
564  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
565  // SPV_KHR_8bit_storage for more details.
566  if (!storage)
567  return;
568 
569  switch (*storage) {
570  case StorageClass::PushConstant:
571  case StorageClass::StorageBuffer:
572  case StorageClass::Uniform:
573  if (getIntOrFloatBitWidth() == 8) {
574  static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575  ArrayRef<Extension> ref(exts, std::size(exts));
576  extensions.push_back(ref);
577  }
578  [[fallthrough]];
579  case StorageClass::Input:
580  case StorageClass::Output:
581  if (getIntOrFloatBitWidth() == 16) {
582  static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583  ArrayRef<Extension> ref(exts, std::size(exts));
584  extensions.push_back(ref);
585  }
586  break;
587  default:
588  break;
589  }
590 }
591 
594  std::optional<StorageClass> storage) {
595  unsigned bitwidth = getIntOrFloatBitWidth();
596 
597  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
598  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
599  // SPV_KHR_8bit_storage for more details.
600 
601 #define STORAGE_CASE(storage, cap8, cap16) \
602  case StorageClass::storage: { \
603  if (bitwidth == 8) { \
604  static const Capability caps[] = {Capability::cap8}; \
605  ArrayRef<Capability> ref(caps, std::size(caps)); \
606  capabilities.push_back(ref); \
607  return; \
608  } \
609  if (bitwidth == 16) { \
610  static const Capability caps[] = {Capability::cap16}; \
611  ArrayRef<Capability> ref(caps, std::size(caps)); \
612  capabilities.push_back(ref); \
613  return; \
614  } \
615  /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
616  /* storage classes. Fall through to the next section. */ \
617  } break
618 
619  // This part only handles the cases where special bitwidths appearing in
620  // interface storage classes.
621  if (storage) {
622  switch (*storage) {
623  STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
624  STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
625  StorageBuffer16BitAccess);
626  STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
627  StorageUniform16);
628  case StorageClass::Input:
629  case StorageClass::Output: {
630  if (bitwidth == 16) {
631  static const Capability caps[] = {Capability::StorageInputOutput16};
632  ArrayRef<Capability> ref(caps, std::size(caps));
633  capabilities.push_back(ref);
634  return;
635  }
636  break;
637  }
638  default:
639  break;
640  }
641  }
642 #undef STORAGE_CASE
643 
644  // For other non-interface storage classes, require a different set of
645  // capabilities for special bitwidths.
646 
647 #define WIDTH_CASE(type, width) \
648  case width: { \
649  static const Capability caps[] = {Capability::type##width}; \
650  ArrayRef<Capability> ref(caps, std::size(caps)); \
651  capabilities.push_back(ref); \
652  } break
653 
654  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
655  switch (bitwidth) {
656  WIDTH_CASE(Int, 8);
657  WIDTH_CASE(Int, 16);
658  WIDTH_CASE(Int, 64);
659  case 1:
660  case 32:
661  break;
662  default:
663  llvm_unreachable("invalid bitwidth to getCapabilities");
664  }
665  } else {
666  assert(llvm::isa<FloatType>(*this));
667  switch (bitwidth) {
668  case 16: {
669  if (isa<BFloat16Type>(*this)) {
670  static const Capability cap = Capability::BFloat16TypeKHR;
671  capabilities.push_back(cap);
672  } else {
673  static const Capability cap = Capability::Float16;
674  capabilities.push_back(cap);
675  }
676  break;
677  }
678  WIDTH_CASE(Float, 64);
679  case 32:
680  break;
681  default:
682  llvm_unreachable("invalid bitwidth to getCapabilities");
683  }
684  }
685 
686 #undef WIDTH_CASE
687 }
688 
689 std::optional<int64_t> ScalarType::getSizeInBytes() {
690  auto bitWidth = getIntOrFloatBitWidth();
691  // According to the SPIR-V spec:
692  // "There is no physical size or bit pattern defined for values with boolean
693  // type. If they are stored (in conjunction with OpVariable), they can only
694  // be used with logical addressing operations, not physical, and only with
695  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
696  // Private, Function, Input, and Output."
697  if (bitWidth == 1)
698  return std::nullopt;
699  return bitWidth / 8;
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // SPIRVType
704 //===----------------------------------------------------------------------===//
705 
707  // Allow SPIR-V dialect types
708  if (llvm::isa<SPIRVDialect>(type.getDialect()))
709  return true;
710  if (llvm::isa<ScalarType>(type))
711  return true;
712  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
713  return CompositeType::isValid(vectorType);
714  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
715  return llvm::isa<ScalarType>(tensorArmType.getElementType());
716  return false;
717 }
718 
720  return isIntOrFloat() || llvm::isa<VectorType>(*this);
721 }
722 
724  std::optional<StorageClass> storage) {
725  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
726  scalarType.getExtensions(extensions, storage);
727  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
728  compositeType.getExtensions(extensions, storage);
729  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
730  imageType.getExtensions(extensions, storage);
731  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
732  sampledImageType.getExtensions(extensions, storage);
733  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
734  matrixType.getExtensions(extensions, storage);
735  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
736  ptrType.getExtensions(extensions, storage);
737  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
738  tensorArmType.getExtensions(extensions, storage);
739  } else {
740  llvm_unreachable("invalid SPIR-V Type to getExtensions");
741  }
742 }
743 
746  std::optional<StorageClass> storage) {
747  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
748  scalarType.getCapabilities(capabilities, storage);
749  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
750  compositeType.getCapabilities(capabilities, storage);
751  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
752  imageType.getCapabilities(capabilities, storage);
753  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
754  sampledImageType.getCapabilities(capabilities, storage);
755  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
756  matrixType.getCapabilities(capabilities, storage);
757  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
758  ptrType.getCapabilities(capabilities, storage);
759  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
760  tensorArmType.getCapabilities(capabilities, storage);
761  } else {
762  llvm_unreachable("invalid SPIR-V Type to getCapabilities");
763  }
764 }
765 
766 std::optional<int64_t> SPIRVType::getSizeInBytes() {
767  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
768  return scalarType.getSizeInBytes();
769  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
770  return compositeType.getSizeInBytes();
771  return std::nullopt;
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // SampledImageType
776 //===----------------------------------------------------------------------===//
778  using KeyTy = Type;
779 
780  SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
781 
782  bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
783 
785  const KeyTy &key) {
786  return new (allocator.allocate<SampledImageTypeStorage>())
788  }
789 
791 };
792 
794  return Base::get(imageType.getContext(), imageType);
795 }
796 
799  Type imageType) {
800  return Base::getChecked(emitError, imageType.getContext(), imageType);
801 }
802 
803 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
804 
805 LogicalResult
807  Type imageType) {
808  if (!llvm::isa<ImageType>(imageType))
809  return emitError() << "expected image type";
810 
811  return success();
812 }
813 
816  std::optional<StorageClass> storage) {
817  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
818 }
819 
822  std::optional<StorageClass> storage) {
823  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // StructType
828 //===----------------------------------------------------------------------===//
829 
830 /// Type storage for SPIR-V structure types:
831 ///
832 /// Structures are uniqued using:
833 /// - for identified structs:
834 /// - a string identifier;
835 /// - for literal structs:
836 /// - a list of member types;
837 /// - a list of member offset info;
838 /// - a list of member decoration info;
839 /// - a list of struct decoration info.
840 ///
841 /// Identified structures only have a mutable component consisting of:
842 /// - a list of member types;
843 /// - a list of member offset info;
844 /// - a list of member decoration info;
845 /// - a list of struct decoration info.
847  /// Construct a storage object for an identified struct type. A struct type
848  /// associated with such storage must call StructType::trySetBody(...) later
849  /// in order to mutate the storage object providing the actual content.
850  StructTypeStorage(StringRef identifier)
851  : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
852  numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
853  numStructDecorations(0), structDecorationsInfo(nullptr),
854  identifier(identifier) {}
855 
856  /// Construct a storage object for a literal struct type. A struct type
857  /// associated with such storage is immutable.
859  unsigned numMembers, Type const *memberTypes,
860  StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
861  StructType::MemberDecorationInfo const *memberDecorationsInfo,
862  unsigned numStructDecorations,
863  StructType::StructDecorationInfo const *structDecorationsInfo)
864  : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
865  numMembers(numMembers), numMemberDecorations(numMemberDecorations),
866  memberDecorationsInfo(memberDecorationsInfo),
867  numStructDecorations(numStructDecorations),
868  structDecorationsInfo(structDecorationsInfo) {}
869 
870  /// A storage key is divided into 2 parts:
871  /// - for identified structs:
872  /// - a StringRef representing the struct identifier;
873  /// - for literal structs:
874  /// - an ArrayRef<Type> for member types;
875  /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
876  /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
877  /// info;
878  /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
879  /// info.
880  ///
881  /// An identified struct type is uniqued only by the first part (field 0)
882  /// of the key.
883  ///
884  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
885  /// and 4) of the key. The identifier field (field 0) must be empty.
886  using KeyTy =
887  std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
890 
891  /// For identified structs, return true if the given key contains the same
892  /// identifier.
893  ///
894  /// For literal structs, return true if the given key contains a matching list
895  /// of member types + offset info + decoration info.
896  bool operator==(const KeyTy &key) const {
897  if (isIdentified()) {
898  // Identified types are uniqued by their identifier.
899  return getIdentifier() == std::get<0>(key);
900  }
901 
902  return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
903  getMemberDecorationsInfo(), getStructDecorationsInfo());
904  }
905 
906  /// If the given key contains a non-empty identifier, this method constructs
907  /// an identified struct and leaves the rest of the struct type data to be set
908  /// through a later call to StructType::trySetBody(...).
909  ///
910  /// If, on the other hand, the key contains an empty identifier, a literal
911  /// struct is constructed using the other fields of the key.
913  const KeyTy &key) {
914  StringRef keyIdentifier = std::get<0>(key);
915 
916  if (!keyIdentifier.empty()) {
917  StringRef identifier = allocator.copyInto(keyIdentifier);
918 
919  // Identified StructType body/members will be set through trySetBody(...)
920  // later.
921  return new (allocator.allocate<StructTypeStorage>())
922  StructTypeStorage(identifier);
923  }
924 
925  ArrayRef<Type> keyTypes = std::get<1>(key);
926 
927  // Copy the member type and layout information into the bump pointer
928  const Type *typesList = nullptr;
929  if (!keyTypes.empty()) {
930  typesList = allocator.copyInto(keyTypes).data();
931  }
932 
933  const StructType::OffsetInfo *offsetInfoList = nullptr;
934  if (!std::get<2>(key).empty()) {
935  ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
936  assert(keyOffsetInfo.size() == keyTypes.size() &&
937  "size of offset information must be same as the size of number of "
938  "elements");
939  offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
940  }
941 
942  const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
943  unsigned numMemberDecorations = 0;
944  if (!std::get<3>(key).empty()) {
945  auto keyMemberDecorations = std::get<3>(key);
946  numMemberDecorations = keyMemberDecorations.size();
947  memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
948  }
949 
950  const StructType::StructDecorationInfo *structDecorationList = nullptr;
951  unsigned numStructDecorations = 0;
952  if (!std::get<4>(key).empty()) {
953  auto keyStructDecorations = std::get<4>(key);
954  numStructDecorations = keyStructDecorations.size();
955  structDecorationList = allocator.copyInto(keyStructDecorations).data();
956  }
957 
958  return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
959  keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
960  memberDecorationList, numStructDecorations, structDecorationList);
961  }
962 
964  return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
965  }
966 
968  if (offsetInfo) {
969  return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
970  }
971  return {};
972  }
973 
975  if (memberDecorationsInfo) {
976  return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
977  numMemberDecorations);
978  }
979  return {};
980  }
981 
983  if (structDecorationsInfo)
984  return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
985  numStructDecorations);
986  return {};
987  }
988 
989  StringRef getIdentifier() const { return identifier; }
990 
991  bool isIdentified() const { return !identifier.empty(); }
992 
993  /// Sets the struct type content for identified structs. Calling this method
994  /// is only valid for identified structs.
995  ///
996  /// Fails under the following conditions:
997  /// - If called for a literal struct;
998  /// - If called for an identified struct whose body was set before (through a
999  /// call to this method) but with different contents from the passed
1000  /// arguments.
1001  LogicalResult
1002  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
1003  ArrayRef<StructType::OffsetInfo> structOffsetInfo,
1004  ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
1005  ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
1006  if (!isIdentified())
1007  return failure();
1008 
1009  if (memberTypesAndIsBodySet.getInt() &&
1010  (getMemberTypes() != structMemberTypes ||
1011  getOffsetInfo() != structOffsetInfo ||
1012  getMemberDecorationsInfo() != structMemberDecorationInfo ||
1013  getStructDecorationsInfo() != structDecorationInfo))
1014  return failure();
1015 
1016  memberTypesAndIsBodySet.setInt(true);
1017  numMembers = structMemberTypes.size();
1018 
1019  // Copy the member type and layout information into the bump pointer.
1020  if (!structMemberTypes.empty())
1021  memberTypesAndIsBodySet.setPointer(
1022  allocator.copyInto(structMemberTypes).data());
1023 
1024  if (!structOffsetInfo.empty()) {
1025  assert(structOffsetInfo.size() == structMemberTypes.size() &&
1026  "size of offset information must be same as the size of number of "
1027  "elements");
1028  offsetInfo = allocator.copyInto(structOffsetInfo).data();
1029  }
1030 
1031  if (!structMemberDecorationInfo.empty()) {
1032  numMemberDecorations = structMemberDecorationInfo.size();
1033  memberDecorationsInfo =
1034  allocator.copyInto(structMemberDecorationInfo).data();
1035  }
1036 
1037  if (!structDecorationInfo.empty()) {
1038  numStructDecorations = structDecorationInfo.size();
1039  structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
1040  }
1041 
1042  return success();
1043  }
1044 
1045  llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1047  unsigned numMembers;
1052  StringRef identifier;
1053 };
1054 
1055 StructType
1059  ArrayRef<StructType::StructDecorationInfo> structDecorations) {
1060  assert(!memberTypes.empty() && "Struct needs at least one member type");
1061  // Sort the decorations.
1062  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
1063  memberDecorations);
1064  llvm::array_pod_sort(sortedMemberDecorations.begin(),
1065  sortedMemberDecorations.end());
1066  SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
1067  structDecorations);
1068  llvm::array_pod_sort(sortedStructDecorations.begin(),
1069  sortedStructDecorations.end());
1070 
1071  return Base::get(memberTypes.vec().front().getContext(),
1072  /*identifier=*/StringRef(), memberTypes, offsetInfo,
1073  sortedMemberDecorations, sortedStructDecorations);
1074 }
1075 
1077  StringRef identifier) {
1078  assert(!identifier.empty() &&
1079  "StructType identifier must be non-empty string");
1080 
1081  return Base::get(context, identifier, ArrayRef<Type>(),
1085 }
1086 
1087 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1088  StructType newStructType = Base::get(
1089  context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1092  // Set an empty body in case this is a identified struct.
1093  if (newStructType.isIdentified() &&
1094  failed(newStructType.trySetBody(
1098  return StructType();
1099 
1100  return newStructType;
1101 }
1102 
1103 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1104 
1105 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1106 
1107 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1108 
1109 Type StructType::getElementType(unsigned index) const {
1110  assert(getNumElements() > index && "member index out of range");
1111  return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1112 }
1113 
1115  return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1116  getNumElements());
1117 }
1118 
1119 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1120 
1121 bool StructType::hasDecoration(spirv::Decoration decoration) const {
1123  getImpl()->getStructDecorationsInfo())
1124  if (info.decoration == decoration)
1125  return true;
1126 
1127  return false;
1128 }
1129 
1130 uint64_t StructType::getMemberOffset(unsigned index) const {
1131  assert(getNumElements() > index && "member index out of range");
1132  return getImpl()->offsetInfo[index];
1133 }
1134 
1137  const {
1138  memberDecorations.clear();
1139  auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1140  memberDecorations.append(implMemberDecorations.begin(),
1141  implMemberDecorations.end());
1142 }
1143 
1145  unsigned index,
1146  SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1147  assert(getNumElements() > index && "member index out of range");
1148  auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1149  decorationsInfo.clear();
1150  for (const auto &memberDecoration : memberDecorations) {
1151  if (memberDecoration.memberIndex == index) {
1152  decorationsInfo.push_back(memberDecoration);
1153  }
1154  if (memberDecoration.memberIndex > index) {
1155  // Early exit since the decorations are stored sorted.
1156  return;
1157  }
1158  }
1159 }
1160 
1163  const {
1164  structDecorations.clear();
1165  auto implDecorations = getImpl()->getStructDecorationsInfo();
1166  structDecorations.append(implDecorations.begin(), implDecorations.end());
1167 }
1168 
1169 LogicalResult
1171  ArrayRef<OffsetInfo> offsetInfo,
1172  ArrayRef<MemberDecorationInfo> memberDecorations,
1173  ArrayRef<StructDecorationInfo> structDecorations) {
1174  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
1175  structDecorations);
1176 }
1177 
1179  std::optional<StorageClass> storage) {
1180  for (Type elementType : getElementTypes())
1181  llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1182 }
1183 
1186  std::optional<StorageClass> storage) {
1187  for (Type elementType : getElementTypes())
1188  llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
1189 }
1190 
1191 llvm::hash_code spirv::hash_value(
1192  const StructType::MemberDecorationInfo &memberDecorationInfo) {
1193  return llvm::hash_combine(memberDecorationInfo.memberIndex,
1194  memberDecorationInfo.decoration);
1195 }
1196 
1197 llvm::hash_code spirv::hash_value(
1198  const StructType::StructDecorationInfo &structDecorationInfo) {
1199  return llvm::hash_value(structDecorationInfo.decoration);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // MatrixType
1204 //===----------------------------------------------------------------------===//
1205 
1207  MatrixTypeStorage(Type columnType, uint32_t columnCount)
1208  : columnType(columnType), columnCount(columnCount) {}
1209 
1210  using KeyTy = std::tuple<Type, uint32_t>;
1211 
1213  const KeyTy &key) {
1214 
1215  // Initialize the memory using placement new.
1216  return new (allocator.allocate<MatrixTypeStorage>())
1217  MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1218  }
1219 
1220  bool operator==(const KeyTy &key) const {
1221  return key == KeyTy(columnType, columnCount);
1222  }
1223 
1225  const uint32_t columnCount;
1226 };
1227 
1228 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1229  return Base::get(columnType.getContext(), columnType, columnCount);
1230 }
1231 
1233  Type columnType, uint32_t columnCount) {
1234  return Base::getChecked(emitError, columnType.getContext(), columnType,
1235  columnCount);
1236 }
1237 
1238 LogicalResult
1240  Type columnType, uint32_t columnCount) {
1241  if (columnCount < 2 || columnCount > 4)
1242  return emitError() << "matrix can have 2, 3, or 4 columns only";
1243 
1244  if (!isValidColumnType(columnType))
1245  return emitError() << "matrix columns must be vectors of floats";
1246 
1247  /// The underlying vectors (columns) must be of size 2, 3, or 4
1248  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1249  if (columnShape.size() != 1)
1250  return emitError() << "matrix columns must be 1D vectors";
1251 
1252  if (columnShape[0] < 2 || columnShape[0] > 4)
1253  return emitError() << "matrix columns must be of size 2, 3, or 4";
1254 
1255  return success();
1256 }
1257 
1258 /// Returns true if the matrix elements are vectors of float elements
1260  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1261  if (llvm::isa<FloatType>(vectorType.getElementType()))
1262  return true;
1263  }
1264  return false;
1265 }
1266 
1267 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1268 
1270  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1271 }
1272 
1273 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1274 
1275 unsigned MatrixType::getNumRows() const {
1276  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1277 }
1278 
1279 unsigned MatrixType::getNumElements() const {
1280  return (getImpl()->columnCount) * getNumRows();
1281 }
1282 
1284  std::optional<StorageClass> storage) {
1285  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1286 }
1287 
1290  std::optional<StorageClass> storage) {
1291  {
1292  static const Capability caps[] = {Capability::Matrix};
1293  ArrayRef<Capability> ref(caps, std::size(caps));
1294  capabilities.push_back(ref);
1295  }
1296  // Add any capabilities associated with the underlying vectors (i.e., columns)
1297  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // TensorArmType
1302 //===----------------------------------------------------------------------===//
1303 
1305  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1306 
1308  const KeyTy &key) {
1309  auto [shape, elementType] = key;
1310  shape = allocator.copyInto(shape);
1311  return new (allocator.allocate<TensorArmTypeStorage>())
1312  TensorArmTypeStorage(shape, elementType);
1313  }
1314 
1315  static llvm::hash_code hashKey(const KeyTy &key) {
1316  auto [shape, elementType] = key;
1317  return llvm::hash_combine(shape, elementType);
1318  }
1319 
1320  bool operator==(const KeyTy &key) const {
1321  return key == KeyTy(shape, elementType);
1322  }
1323 
1325  : shape(shape), elementType(elementType) {}
1326 
1329 };
1330 
1332  return Base::get(elementType.getContext(), shape, elementType);
1333 }
1334 
1336  Type elementType) const {
1337  return TensorArmType::get(shape.value_or(getShape()), elementType);
1338 }
1339 
1340 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1342 
1345  std::optional<StorageClass> storage) {
1346 
1347  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
1348  static constexpr Extension ext{Extension::SPV_ARM_tensors};
1349  extensions.push_back(ext);
1350 }
1351 
1354  std::optional<StorageClass> storage) {
1355  llvm::cast<SPIRVType>(getElementType())
1356  .getCapabilities(capabilities, storage);
1357  static constexpr Capability cap{Capability::TensorsARM};
1358  capabilities.push_back(cap);
1359 }
1360 
1361 LogicalResult
1363  ArrayRef<int64_t> shape, Type elementType) {
1364  if (llvm::is_contained(shape, 0))
1365  return emitError() << "arm.tensor do not support dimensions = 0";
1366  if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
1367  llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
1368  return emitError()
1369  << "arm.tensor shape dimensions must be either fully dynamic or "
1370  "completed shaped";
1371  return success();
1372 }
1373 
1374 //===----------------------------------------------------------------------===//
1375 // SPIR-V Dialect
1376 //===----------------------------------------------------------------------===//
1377 
1378 void SPIRVDialect::registerTypes() {
1381 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
Definition: SPIRVTypes.cpp:337
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
Definition: SPIRVTypes.cpp:343
static constexpr unsigned getNumBits()
Definition: SPIRVTypes.cpp:309
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
Definition: SPIRVTypes.cpp:325
constexpr unsigned getNumBits< ImageSamplingInfo >()
Definition: SPIRVTypes.cpp:331
constexpr unsigned getNumBits< Dim >()
Definition: SPIRVTypes.cpp:313
constexpr unsigned getNumBits< ImageDepthInfo >()
Definition: SPIRVTypes.cpp:319
int64_t rows
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base storage class appearing in a Type.
Definition: TypeSupport.h:166
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
Definition: SPIRVTypes.cpp:80
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:68
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:73
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:164
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:190
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
Definition: SPIRVTypes.cpp:139
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:143
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:117
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:100
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:106
static bool classof(Type type)
Definition: SPIRVTypes.cpp:92
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:281
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:267
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:272
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:255
ArrayRef< int64_t > getShape() const
Definition: SPIRVTypes.cpp:277
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:287
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:295
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:283
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:390
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:392
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:404
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:400
Type getElementType() const
Definition: SPIRVTypes.cpp:386
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:396
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:411
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:406
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:457
Type getPointeeType() const
Definition: SPIRVTypes.cpp:451
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:453
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:468
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:522
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:514
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:516
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:766
static bool classof(Type type)
Definition: SPIRVTypes.cpp:706
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:744
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
Definition: SPIRVTypes.cpp:723
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:806
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:820
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Definition: SPIRVTypes.cpp:798
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:814
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:793
static bool classof(Type type)
Definition: SPIRVTypes.cpp:538
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:556
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:548
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:592
std::optional< int64_t > getSizeInBytes()
Definition: SPIRVTypes.cpp:689
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition: SPIRVTypes.h:524
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
ArrayRef< int64_t > getShape() const
TensorArmType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::tuple< Type, unsigned, unsigned > KeyTy
Definition: SPIRVTypes.cpp:30
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:32
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:37
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
Definition: SPIRVTypes.cpp:231
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:234
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:359
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:354
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
Definition: SPIRVTypes.cpp:352
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:436
std::pair< Type, StorageClass > KeyTy
Definition: SPIRVTypes.cpp:428
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:430
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:487
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:493
bool operator==(const KeyTy &key) const
Definition: SPIRVTypes.cpp:782
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Definition: SPIRVTypes.cpp:784
Type storage for SPIR-V structure types:
Definition: SPIRVTypes.cpp:846
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
Definition: SPIRVTypes.cpp:967
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
Definition: SPIRVTypes.cpp:896
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
Definition: SPIRVTypes.cpp:912
ArrayRef< Type > getMemberTypes() const
Definition: SPIRVTypes.cpp:963
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
Definition: SPIRVTypes.cpp:850
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo >, ArrayRef< StructType::StructDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
Definition: SPIRVTypes.cpp:889
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
Definition: SPIRVTypes.cpp:974
StructType::MemberDecorationInfo const * memberDecorationsInfo
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo, unsigned numStructDecorations, StructType::StructDecorationInfo const *structDecorationsInfo)
Construct a storage object for a literal struct type.
Definition: SPIRVTypes.cpp:858
StructType::StructDecorationInfo const * structDecorationsInfo
ArrayRef< StructType::StructDecorationInfo > getStructDecorationsInfo() const
Definition: SPIRVTypes.cpp:982
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo, ArrayRef< StructType::StructDecorationInfo > structDecorationInfo)
Sets the struct type content for identified structs.
static TensorArmTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
static llvm::hash_code hashKey(const KeyTy &key)
std::tuple< ArrayRef< int64_t >, Type > KeyTy
TensorArmTypeStorage(ArrayRef< int64_t > shape, Type elementType)
bool operator==(const KeyTy &key) const