MLIR 23.0.0git
XeGPULayoutImpl.h
Go to the documentation of this file.
1//===- XeGPULayoutImpl.h - Layout utility functions ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
10#define MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
11
18#include "llvm/ADT/STLFunctionalExtras.h"
19
20namespace mlir {
21
22class VectorType;
23class OpOperand;
24class OpResult;
25class OpBuilder;
26class ValueRange;
27class TypeConverter;
28class OpFoldResult;
29
30namespace xegpu {
31class DistributeLayoutAttr;
32class LayoutAttr;
33class TensorDescType;
34} // namespace xegpu
35
36namespace xegpu {
37
38LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
39 LayoutKind layoutKind, unsigned indexBitWidth,
40 bool printOnly = false);
41
42LogicalResult resolveLayoutConflicts(Operation *target);
43
44/// Callable returning the propagated layout for a given Value, used by the
45/// layout-propagation helpers below.
46using GetLayoutFnTy = llvm::function_ref<DistributeLayoutAttr(Value)>;
47
48/// Propagate layouts from a region branch op's region entry block arguments
49/// back to its init operands. The block argument's layout is obtained via
50/// `getLayoutOfValue`; the matching layout is then recorded on each init
51/// operand that flows into that block argument (e.g. scf.for's iter_args
52/// inits), and on tensor descriptor block argument types.
53LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp,
54 GetLayoutFnTy getLayoutOfValue);
55
56/// Attach layout attributes to all vector-type operands of operations within
57/// the given operation's nested region. Reports an error if any vector operand
58/// lacks a layout attribute.
60
61/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
62template <typename T,
63 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
64 std::is_same_v<T, OpResult>>>
65void removeLayoutAttr(const T &operandOrResult);
66
67/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
68/// given operation if they exist. If the operation contains regions, it is also
69/// applied recursively to the contained operations
71
72/// Removes the temporary layout attributes for each OpOperand and OpResult of
73/// the given operation. Recursive for contained operations if the given
74/// operation contains regions.
76
77/// Updates the NamedAttribute sequence by dropping sg-layout and
78/// sg-data information from any DistributeLayoutAttr found.
81
82/// Updates the NamedAttribute sequence by dropping inst-data information from
83/// any DistributeLayoutAttr found.
85
86/// Infers the source layout attribute for a broadcast operation given the
87/// result layout attribute, result shape, and source shape.
88DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
89 ArrayRef<int64_t> resShape,
90 ArrayRef<int64_t> srcShape);
91
92/// Infers the source layout attribute for a reduction operation given the
93/// result layout attribute and reduced dims.
94DistributeLayoutAttr
95inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
96 SmallVector<int64_t> reduceDims);
97
98/// Infers the source layout attribute for a reduction operation given the
99/// result layout attribute and reduced dims.
100DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout);
101
102/// Infers the source layout attribute for a transpose operation given the
103/// result layout attribute and permutation.
104DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout,
105 ArrayRef<int64_t> permutation);
106
107/// Infers the source layout attribute for a bitcast operation given the
108/// result layout attribute, result element type bitwidth, and source element
109/// type bitwidth.
110DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
111 int resElemTyBitWidth,
112 int srcElemTyBitWidth);
113
114/// Infers the source layout attribute for an interleave operation given the
115/// result layout attribute. Interleave doubles the innermost dimension size.
116DistributeLayoutAttr
117inferInterleaveSourceLayout(DistributeLayoutAttr resLayout);
118
119/// Infers the source layout attribute for a deinterleave operation given the
120/// result layout attribute. Deinterleave halves the innermost dimension size.
121DistributeLayoutAttr
122inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout);
123
124/// Infers the source layout attribute for a shape cast operation given the
125/// result layout attribute, result shape, and source shape.
126DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
127 ArrayRef<int64_t> resShape,
128 ArrayRef<int64_t> srcShape);
129
130/// Infers the source layout attribute for an insert strided slice operation
131/// given the result layout attribute, result shape, and source shape. Removes
132/// leading dimensions from the result layout to match the source shape size.
133DistributeLayoutAttr
134inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
135 ArrayRef<int64_t> resShape,
136 ArrayRef<int64_t> srcShape);
137
138/// Infers the source layout attribute for an insert operation.
139/// using same logic as inferInsertStridedSliceSourceLayout
140DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout,
141 ArrayRef<int64_t> resShape,
142 ArrayRef<int64_t> srcShape);
143
144/// Infers the source layout attribute for an extract operation. Adds
145/// leading dimensions to the source layout to match the source shape size.
146DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout,
147 ArrayRef<int64_t> resShape,
148 ArrayRef<int64_t> srcShape);
149
150/// Infers the layout attribute for mask and offset operand for Chunked load
151/// and store, given the anchor layout attribute for the value being load/store.
152DistributeLayoutAttr
153inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
154 int chunkSize);
155
156/// Infers the source layout attribute for an operand using result layout
157/// attribute
158DistributeLayoutAttr
160 DistributeLayoutAttr resLayout);
161
162/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
163/// the result.
164///
165/// This function first attempts to construct a source layout that, when
166/// sliced along reduction dimensions, produces a result layout compatible
167/// with the consumer's preferred layout. This minimizes data redistribution
168/// overhead. The SliceAttr for the result is then created based on the
169/// derived source layout and the specified reduction dimensions.
171 VectorType srcVectorTy,
172 DistributeLayoutAttr consumerLayout,
173 SmallVector<int64_t> reductionDims,
174 int numSg, const uArch::uArch *uArch);
175
176/// Sets up layout for Reduction operations by creating a SliceAttr for the
177/// result.
178SliceAttr setupReductionResultLayout(LayoutKind layoutKind,
179 VectorType srcVectorTy,
180 const uArch::uArch *uArch);
181
182/// Setup the result layout attribute for a bitcast operation based on element
183/// type bitwidths. This ensures the source layout can always be derived from
184/// the result layout.
185///
186/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
187/// resElemTyBitWidth), the result layout's innermost dimension data sizes
188/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This maintains
189/// the invariant that the source layout can be recovered by adjusting the
190/// result layout based on bitwidth ratio of input vs output.
191DistributeLayoutAttr setupBitCastResultLayout(
192 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
193 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
194
195/// Sets up the result layout for an interleave operation to ensure the source
196/// layout can be safely derived. Interleave doubles the innermost dimension,
197/// so the result layout must ensure that laneData is at least 2 (or a multiple
198/// of 2), and instData must be divisible by innermostDimLaneLayout * 2.
199DistributeLayoutAttr setupInterleaveResultLayout(
200 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
201 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
202
203/// Sets up the result layout for an insert strided slice operation.
204/// Creates a result layout based on the specified layout kind (InstData or
205/// Lane).
207 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
208 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
209
210/// Sets up the anchor layout for a load gather operation.
211DistributeLayoutAttr
212setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
213 int chunkSize, DistributeLayoutAttr consumerLayout,
214 const uArch::uArch *uArch);
215
216/// Sets up the anchor layout for load matrix operation.
217DistributeLayoutAttr
218setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
219 DistributeLayoutAttr consumerLayout,
220 const uArch::uArch *uArch);
221
222/// Sets up the anchor layout for a store scatter operation.
223DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
224 VectorType vectorTy,
225 int chunkSize,
226 const uArch::uArch *uArch);
227
228/// Sets up the anchor layout for a store matrix operation.
229DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
230 VectorType vectorTy,
231 const uArch::uArch *uArch);
232
233/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
234/// The numSg and consumerLayout (optional) are only used by sg layout creation.
235std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
236 DistributeLayoutAttr>>
237setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
238 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
239 const uArch::uArch *uArch);
240
241/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
242/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
243/// creation. A_scale and B_scale are optional.
244std::optional<
245 std::tuple<DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr,
246 DistributeLayoutAttr, DistributeLayoutAttr>>
247setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
248 VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy,
249 DistributeLayoutAttr consumerLayout, int numSg,
250 const uArch::uArch *uArch);
251
252/// Gets the expected layout for a given consumer operand. This will check if
253/// the owning operation of the consumer operand is one of the special layout
254/// users and determine the expected layout accordingly.
255DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
256
257/// Returns true if `op` is safe and cheap to clone: it has no side effects,
258/// no regions, and all of its operands are themselves trivially
259/// rematerializable (e.g. `vector.step`, splat `arith.constant`, or
260/// `vector.create_mask` whose operands are constants).
262
263} // namespace xegpu
264
265} // namespace mlir
266
267#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
llvm::function_ref< DistributeLayoutAttr(Value)> GetLayoutFnTy
Callable returning the propagated layout for a given Value, used by the layout-propagation helpers be...
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp, GetLayoutFnTy getLayoutOfValue)
Propagate layouts from a region branch op's region entry block arguments back to its init operands.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.