MLIR  15.0.0git
SPIRVTypes.h
Go to the documentation of this file.
1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
15 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/TypeSupport.h"
21 #include "mlir/IR/Types.h"
22 
23 #include <tuple>
24 
25 namespace mlir {
26 namespace spirv {
27 
28 namespace detail {
29 struct ArrayTypeStorage;
30 struct CooperativeMatrixTypeStorage;
31 struct ImageTypeStorage;
32 struct MatrixTypeStorage;
33 struct PointerTypeStorage;
34 struct RuntimeArrayTypeStorage;
35 struct SampledImageTypeStorage;
36 struct StructTypeStorage;
37 
38 } // namespace detail
39 
40 // Base SPIR-V type for providing availability queries.
41 class SPIRVType : public Type {
42 public:
43  using Type::Type;
44 
45  static bool classof(Type type);
46 
47  bool isScalarOrVector();
48 
49  /// The extension requirements for each type are following the
50  /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51  /// convention.
53 
54  /// Appends to `extensions` the extensions needed for this type to appear in
55  /// the given `storage` class. This method does not guarantee the uniqueness
56  /// of extensions; the same extension may be appended multiple times.
57  void getExtensions(ExtensionArrayRefVector &extensions,
58  Optional<StorageClass> storage = llvm::None);
59 
60  /// The capability requirements for each type are following the
61  /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
62  /// convention.
64 
65  /// Appends to `capabilities` the capabilities needed for this type to appear
66  /// in the given `storage` class. This method does not guarantee the
67  /// uniqueness of capabilities; the same capability may be appended multiple
68  /// times.
69  void getCapabilities(CapabilityArrayRefVector &capabilities,
70  Optional<StorageClass> storage = llvm::None);
71 
72  /// Returns the size in bytes for each type. If no size can be calculated,
73  /// returns `llvm::None`. Note that if the type has explicit layout, it is
74  /// also taken into account in calculation.
76 };
77 
78 // SPIR-V scalar type: bool type, integer type, floating point type.
79 class ScalarType : public SPIRVType {
80 public:
81  using SPIRVType::SPIRVType;
82 
83  static bool classof(Type type);
84 
85  /// Returns true if the given integer type is valid for the SPIR-V dialect.
86  static bool isValid(FloatType);
87  /// Returns true if the given float type is valid for the SPIR-V dialect.
88  static bool isValid(IntegerType);
89 
90  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
91  Optional<StorageClass> storage = llvm::None);
92  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
93  Optional<StorageClass> storage = llvm::None);
94 
96 };
97 
98 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
99 class CompositeType : public SPIRVType {
100 public:
101  using SPIRVType::SPIRVType;
102 
103  static bool classof(Type type);
104 
105  /// Returns true if the given vector type is valid for the SPIR-V dialect.
106  static bool isValid(VectorType);
107 
108  /// Return the number of elements of the type. This should only be called if
109  /// hasCompileTimeKnownNumElements is true.
110  unsigned getNumElements() const;
111 
112  Type getElementType(unsigned) const;
113 
114  /// Return true if the number of elements is known at compile time and is not
115  /// implementation dependent.
116  bool hasCompileTimeKnownNumElements() const;
117 
118  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
119  Optional<StorageClass> storage = llvm::None);
120  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
121  Optional<StorageClass> storage = llvm::None);
122 
124 };
125 
126 // SPIR-V array type
127 class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
128  detail::ArrayTypeStorage> {
129 public:
130  using Base::Base;
131 
132  static ArrayType get(Type elementType, unsigned elementCount);
133 
134  /// Returns an array type with the given stride in bytes.
135  static ArrayType get(Type elementType, unsigned elementCount,
136  unsigned stride);
137 
138  unsigned getNumElements() const;
139 
140  Type getElementType() const;
141 
142  /// Returns the array stride in bytes. 0 means no stride decorated on this
143  /// type.
144  unsigned getArrayStride() const;
145 
146  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
147  Optional<StorageClass> storage = llvm::None);
148  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
149  Optional<StorageClass> storage = llvm::None);
150 
151  /// Returns the array size in bytes. Since array type may have an explicit
152  /// stride declaration (in bytes), we also include it in the calculation.
154 };
155 
156 // SPIR-V image type
158  : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
159 public:
160  using Base::Base;
161 
162  static ImageType
163  get(Type elementType, Dim dim,
164  ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
165  ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
166  ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
167  ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
168  ImageFormat format = ImageFormat::Unknown) {
169  return ImageType::get(
170  std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
171  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
172  elementType, dim, depth, arrayed, samplingInfo, samplerUse,
173  format));
174  }
175 
176  static ImageType
177  get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
178  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
179 
180  Type getElementType() const;
181  Dim getDim() const;
182  ImageDepthInfo getDepthInfo() const;
183  ImageArrayedInfo getArrayedInfo() const;
184  ImageSamplingInfo getSamplingInfo() const;
185  ImageSamplerUseInfo getSamplerUseInfo() const;
186  ImageFormat getImageFormat() const;
187  // TODO: Add support for Access qualifier
188 
189  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
190  Optional<StorageClass> storage = llvm::None);
191  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
192  Optional<StorageClass> storage = llvm::None);
193 };
194 
195 // SPIR-V pointer type
196 class PointerType : public Type::TypeBase<PointerType, SPIRVType,
197  detail::PointerTypeStorage> {
198 public:
199  using Base::Base;
200 
201  static PointerType get(Type pointeeType, StorageClass storageClass);
202 
203  Type getPointeeType() const;
204 
205  StorageClass getStorageClass() const;
206 
207  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
208  Optional<StorageClass> storage = llvm::None);
209  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
210  Optional<StorageClass> storage = llvm::None);
211 };
212 
213 // SPIR-V run-time array type
215  : public Type::TypeBase<RuntimeArrayType, SPIRVType,
216  detail::RuntimeArrayTypeStorage> {
217 public:
218  using Base::Base;
219 
220  static RuntimeArrayType get(Type elementType);
221 
222  /// Returns a runtime array type with the given stride in bytes.
223  static RuntimeArrayType get(Type elementType, unsigned stride);
224 
225  Type getElementType() const;
226 
227  /// Returns the array stride in bytes. 0 means no stride decorated on this
228  /// type.
229  unsigned getArrayStride() const;
230 
231  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
232  Optional<StorageClass> storage = llvm::None);
233  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
234  Optional<StorageClass> storage = llvm::None);
235 };
236 
237 // SPIR-V sampled image type
239  : public Type::TypeBase<SampledImageType, SPIRVType,
240  detail::SampledImageTypeStorage> {
241 public:
242  using Base::Base;
243 
244  static SampledImageType get(Type imageType);
245 
246  static SampledImageType
247  getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
248 
250  Type imageType);
251 
252  Type getImageType() const;
253 
254  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
255  Optional<spirv::StorageClass> storage = llvm::None);
256  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
257  Optional<spirv::StorageClass> storage = llvm::None);
258 };
259 
260 /// SPIR-V struct type. Two kinds of struct types are supported:
261 /// - Literal: a literal struct type is uniqued by its fields (types + offset
262 /// info + decoration info).
263 /// - Identified: an indentified struct type is uniqued by its string identifier
264 /// (name). This is useful in representing recursive structs. For example, the
265 /// following C struct:
266 ///
267 /// struct A {
268 /// A* next;
269 /// };
270 ///
271 /// would be represented in MLIR as:
272 ///
273 /// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
274 ///
275 /// In the above, expressing recursive struct types is accomplished by giving a
276 /// recursive struct a unique identified and using that identifier in the struct
277 /// definition for recursive references.
278 class StructType : public Type::TypeBase<StructType, CompositeType,
279  detail::StructTypeStorage> {
280 public:
281  using Base::Base;
282 
283  // Type for specifying the offset of the struct members
284  using OffsetInfo = uint32_t;
285 
286  // Type for specifying the decoration(s) on struct members
288  uint32_t memberIndex : 31;
289  uint32_t hasValue : 1;
290  Decoration decoration;
291  uint32_t decorationValue;
292 
293  MemberDecorationInfo(uint32_t index, uint32_t hasValue,
294  Decoration decoration, uint32_t decorationValue)
295  : memberIndex(index), hasValue(hasValue), decoration(decoration),
296  decorationValue(decorationValue) {}
297 
298  bool operator==(const MemberDecorationInfo &other) const {
299  return (this->memberIndex == other.memberIndex) &&
300  (this->decoration == other.decoration) &&
301  (this->decorationValue == other.decorationValue);
302  }
303 
304  bool operator<(const MemberDecorationInfo &other) const {
305  return this->memberIndex < other.memberIndex ||
306  (this->memberIndex == other.memberIndex &&
307  static_cast<uint32_t>(this->decoration) <
308  static_cast<uint32_t>(other.decoration));
309  }
310  };
311 
312  /// Construct a literal StructType with at least one member.
313  static StructType get(ArrayRef<Type> memberTypes,
314  ArrayRef<OffsetInfo> offsetInfo = {},
315  ArrayRef<MemberDecorationInfo> memberDecorations = {});
316 
317  /// Construct an identified StructType. This creates a StructType whose body
318  /// (member types, offset info, and decorations) is not set yet. A call to
319  /// StructType::trySetBody(...) must follow when the StructType contents are
320  /// available (e.g. parsed or deserialized).
321  ///
322  /// Note: If another thread creates (or had already created) a struct with the
323  /// same identifier, that struct will be returned as a result.
324  static StructType getIdentified(MLIRContext *context, StringRef identifier);
325 
326  /// Construct a (possibly identified) StructType with no members.
327  ///
328  /// Note: this method might fail in a multi-threaded setup if another thread
329  /// created an identified struct with the same identifier but with different
330  /// contents before returning. In which case, an empty (default-constructed)
331  /// StructType is returned.
332  static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
333 
334  /// For literal structs, return an empty string.
335  /// For identified structs, return the struct's identifier.
336  StringRef getIdentifier() const;
337 
338  /// Returns true if the StructType is identified.
339  bool isIdentified() const;
340 
341  unsigned getNumElements() const;
342 
343  Type getElementType(unsigned) const;
344 
345  /// Range class for element types.
347  : public ::llvm::detail::indexed_accessor_range_base<
348  ElementTypeRange, const Type *, Type, Type, Type> {
349  private:
350  using RangeBaseT::RangeBaseT;
351 
352  /// See `llvm::detail::indexed_accessor_range_base` for details.
353  static const Type *offset_base(const Type *object, ptrdiff_t index) {
354  return object + index;
355  }
356  /// See `llvm::detail::indexed_accessor_range_base` for details.
357  static Type dereference_iterator(const Type *object, ptrdiff_t index) {
358  return object[index];
359  }
360 
361  /// Allow base class access to `offset_base` and `dereference_iterator`.
362  friend RangeBaseT;
363  };
364 
365  ElementTypeRange getElementTypes() const;
366 
367  bool hasOffset() const;
368 
369  uint64_t getMemberOffset(unsigned) const;
370 
371  // Returns in `memberDecorations` the Decorations (apart from Offset)
372  // associated with all members of the StructType.
373  void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
374  &memberDecorations) const;
375 
376  // Returns in `decorationsInfo` all the Decorations (apart from Offset)
377  // associated with the `i`-th member of the StructType.
378  void getMemberDecorations(
379  unsigned i,
381 
382  /// Sets the contents of an incomplete identified StructType. This method must
383  /// be called only for identified StructTypes and it must be called only once
384  /// per instance. Otherwise, failure() is returned.
386  trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
387  ArrayRef<MemberDecorationInfo> memberDecorations = {});
388 
389  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
390  Optional<StorageClass> storage = llvm::None);
391  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
392  Optional<StorageClass> storage = llvm::None);
393 };
394 
395 llvm::hash_code
396 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
397 
398 // SPIR-V cooperative matrix type
400  : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
401  detail::CooperativeMatrixTypeStorage> {
402 public:
403  using Base::Base;
404 
405  static CooperativeMatrixNVType get(Type elementType, Scope scope,
406  unsigned rows, unsigned columns);
407  Type getElementType() const;
408 
409  /// Return the scope of the cooperative matrix.
410  Scope getScope() const;
411  /// return the number of rows of the matrix.
412  unsigned getRows() const;
413  /// return the number of columns of the matrix.
414  unsigned getColumns() const;
415 
416  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
417  Optional<StorageClass> storage = llvm::None);
418  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
419  Optional<StorageClass> storage = llvm::None);
420 };
421 
422 // SPIR-V matrix type
423 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
424  detail::MatrixTypeStorage> {
425 public:
426  using Base::Base;
427 
428  static MatrixType get(Type columnType, uint32_t columnCount);
429 
431  Type columnType, uint32_t columnCount);
432 
434  Type columnType, uint32_t columnCount);
435 
436  /// Returns true if the matrix elements are vectors of float elements.
437  static bool isValidColumnType(Type columnType);
438 
439  Type getColumnType() const;
440 
441  /// Returns the number of rows.
442  unsigned getNumRows() const;
443 
444  /// Returns the number of columns.
445  unsigned getNumColumns() const;
446 
447  /// Returns total number of elements (rows*columns).
448  unsigned getNumElements() const;
449 
450  /// Returns the elements' type (i.e, single element type).
451  Type getElementType() const;
452 
453  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
454  Optional<StorageClass> storage = llvm::None);
455  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
456  Optional<StorageClass> storage = llvm::None);
457 };
458 
459 } // namespace spirv
460 } // namespace mlir
461 
462 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
Include the generated interface declarations.
MemberDecorationInfo(uint32_t index, uint32_t hasValue, Decoration decoration, uint32_t decorationValue)
Definition: SPIRVTypes.h:293
constexpr Type()
Definition: Types.h:84
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:311
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:687
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Range class for element types.
Definition: SPIRVTypes.h:346
bool operator==(const MemberDecorationInfo &other) const
Definition: SPIRVTypes.h:298
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:751
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SPIR-V struct type.
Definition: SPIRVTypes.h:278
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
bool operator<(const MemberDecorationInfo &other) const
Definition: SPIRVTypes.h:304
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)