MLIR 23.0.0git
XeGPULayoutImpl.h
Go to the documentation of this file.
1//===- XeGPULayoutImpl.h - Layout utility functions ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
10#define MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
11
17
18namespace mlir {
19
20class VectorType;
21class OpOperand;
22class OpResult;
23class OpBuilder;
24class ValueRange;
25class TypeConverter;
26class OpFoldResult;
27
28namespace xegpu {
29class DistributeLayoutAttr;
30class LayoutAttr;
31class TensorDescType;
32} // namespace xegpu
33
34namespace xegpu {
35
36LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
37 LayoutKind layoutKind, unsigned indexBitWidth,
38 bool printOnly = false);
39
40LogicalResult resolveLayoutConflicts(Operation *target);
41
42/// Attach layout attributes to all vector-type operands of operations within
43/// the given operation's nested region. Reports an error if any vector operand
44/// lacks a layout attribute.
45bool recoverTemporaryLayouts(Operation *rootOp);
46
47/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
48template <typename T,
49 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
50 std::is_same_v<T, OpResult>>>
51void removeLayoutAttr(const T &operandOrResult);
52
53/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
54/// given operation if they exist. If the operation contains regions, it is also
55/// applied recursively to the contained operations
56void removeLayoutAttrs(Operation *op);
57
58/// Removes the temporary layout attributes for each OpOperand and OpResult of
59/// the given operation. Recursive for contained operations if the given
60/// operation contains regions.
61void removeTemporaryLayoutAttrs(Operation *op);
62
63/// Updates the NamedAttribute sequence by dropping sg-layout and
64/// sg-data information from any DistributeLayoutAttr found.
65SmallVector<NamedAttribute>
66dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
67
68/// Updates the NamedAttribute sequence by dropping inst-data information from
69/// any DistributeLayoutAttr found.
70SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
71
72/// Infers the source layout attribute for a broadcast operation given the
73/// result layout attribute, result shape, and source shape.
74DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
75 ArrayRef<int64_t> resShape,
76 ArrayRef<int64_t> srcShape);
77
78/// Infers the source layout attribute for a reduction operation given the
79/// result layout attribute and reduced dims.
80DistributeLayoutAttr
81inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
82 SmallVector<int64_t> reduceDims);
83
84/// Infers the source layout attribute for a reduction operation given the
85/// result layout attribute and reduced dims.
86DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout);
87
88/// Infers the source layout attribute for a transpose operation given the
89/// result layout attribute and permutation.
90DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout,
91 ArrayRef<int64_t> permutation);
92
93/// Infers the source layout attribute for a bitcast operation given the
94/// result layout attribute, result element type bitwidth, and source element
95/// type bitwidth.
96DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
97 int resElemTyBitWidth,
98 int srcElemTyBitWidth);
99
100/// Infers the source layout attribute for an interleave operation given the
101/// result layout attribute. Interleave doubles the innermost dimension size.
102DistributeLayoutAttr
103inferInterleaveSourceLayout(DistributeLayoutAttr resLayout);
104
105/// Infers the source layout attribute for a deinterleave operation given the
106/// result layout attribute. Deinterleave halves the innermost dimension size.
107DistributeLayoutAttr
108inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout);
109
110/// Infers the source layout attribute for a shape cast operation given the
111/// result layout attribute, result shape, and source shape.
112DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
113 ArrayRef<int64_t> resShape,
114 ArrayRef<int64_t> srcShape);
115
116/// Infers the source layout attribute for an insert strided slice operation
117/// given the result layout attribute, result shape, and source shape. Removes
118/// leading dimensions from the result layout to match the source shape size.
119DistributeLayoutAttr
120inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
121 ArrayRef<int64_t> resShape,
122 ArrayRef<int64_t> srcShape);
123
124/// Infers the source layout attribute for an insert operation.
125/// using same logic as inferInsertStridedSliceSourceLayout
126DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout,
127 ArrayRef<int64_t> resShape,
128 ArrayRef<int64_t> srcShape);
129
130/// Infers the source layout attribute for an extract operation. Adds
131/// leading dimensions to the source layout to match the source shape size.
132DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout,
133 ArrayRef<int64_t> resShape,
134 ArrayRef<int64_t> srcShape);
135
136/// Infers the layout attribute for mask and offset operand for Chunked load
137/// and store, given the anchor layout attribute for the value being load/store.
138DistributeLayoutAttr
139inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
140 int chunkSize);
141
142/// Infers the source layout attribute for an operand using result layout
143/// attribute
144DistributeLayoutAttr
145inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
146
147/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
148/// the result.
149///
150/// This function first attempts to construct a source layout that, when
151/// sliced along reduction dimensions, produces a result layout compatible
152/// with the consumer's preferred layout. This minimizes data redistribution
153/// overhead. The SliceAttr for the result is then created based on the
154/// derived source layout and the specified reduction dimensions.
156 VectorType srcVectorTy,
157 DistributeLayoutAttr consumerLayout,
158 SmallVector<int64_t> reductionDims,
159 int numSg, const uArch::uArch *uArch);
160
161/// Sets up layout for Reduction operations by creating a SliceAttr for the
162/// result.
163SliceAttr setupReductionResultLayout(LayoutKind layoutKind,
164 VectorType srcVectorTy,
165 const uArch::uArch *uArch);
166
167/// Setup the result layout attribute for a bitcast operation based on element
168/// type bitwidths. This ensures the source layout can always be derived from
169/// the result layout.
170///
171/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
172/// resElemTyBitWidth), the result layout's innermost dimension data sizes
173/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This maintains
174/// the invariant that the source layout can be recovered by adjusting the
175/// result layout based on bitwidth ratio of input vs output.
176DistributeLayoutAttr setupBitCastResultLayout(
177 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
178 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
179
180/// Sets up the result layout for an interleave operation to ensure the source
181/// layout can be safely derived. Interleave doubles the innermost dimension,
182/// so the result layout must ensure that laneData is at least 2 (or a multiple
183/// of 2), and instData must be divisible by innermostDimLaneLayout * 2.
184DistributeLayoutAttr setupInterleaveResultLayout(
185 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
186 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
187
188/// Sets up the result layout for an insert strided slice operation.
189/// Creates a result layout based on the specified layout kind (InstData or
190/// Lane).
192 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
193 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
194
195/// Sets up the anchor layout for a load gather operation.
196DistributeLayoutAttr
197setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
198 int chunkSize, DistributeLayoutAttr consumerLayout,
199 const uArch::uArch *uArch);
200
201/// Sets up the anchor layout for load matrix operation.
202DistributeLayoutAttr
203setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
204 DistributeLayoutAttr consumerLayout,
205 const uArch::uArch *uArch);
206
207/// Sets up the anchor layout for a store scatter operation.
208DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
209 VectorType vectorTy,
210 int chunkSize,
211 const uArch::uArch *uArch);
212
213/// Sets up the anchor layout for a store matrix operation.
214DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
215 VectorType vectorTy,
216 const uArch::uArch *uArch);
217
218/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
219/// The numSg and consumerLayout (optional) are only used by sg layout creation.
220std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
221 DistributeLayoutAttr>>
222setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
223 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
224 const uArch::uArch *uArch);
225
226/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
227/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
228/// creation. A_scale and B_scale are optional.
229std::optional<
230 std::tuple<DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr,
231 DistributeLayoutAttr, DistributeLayoutAttr>>
232setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
233 VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy,
234 DistributeLayoutAttr consumerLayout, int numSg,
235 const uArch::uArch *uArch);
236
237/// Gets the expected layout for a given consumer operand. This will check if
238/// the owning operation of the consumer operand is one of the special layout
239/// users and determine the expected layout accordingly.
240DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
241
242} // namespace xegpu
243
244} // namespace mlir
245
246#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.