MLIR 22.0.0git
SPIRVTypes.h
Go to the documentation of this file.
1//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
15
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Location.h"
20#include "mlir/IR/TypeSupport.h"
21#include "mlir/IR/Types.h"
22
23#include <cstdint>
24#include <tuple>
25
26namespace mlir {
27namespace spirv {
28
29namespace detail {
30struct ArrayTypeStorage;
33struct ImageTypeStorage;
39
40} // namespace detail
41
42// Base SPIR-V type for providing availability queries.
43class SPIRVType : public Type {
44public:
45 using Type::Type;
46
47 static bool classof(Type type);
48
49 bool isScalarOrVector();
50
51 /// The extension requirements for each type are following the
52 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
53 /// convention.
55
56 /// Appends to `extensions` the extensions needed for this type to appear in
57 /// the given `storage` class. This method does not guarantee the uniqueness
58 /// of extensions; the same extension may be appended multiple times.
60 std::optional<StorageClass> storage = std::nullopt);
61
62 /// The capability requirements for each type are following the
63 /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
64 /// convention.
66
67 /// Appends to `capabilities` the capabilities needed for this type to appear
68 /// in the given `storage` class. This method does not guarantee the
69 /// uniqueness of capabilities; the same capability may be appended multiple
70 /// times.
72 std::optional<StorageClass> storage = std::nullopt);
73
74 /// Returns the size in bytes for each type. If no size can be calculated,
75 /// returns `std::nullopt`. Note that if the type has explicit layout, it is
76 /// also taken into account in calculation.
77 std::optional<int64_t> getSizeInBytes();
78};
79
80// SPIR-V scalar type: bool type, integer type, floating point type.
81class ScalarType : public SPIRVType {
82public:
83 using SPIRVType::SPIRVType;
84
85 static bool classof(Type type);
86
87 /// Returns true if the given integer type is valid for the SPIR-V dialect.
88 static bool isValid(FloatType);
89 /// Returns true if the given float type is valid for the SPIR-V dialect.
90 static bool isValid(IntegerType);
91};
92
93// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
94// StructType, or SPIR-V TensorArmType.
95class CompositeType : public SPIRVType {
96public:
97 using SPIRVType::SPIRVType;
98
99 static bool classof(Type type);
100
101 /// Returns true if the given vector type is valid for the SPIR-V dialect.
102 static bool isValid(VectorType);
103
104 /// Return the number of elements of the type. This should only be called if
105 /// hasCompileTimeKnownNumElements is true.
106 unsigned getNumElements() const;
107
108 Type getElementType(unsigned) const;
109
110 /// Return true if the number of elements is known at compile time and is not
111 /// implementation dependent.
113};
114
115// SPIR-V array type
116class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
117 detail::ArrayTypeStorage> {
118public:
119 using Base::Base;
120
121 static constexpr StringLiteral name = "spirv.array";
122
123 static ArrayType get(Type elementType, unsigned elementCount);
124
125 /// Returns an array type with the given stride in bytes.
126 static ArrayType get(Type elementType, unsigned elementCount,
127 unsigned stride);
128
129 unsigned getNumElements() const;
130
131 Type getElementType() const;
132
133 /// Returns the array stride in bytes. 0 means no stride decorated on this
134 /// type.
135 unsigned getArrayStride() const;
136};
137
138// SPIR-V image type
140 : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
141public:
142 using Base::Base;
143
144 static constexpr StringLiteral name = "spirv.image";
145
146 static ImageType
147 get(Type elementType, Dim dim,
148 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
149 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
150 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
151 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
152 ImageFormat format = ImageFormat::Unknown) {
153 return ImageType::get(
154 std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
155 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
156 elementType, dim, depth, arrayed, samplingInfo, samplerUse,
157 format));
158 }
159
160 static ImageType
161 get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
162 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
163
164 Type getElementType() const;
165 Dim getDim() const;
166 ImageDepthInfo getDepthInfo() const;
167 ImageArrayedInfo getArrayedInfo() const;
168 ImageSamplingInfo getSamplingInfo() const;
169 ImageSamplerUseInfo getSamplerUseInfo() const;
170 ImageFormat getImageFormat() const;
171 // TODO: Add support for Access qualifier
172};
173
174// SPIR-V pointer type
175class PointerType : public Type::TypeBase<PointerType, SPIRVType,
176 detail::PointerTypeStorage> {
177public:
178 using Base::Base;
179
180 static constexpr StringLiteral name = "spirv.pointer";
181
182 static PointerType get(Type pointeeType, StorageClass storageClass);
183
184 Type getPointeeType() const;
185
186 StorageClass getStorageClass() const;
187};
188
189// SPIR-V run-time array type
191 : public Type::TypeBase<RuntimeArrayType, SPIRVType,
192 detail::RuntimeArrayTypeStorage> {
193public:
194 using Base::Base;
195
196 static constexpr StringLiteral name = "spirv.rtarray";
197
198 static RuntimeArrayType get(Type elementType);
199
200 /// Returns a runtime array type with the given stride in bytes.
201 static RuntimeArrayType get(Type elementType, unsigned stride);
202
203 Type getElementType() const;
204
205 /// Returns the array stride in bytes. 0 means no stride decorated on this
206 /// type.
207 unsigned getArrayStride() const;
208};
209
210// SPIR-V sampled image type
212 : public Type::TypeBase<SampledImageType, SPIRVType,
213 detail::SampledImageTypeStorage> {
214public:
215 using Base::Base;
216
217 static constexpr StringLiteral name = "spirv.sampled_image";
218
219 static SampledImageType get(Type imageType);
220
221 static SampledImageType
223
224 static LogicalResult
226 Type imageType);
227
228 Type getImageType() const;
229};
230
231/// SPIR-V struct type. Two kinds of struct types are supported:
232/// - Literal: a literal struct type is uniqued by its fields (types + offset
233/// info + decoration info).
234/// - Identified: an indentified struct type is uniqued by its string identifier
235/// (name). This is useful in representing recursive structs. For example, the
236/// following C struct:
237///
238/// struct A {
239/// A* next;
240/// };
241///
242/// would be represented in MLIR as:
243///
244/// !spirv.struct<A, (!spirv.ptr<!spirv.struct<A>, Generic>)>
245///
246/// In the above, expressing recursive struct types is accomplished by giving a
247/// recursive struct a unique identified and using that identifier in the struct
248/// definition for recursive references.
250 : public Type::TypeBase<StructType, CompositeType,
251 detail::StructTypeStorage, TypeTrait::IsMutable> {
252public:
253 using Base::Base;
254
255 // Type for specifying the offset of the struct members
256 using OffsetInfo = uint32_t;
257
258 static constexpr StringLiteral name = "spirv.struct";
259
260 // Type for specifying the decoration(s) on struct members.
261 // If `decorationValue` is UnitAttr then decoration has no
262 // value.
264 uint32_t memberIndex;
265 Decoration decoration;
267
272
274 const MemberDecorationInfo &rhs) {
275 return lhs.memberIndex == rhs.memberIndex &&
276 lhs.decoration == rhs.decoration &&
277 lhs.decorationValue == rhs.decorationValue;
278 }
279
281 const MemberDecorationInfo &rhs) {
282 return std::tuple(lhs.memberIndex, llvm::to_underlying(lhs.decoration)) <
283 std::tuple(rhs.memberIndex, llvm::to_underlying(rhs.decoration));
284 }
285
286 bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
287 };
288
289 // Type for specifying the decoration(s) on the struct itself.
291 Decoration decoration;
293
296
298 const StructDecorationInfo &rhs) {
299 return lhs.decoration == rhs.decoration &&
300 lhs.decorationValue == rhs.decorationValue;
301 }
302
304 const StructDecorationInfo &rhs) {
305 return llvm::to_underlying(lhs.decoration) <
306 llvm::to_underlying(rhs.decoration);
307 }
308
309 bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
310 };
311
312 /// Construct a literal StructType with at least one member.
313 static StructType get(ArrayRef<Type> memberTypes,
314 ArrayRef<OffsetInfo> offsetInfo = {},
315 ArrayRef<MemberDecorationInfo> memberDecorations = {},
316 ArrayRef<StructDecorationInfo> structDecorations = {});
317
318 /// Construct an identified StructType. This creates a StructType whose body
319 /// (member types, offset info, and decorations) is not set yet. A call to
320 /// StructType::trySetBody(...) must follow when the StructType contents are
321 /// available (e.g. parsed or deserialized).
322 ///
323 /// Note: If another thread creates (or had already created) a struct with the
324 /// same identifier, that struct will be returned as a result.
325 static StructType getIdentified(MLIRContext *context, StringRef identifier);
326
327 /// Construct a (possibly identified) StructType with no members.
328 ///
329 /// Note: this method might fail in a multi-threaded setup if another thread
330 /// created an identified struct with the same identifier but with different
331 /// contents before returning. In which case, an empty (default-constructed)
332 /// StructType is returned.
333 static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
334
335 /// For literal structs, return an empty string.
336 /// For identified structs, return the struct's identifier.
337 StringRef getIdentifier() const;
338
339 /// Returns true if the StructType is identified.
340 bool isIdentified() const;
341
342 unsigned getNumElements() const;
343
344 Type getElementType(unsigned) const;
345
347
348 bool hasOffset() const;
349
350 /// Returns true if the struct has a specified decoration.
351 bool hasDecoration(spirv::Decoration decoration) const;
352
353 uint64_t getMemberOffset(unsigned) const;
354
355 // Returns in `memberDecorations` the Decorations (apart from Offset)
356 // associated with all members of the StructType.
357 void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
358 &memberDecorations) const;
359
360 // Returns in `decorationsInfo` all the Decorations (apart from Offset)
361 // associated with the `i`-th member of the StructType.
363 unsigned i,
364 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
365
366 // Returns in `structDecorations` the Decorations associated with the
367 // StructType.
368 void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
369 &structDecorations) const;
370
371 /// Sets the contents of an incomplete identified StructType. This method must
372 /// be called only for identified StructTypes and it must be called only once
373 /// per instance. Otherwise, failure() is returned.
374 LogicalResult
375 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
376 ArrayRef<MemberDecorationInfo> memberDecorations = {},
377 ArrayRef<StructDecorationInfo> structDecorations = {});
378};
379
380llvm::hash_code
381hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
382
383llvm::hash_code
384hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
385
386// SPIR-V KHR cooperative matrix type
388 : public Type::TypeBase<CooperativeMatrixType, CompositeType,
389 detail::CooperativeMatrixTypeStorage,
390 ShapedType::Trait> {
391public:
392 using Base::Base;
393
394 static constexpr StringLiteral name = "spirv.coopmatrix";
395
396 static CooperativeMatrixType get(Type elementType, uint32_t rows,
397 uint32_t columns, Scope scope,
398 CooperativeMatrixUseKHR use);
399 Type getElementType() const;
400
401 /// Returns the scope of the matrix.
402 Scope getScope() const;
403 /// Returns the number of rows of the matrix.
404 uint32_t getRows() const;
405 /// Returns the number of columns of the matrix.
406 uint32_t getColumns() const;
407 /// Returns the use parameter of the cooperative matrix.
408 CooperativeMatrixUseKHR getUse() const;
409
410 operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
411
413
414 bool hasRank() const { return true; }
415
417 Type elementType) const {
418 if (!shape)
419 return get(elementType, getRows(), getColumns(), getScope(), getUse());
420
421 assert(shape.value().size() == 2);
422 return get(elementType, shape.value()[0], shape.value()[1], getScope(),
423 getUse());
424 }
425};
426
427// SPIR-V matrix type
428class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
429 detail::MatrixTypeStorage> {
430public:
431 using Base::Base;
432
433 static constexpr StringLiteral name = "spirv.matrix";
434
435 static MatrixType get(Type columnType, uint32_t columnCount);
436
438 Type columnType, uint32_t columnCount);
439
440 static LogicalResult
442 Type columnType, uint32_t columnCount);
443
444 /// Returns true if the matrix elements are vectors of float elements.
445 static bool isValidColumnType(Type columnType);
446
447 Type getColumnType() const;
448
449 /// Returns the number of rows.
450 unsigned getNumRows() const;
451
452 /// Returns the number of columns.
453 unsigned getNumColumns() const;
454
455 /// Returns total number of elements (rows*columns).
456 unsigned getNumElements() const;
457
458 /// Returns the elements' type (i.e, single element type).
459 Type getElementType() const;
460};
461
462/// SPIR-V TensorARM Type
464 : public Type::TypeBase<TensorArmType, CompositeType,
465 detail::TensorArmTypeStorage, ShapedType::Trait> {
466public:
467 using Base::Base;
468
469 using ShapedTypeTraits = ShapedType::Trait<TensorArmType>;
470 using ShapedTypeTraits::getDimSize;
471 using ShapedTypeTraits::getDynamicDimIndex;
472 using ShapedTypeTraits::getElementTypeBitWidth;
473 using ShapedTypeTraits::getNumDynamicDims;
474 using ShapedTypeTraits::getNumElements;
475 using ShapedTypeTraits::getRank;
476 using ShapedTypeTraits::hasStaticShape;
477 using ShapedTypeTraits::isDynamicDim;
478
479 static constexpr StringLiteral name = "spirv.arm.tensor";
480
481 // TensorArm supports minimum rank of 1, hence an empty shape here means
482 // unranked.
483 static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
485 Type elementType) const;
486
487 static LogicalResult
489 ArrayRef<int64_t> shape, Type elementType);
490
491 Type getElementType() const;
493 bool hasRank() const { return !getShape().empty(); }
494 operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
495};
496
497} // namespace spirv
498} // namespace mlir
499
500#endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
lhs
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
constexpr Type()=default
detail::StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > TypeBase
Utility class for implementing types.
Definition Types.h:79
StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > Base
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
static constexpr StringLiteral name
Definition SPIRVTypes.h:121
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
CooperativeMatrixType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
Definition SPIRVTypes.h:416
static constexpr StringLiteral name
Definition SPIRVTypes.h:394
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static constexpr StringLiteral name
Definition SPIRVTypes.h:144
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
static constexpr StringLiteral name
Definition SPIRVTypes.h:433
unsigned getNumRows() const
Returns the number of rows.
static constexpr StringLiteral name
Definition SPIRVTypes.h:180
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
static constexpr StringLiteral name
Definition SPIRVTypes.h:196
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
constexpr Type()=default
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static constexpr StringLiteral name
Definition SPIRVTypes.h:217
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
static SampledImageType get(Type imageType)
static bool classof(Type type)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
bool hasDecoration(spirv::Decoration decoration) const
Returns true if the struct has a specified decoration.
unsigned getNumElements() const
Type getElementType(unsigned) const
static constexpr StringLiteral name
Definition SPIRVTypes.h:258
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:465
static constexpr StringLiteral name
Definition SPIRVTypes.h:479
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType)
ShapedType::Trait< TensorArmType > ShapedTypeTraits
Definition SPIRVTypes.h:469
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
TensorArmType cloneWith(std::optional< ArrayRef< int64_t > > shape, Type elementType) const
ArrayRef< int64_t > getShape() const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
friend bool operator==(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
Definition SPIRVTypes.h:273
MemberDecorationInfo(uint32_t index, Decoration decoration, Attribute decorationValue)
Definition SPIRVTypes.h:268
friend bool operator<(const MemberDecorationInfo &lhs, const MemberDecorationInfo &rhs)
Definition SPIRVTypes.h:280
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
Definition SPIRVTypes.h:294
friend bool operator<(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Definition SPIRVTypes.h:303
friend bool operator==(const StructDecorationInfo &lhs, const StructDecorationInfo &rhs)
Definition SPIRVTypes.h:297
Type storage for SPIR-V structure types: