MLIR  16.0.0git
SPIRVTypes.h
Go to the documentation of this file.
1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
15 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/TypeSupport.h"
21 #include "mlir/IR/Types.h"
22 
23 #include <tuple>
24 
25 namespace mlir {
26 namespace spirv {
27 
28 namespace detail {
29 struct ArrayTypeStorage;
30 struct CooperativeMatrixTypeStorage;
31 struct ImageTypeStorage;
32 struct JointMatrixTypeStorage;
33 struct MatrixTypeStorage;
34 struct PointerTypeStorage;
35 struct RuntimeArrayTypeStorage;
36 struct SampledImageTypeStorage;
37 struct StructTypeStorage;
38 
39 } // namespace detail
40 
41 // Base SPIR-V type for providing availability queries.
42 class SPIRVType : public Type {
43 public:
44  using Type::Type;
45 
46  static bool classof(Type type);
47 
48  bool isScalarOrVector();
49 
50  /// The extension requirements for each type are following the
51  /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
52  /// convention.
54 
55  /// Appends to `extensions` the extensions needed for this type to appear in
56  /// the given `storage` class. This method does not guarantee the uniqueness
57  /// of extensions; the same extension may be appended multiple times.
58  void getExtensions(ExtensionArrayRefVector &extensions,
59  Optional<StorageClass> storage = llvm::None);
60 
61  /// The capability requirements for each type are following the
62  /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
63  /// convention.
65 
66  /// Appends to `capabilities` the capabilities needed for this type to appear
67  /// in the given `storage` class. This method does not guarantee the
68  /// uniqueness of capabilities; the same capability may be appended multiple
69  /// times.
70  void getCapabilities(CapabilityArrayRefVector &capabilities,
71  Optional<StorageClass> storage = llvm::None);
72 
73  /// Returns the size in bytes for each type. If no size can be calculated,
74  /// returns `llvm::None`. Note that if the type has explicit layout, it is
75  /// also taken into account in calculation.
77 };
78 
79 // SPIR-V scalar type: bool type, integer type, floating point type.
80 class ScalarType : public SPIRVType {
81 public:
82  using SPIRVType::SPIRVType;
83 
84  static bool classof(Type type);
85 
86  /// Returns true if the given integer type is valid for the SPIR-V dialect.
87  static bool isValid(FloatType);
88  /// Returns true if the given float type is valid for the SPIR-V dialect.
89  static bool isValid(IntegerType);
90 
91  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
92  Optional<StorageClass> storage = llvm::None);
93  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
94  Optional<StorageClass> storage = llvm::None);
95 
97 };
98 
99 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
100 class CompositeType : public SPIRVType {
101 public:
102  using SPIRVType::SPIRVType;
103 
104  static bool classof(Type type);
105 
106  /// Returns true if the given vector type is valid for the SPIR-V dialect.
107  static bool isValid(VectorType);
108 
109  /// Return the number of elements of the type. This should only be called if
110  /// hasCompileTimeKnownNumElements is true.
111  unsigned getNumElements() const;
112 
113  Type getElementType(unsigned) const;
114 
115  /// Return true if the number of elements is known at compile time and is not
116  /// implementation dependent.
117  bool hasCompileTimeKnownNumElements() const;
118 
119  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
120  Optional<StorageClass> storage = llvm::None);
121  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
122  Optional<StorageClass> storage = llvm::None);
123 
125 };
126 
127 // SPIR-V array type
128 class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
129  detail::ArrayTypeStorage> {
130 public:
131  using Base::Base;
132 
133  static ArrayType get(Type elementType, unsigned elementCount);
134 
135  /// Returns an array type with the given stride in bytes.
136  static ArrayType get(Type elementType, unsigned elementCount,
137  unsigned stride);
138 
139  unsigned getNumElements() const;
140 
141  Type getElementType() const;
142 
143  /// Returns the array stride in bytes. 0 means no stride decorated on this
144  /// type.
145  unsigned getArrayStride() const;
146 
147  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
148  Optional<StorageClass> storage = llvm::None);
149  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
150  Optional<StorageClass> storage = llvm::None);
151 
152  /// Returns the array size in bytes. Since array type may have an explicit
153  /// stride declaration (in bytes), we also include it in the calculation.
155 };
156 
157 // SPIR-V image type
159  : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
160 public:
161  using Base::Base;
162 
163  static ImageType
164  get(Type elementType, Dim dim,
165  ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
166  ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
167  ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
168  ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
169  ImageFormat format = ImageFormat::Unknown) {
170  return ImageType::get(
171  std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
172  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
173  elementType, dim, depth, arrayed, samplingInfo, samplerUse,
174  format));
175  }
176 
177  static ImageType
178  get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
179  ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
180 
181  Type getElementType() const;
182  Dim getDim() const;
183  ImageDepthInfo getDepthInfo() const;
184  ImageArrayedInfo getArrayedInfo() const;
185  ImageSamplingInfo getSamplingInfo() const;
186  ImageSamplerUseInfo getSamplerUseInfo() const;
187  ImageFormat getImageFormat() const;
188  // TODO: Add support for Access qualifier
189 
190  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
191  Optional<StorageClass> storage = llvm::None);
192  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
193  Optional<StorageClass> storage = llvm::None);
194 };
195 
196 // SPIR-V pointer type
197 class PointerType : public Type::TypeBase<PointerType, SPIRVType,
198  detail::PointerTypeStorage> {
199 public:
200  using Base::Base;
201 
202  static PointerType get(Type pointeeType, StorageClass storageClass);
203 
204  Type getPointeeType() const;
205 
206  StorageClass getStorageClass() const;
207 
208  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
209  Optional<StorageClass> storage = llvm::None);
210  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
211  Optional<StorageClass> storage = llvm::None);
212 };
213 
214 // SPIR-V run-time array type
216  : public Type::TypeBase<RuntimeArrayType, SPIRVType,
217  detail::RuntimeArrayTypeStorage> {
218 public:
219  using Base::Base;
220 
221  static RuntimeArrayType get(Type elementType);
222 
223  /// Returns a runtime array type with the given stride in bytes.
224  static RuntimeArrayType get(Type elementType, unsigned stride);
225 
226  Type getElementType() const;
227 
228  /// Returns the array stride in bytes. 0 means no stride decorated on this
229  /// type.
230  unsigned getArrayStride() const;
231 
232  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
233  Optional<StorageClass> storage = llvm::None);
234  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
235  Optional<StorageClass> storage = llvm::None);
236 };
237 
238 // SPIR-V sampled image type
240  : public Type::TypeBase<SampledImageType, SPIRVType,
241  detail::SampledImageTypeStorage> {
242 public:
243  using Base::Base;
244 
245  static SampledImageType get(Type imageType);
246 
247  static SampledImageType
248  getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
249 
251  Type imageType);
252 
253  Type getImageType() const;
254 
255  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
256  Optional<spirv::StorageClass> storage = llvm::None);
257  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
258  Optional<spirv::StorageClass> storage = llvm::None);
259 };
260 
261 /// SPIR-V struct type. Two kinds of struct types are supported:
262 /// - Literal: a literal struct type is uniqued by its fields (types + offset
263 /// info + decoration info).
264 /// - Identified: an indentified struct type is uniqued by its string identifier
265 /// (name). This is useful in representing recursive structs. For example, the
266 /// following C struct:
267 ///
268 /// struct A {
269 /// A* next;
270 /// };
271 ///
272 /// would be represented in MLIR as:
273 ///
274 /// !spirv.struct<A, (!spirv.ptr<!spirv.struct<A>, Generic>)>
275 ///
276 /// In the above, expressing recursive struct types is accomplished by giving a
277 /// recursive struct a unique identified and using that identifier in the struct
278 /// definition for recursive references.
280  : public Type::TypeBase<StructType, CompositeType,
281  detail::StructTypeStorage, TypeTrait::IsMutable> {
282 public:
283  using Base::Base;
284 
285  // Type for specifying the offset of the struct members
286  using OffsetInfo = uint32_t;
287 
288  // Type for specifying the decoration(s) on struct members
290  uint32_t memberIndex : 31;
291  uint32_t hasValue : 1;
292  Decoration decoration;
293  uint32_t decorationValue;
294 
295  MemberDecorationInfo(uint32_t index, uint32_t hasValue,
296  Decoration decoration, uint32_t decorationValue)
297  : memberIndex(index), hasValue(hasValue), decoration(decoration),
298  decorationValue(decorationValue) {}
299 
300  bool operator==(const MemberDecorationInfo &other) const {
301  return (this->memberIndex == other.memberIndex) &&
302  (this->decoration == other.decoration) &&
303  (this->decorationValue == other.decorationValue);
304  }
305 
306  bool operator<(const MemberDecorationInfo &other) const {
307  return this->memberIndex < other.memberIndex ||
308  (this->memberIndex == other.memberIndex &&
309  static_cast<uint32_t>(this->decoration) <
310  static_cast<uint32_t>(other.decoration));
311  }
312  };
313 
314  /// Construct a literal StructType with at least one member.
315  static StructType get(ArrayRef<Type> memberTypes,
316  ArrayRef<OffsetInfo> offsetInfo = {},
317  ArrayRef<MemberDecorationInfo> memberDecorations = {});
318 
319  /// Construct an identified StructType. This creates a StructType whose body
320  /// (member types, offset info, and decorations) is not set yet. A call to
321  /// StructType::trySetBody(...) must follow when the StructType contents are
322  /// available (e.g. parsed or deserialized).
323  ///
324  /// Note: If another thread creates (or had already created) a struct with the
325  /// same identifier, that struct will be returned as a result.
326  static StructType getIdentified(MLIRContext *context, StringRef identifier);
327 
328  /// Construct a (possibly identified) StructType with no members.
329  ///
330  /// Note: this method might fail in a multi-threaded setup if another thread
331  /// created an identified struct with the same identifier but with different
332  /// contents before returning. In which case, an empty (default-constructed)
333  /// StructType is returned.
334  static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
335 
336  /// For literal structs, return an empty string.
337  /// For identified structs, return the struct's identifier.
338  StringRef getIdentifier() const;
339 
340  /// Returns true if the StructType is identified.
341  bool isIdentified() const;
342 
343  unsigned getNumElements() const;
344 
345  Type getElementType(unsigned) const;
346 
347  /// Range class for element types.
349  : public ::llvm::detail::indexed_accessor_range_base<
350  ElementTypeRange, const Type *, Type, Type, Type> {
351  private:
352  using RangeBaseT::RangeBaseT;
353 
354  /// See `llvm::detail::indexed_accessor_range_base` for details.
355  static const Type *offset_base(const Type *object, ptrdiff_t index) {
356  return object + index;
357  }
358  /// See `llvm::detail::indexed_accessor_range_base` for details.
359  static Type dereference_iterator(const Type *object, ptrdiff_t index) {
360  return object[index];
361  }
362 
363  /// Allow base class access to `offset_base` and `dereference_iterator`.
364  friend RangeBaseT;
365  };
366 
367  ElementTypeRange getElementTypes() const;
368 
369  bool hasOffset() const;
370 
371  uint64_t getMemberOffset(unsigned) const;
372 
373  // Returns in `memberDecorations` the Decorations (apart from Offset)
374  // associated with all members of the StructType.
375  void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
376  &memberDecorations) const;
377 
378  // Returns in `decorationsInfo` all the Decorations (apart from Offset)
379  // associated with the `i`-th member of the StructType.
380  void getMemberDecorations(
381  unsigned i,
383 
384  /// Sets the contents of an incomplete identified StructType. This method must
385  /// be called only for identified StructTypes and it must be called only once
386  /// per instance. Otherwise, failure() is returned.
388  trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
389  ArrayRef<MemberDecorationInfo> memberDecorations = {});
390 
391  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
392  Optional<StorageClass> storage = llvm::None);
393  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
394  Optional<StorageClass> storage = llvm::None);
395 };
396 
397 llvm::hash_code
398 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
399 
400 // SPIR-V cooperative matrix type
402  : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
403  detail::CooperativeMatrixTypeStorage> {
404 public:
405  using Base::Base;
406 
407  static CooperativeMatrixNVType get(Type elementType, Scope scope,
408  unsigned rows, unsigned columns);
409  Type getElementType() const;
410 
411  /// Return the scope of the cooperative matrix.
412  Scope getScope() const;
413  /// return the number of rows of the matrix.
414  unsigned getRows() const;
415  /// return the number of columns of the matrix.
416  unsigned getColumns() const;
417 
418  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
419  Optional<StorageClass> storage = llvm::None);
420  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
421  Optional<StorageClass> storage = llvm::None);
422 };
423 
424 // SPIR-V joint matrix type
426  : public Type::TypeBase<JointMatrixINTELType, CompositeType,
427  detail::JointMatrixTypeStorage> {
428 public:
429  using Base::Base;
430 
431  static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
432  unsigned columns, MatrixLayout matrixLayout);
433  Type getElementType() const;
434 
435  /// Return the scope of the joint matrix.
436  Scope getScope() const;
437  /// return the number of rows of the matrix.
438  unsigned getRows() const;
439  /// return the number of columns of the matrix.
440  unsigned getColumns() const;
441 
442  /// return the layout of the matrix
443  MatrixLayout getMatrixLayout() const;
444 
445  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
446  Optional<StorageClass> storage = llvm::None);
447  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
448  Optional<StorageClass> storage = llvm::None);
449 };
450 
451 // SPIR-V matrix type
452 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
453  detail::MatrixTypeStorage> {
454 public:
455  using Base::Base;
456 
457  static MatrixType get(Type columnType, uint32_t columnCount);
458 
460  Type columnType, uint32_t columnCount);
461 
463  Type columnType, uint32_t columnCount);
464 
465  /// Returns true if the matrix elements are vectors of float elements.
466  static bool isValidColumnType(Type columnType);
467 
468  Type getColumnType() const;
469 
470  /// Returns the number of rows.
471  unsigned getNumRows() const;
472 
473  /// Returns the number of columns.
474  unsigned getNumColumns() const;
475 
476  /// Returns total number of elements (rows*columns).
477  unsigned getNumElements() const;
478 
479  /// Returns the elements' type (i.e, single element type).
480  Type getElementType() const;
481 
482  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
483  Optional<StorageClass> storage = llvm::None);
484  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
485  Optional<StorageClass> storage = llvm::None);
486 };
487 
488 } // namespace spirv
489 } // namespace mlir
490 
491 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
Include the generated interface declarations.
MemberDecorationInfo(uint32_t index, uint32_t hasValue, Decoration decoration, uint32_t decorationValue)
Definition: SPIRVTypes.h:295
constexpr Type()
Definition: Types.h:86
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:694
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Range class for element types.
Definition: SPIRVTypes.h:348
bool operator==(const MemberDecorationInfo &other) const
Definition: SPIRVTypes.h:300
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1174
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:164
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
SPIR-V struct type.
Definition: SPIRVTypes.h:279
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
bool operator<(const MemberDecorationInfo &other) const
Definition: SPIRVTypes.h:306
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)