MLIR  22.0.0git
LayoutUtils.cpp
Go to the documentation of this file.
1 //===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Utilities used to get alignment and layout information
10 // for types in SPIR-V dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 
17 using namespace mlir;
18 
21  Size size = 0;
22  Size alignment = 1;
23  return decorateType(structType, size, alignment);
24 }
25 
29  VulkanLayoutUtils::Size &alignment) {
30  if (structType.getNumElements() == 0) {
31  return structType;
32  }
33 
34  SmallVector<Type, 4> memberTypes;
37 
38  Size structMemberOffset = 0;
39  Size maxMemberAlignment = 1;
40 
41  for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
42  Size memberSize = 0;
43  Size memberAlignment = 1;
44 
45  Type memberType =
46  decorateType(structType.getElementType(i), memberSize, memberAlignment);
47  structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
48  memberTypes.push_back(memberType);
49  offsetInfo.push_back(
50  static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
51  // If the member's size is the max value, it must be the last member and it
52  // must be a runtime array.
53  assert(memberSize != std::numeric_limits<Size>().max() ||
54  (i + 1 == e &&
55  isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
56  // According to the Vulkan spec:
57  // "A structure has a base alignment equal to the largest base alignment of
58  // any of its members."
59  structMemberOffset += memberSize;
60  maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
61  }
62 
63  // According to the Vulkan spec:
64  // "The Offset decoration of a member must not place it between the end of a
65  // structure or an array and the next multiple of the alignment of that
66  // structure or array."
67  size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
68  alignment = maxMemberAlignment;
69  structType.getMemberDecorations(memberDecorations);
70 
71  if (!structType.isIdentified())
72  return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
73 
74  // Identified structs are uniqued by identifier so it is not possible
75  // to create 2 structs with the same name but different decorations.
76  return nullptr;
77 }
78 
80  VulkanLayoutUtils::Size &alignment) {
81  if (isa<spirv::ScalarType>(type)) {
82  alignment = getScalarTypeAlignment(type);
83  // Vulkan spec does not specify any padding for a scalar type.
84  size = alignment;
85  return type;
86  }
87  if (auto structType = dyn_cast<spirv::StructType>(type))
88  return decorateType(structType, size, alignment);
89  if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
90  return decorateType(arrayType, size, alignment);
91  if (auto vectorType = dyn_cast<VectorType>(type))
92  return decorateType(vectorType, size, alignment);
93  if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
94  return decorateType(matrixType, size, alignment);
95  if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
96  size = std::numeric_limits<Size>().max();
97  return decorateType(arrayType, alignment);
98  }
99  if (isa<spirv::PointerType>(type)) {
100  // TODO: Add support for `PhysicalStorageBufferAddresses`.
101  return nullptr;
102  }
103  llvm_unreachable("unhandled SPIR-V type");
104 }
105 
106 Type VulkanLayoutUtils::decorateType(VectorType vectorType,
108  VulkanLayoutUtils::Size &alignment) {
109  const unsigned numElements = vectorType.getNumElements();
110  Type elementType = vectorType.getElementType();
111  Size elementSize = 0;
112  Size elementAlignment = 1;
113 
114  Type memberType = decorateType(elementType, elementSize, elementAlignment);
115  // According to the Vulkan spec:
116  // 1. "A two-component vector has a base alignment equal to twice its scalar
117  // alignment."
118  // 2. "A three- or four-component vector has a base alignment equal to four
119  // times its scalar alignment."
120  size = elementSize * numElements;
121  alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
122  return VectorType::get(numElements, memberType);
123 }
124 
127  VulkanLayoutUtils::Size &alignment) {
128  const unsigned numElements = arrayType.getNumElements();
129  Type elementType = arrayType.getElementType();
130  Size elementSize = 0;
131  Size elementAlignment = 1;
132 
133  Type memberType = decorateType(elementType, elementSize, elementAlignment);
134  // According to the Vulkan spec:
135  // "An array has a base alignment equal to the base alignment of its element
136  // type."
137  size = elementSize * numElements;
138  alignment = elementAlignment;
139  return spirv::ArrayType::get(memberType, numElements, elementSize);
140 }
141 
144  VulkanLayoutUtils::Size &alignment) {
145  const unsigned numColumns = matrixType.getNumColumns();
146  Type columnType = matrixType.getColumnType();
147  unsigned numElements = matrixType.getNumElements();
148  Type elementType = matrixType.getElementType();
149  Size elementSize = 0;
150  Size elementAlignment = 1;
151 
152  decorateType(elementType, elementSize, elementAlignment);
153  // According to the Vulkan spec:
154  // "A matrix type inherits scalar alignment from the equivalent array
155  // declaration."
156  size = elementSize * numElements;
157  alignment = elementAlignment;
158  return spirv::MatrixType::get(columnType, numColumns);
159 }
160 
162  VulkanLayoutUtils::Size &alignment) {
163  Type elementType = arrayType.getElementType();
164  Size elementSize = 0;
165 
166  Type memberType = decorateType(elementType, elementSize, alignment);
167  return spirv::RuntimeArrayType::get(memberType, elementSize);
168 }
169 
171 VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
172  // According to the Vulkan spec:
173  // 1. "A scalar of size N has a scalar alignment of N."
174  // 2. "A scalar has a base alignment equal to its scalar alignment."
175  // 3. "A scalar, vector or matrix type has an extended alignment equal to its
176  // base alignment."
177  unsigned bitWidth = scalarType.getIntOrFloatBitWidth();
178  if (bitWidth == 1)
179  return 1;
180  return bitWidth / 8;
181 }
182 
184  auto ptrType = dyn_cast<spirv::PointerType>(type);
185  if (!ptrType) {
186  return true;
187  }
188 
189  const spirv::StorageClass storageClass = ptrType.getStorageClass();
190  auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
191  if (!structType) {
192  return true;
193  }
194 
195  switch (storageClass) {
196  case spirv::StorageClass::Uniform:
197  case spirv::StorageClass::StorageBuffer:
198  case spirv::StorageClass::PushConstant:
199  case spirv::StorageClass::PhysicalStorageBuffer:
200  return structType.hasOffset() || !structType.getNumElements();
201  default:
202  return true;
203  }
204 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static bool isLegalType(Type type)
Checks whether a type is legal in terms of Vulkan layout info decoration.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:20
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getElementType() const
Returns the elements' type (i.e, single element type).
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
bool isIdentified() const
Returns true if the StructType is identified.
unsigned getNumElements() const
Type getElementType(unsigned) const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...