MLIR 23.0.0git
XeGPULayoutImpl.h
Go to the documentation of this file.
1//===- XeGPULayoutImpl.h - Layout utility functions ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
10#define MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
11
17
18namespace mlir {
19
20class VectorType;
21class OpOperand;
22class OpResult;
23class OpBuilder;
24class ValueRange;
25class TypeConverter;
26class OpFoldResult;
27
28namespace xegpu {
29class DistributeLayoutAttr;
30class LayoutAttr;
31class TensorDescType;
32} // namespace xegpu
33
34namespace xegpu {
35
36LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
37 LayoutKind layoutKind, unsigned indexBitWidth,
38 bool printOnly = false);
39
40LogicalResult resolveLayoutConflicts(Operation *target);
41
42/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
43/// OpResult of of the given operation. If the operation contains regions, it is
44/// also applied recursively to the contained operations operation.
45/// TODO: To be replaced by recoverTemporaryLayouts()
46void recoverTemporaryLayoutsDeprecated(Operation *op);
47
48/// Attach layout attributes to all vector-type operands of operations within
49/// the given operation's nested region. Reports an error if any vector operand
50/// lacks a layout attribute.
51bool recoverTemporaryLayouts(Operation *rootOp);
52
53/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
54template <typename T,
55 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
56 std::is_same_v<T, OpResult>>>
57void removeLayoutAttr(const T &operandOrResult);
58
59/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
60/// given operation if they exist. If the operation contains regions, it is also
61/// applied recursively to the contained operations
62void removeLayoutAttrs(Operation *op);
63
64/// Updates the NamedAttribute sequence by dropping sg-layout and
65/// sg-data information from any DistributeLayoutAttr found.
66SmallVector<NamedAttribute>
67dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
68
69/// Updates the NamedAttribute sequence by dropping inst-data information from
70/// any DistributeLayoutAttr found.
71SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
72
73/// Infers the source layout attribute for a broadcast operation given the
74/// result layout attribute, result shape, and source shape.
75DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
76 ArrayRef<int64_t> resShape,
77 ArrayRef<int64_t> srcShape);
78
79/// Infers the source layout attribute for a reduction operation given the
80/// result layout attribute and reduced dims.
81DistributeLayoutAttr
82inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
83 SmallVector<int64_t> reduceDims);
84
85/// Infers the source layout attribute for a reduction operation given the
86/// result layout attribute and reduced dims.
87DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout);
88
89/// Infers the source layout attribute for a transpose operation given the
90/// result layout attribute and permutation.
91DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout,
92 ArrayRef<int64_t> permutation);
93
94/// Infers the source layout attribute for a bitcast operation given the
95/// result layout attribute, result element type bitwidth, and source element
96/// type bitwidth.
97DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
98 int resElemTyBitWidth,
99 int srcElemTyBitWidth);
100
101/// Infers the source layout attribute for a shape cast operation given the
102/// result layout attribute, result shape, and source shape.
103DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
104 ArrayRef<int64_t> resShape,
105 ArrayRef<int64_t> srcShape);
106
107/// Infers the source layout attribute for an insert strided slice operation
108/// given the result layout attribute, result shape, and source shape. Removes
109/// leading dimensions from the result layout to match the source shape size.
110DistributeLayoutAttr
111inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
112 ArrayRef<int64_t> resShape,
113 ArrayRef<int64_t> srcShape);
114
115/// Infers the layout attribute for mask and offset operand for Chunked load
116/// and store, given the anchor layout attribute for the value being load/store.
117DistributeLayoutAttr
118inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
119 int chunkSize);
120
121/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
122/// the result.
123///
124/// This function first attempts to construct a source layout that, when
125/// sliced along reduction dimensions, produces a result layout compatible
126/// with the consumer's preferred layout. This minimizes data redistribution
127/// overhead. The SliceAttr for the result is then created based on the
128/// derived source layout and the specified reduction dimensions.
130 VectorType srcVectorTy,
131 DistributeLayoutAttr consumerLayout,
132 SmallVector<int64_t> reductionDims,
133 int numSg, const uArch::uArch *uArch);
134
135/// Sets up layout for Reduction operations by creating a SliceAttr for the
136/// result.
137SliceAttr setupReductionResultLayout(LayoutKind layoutKind,
138 VectorType srcVectorTy,
139 const uArch::uArch *uArch);
140
141/// Setup the result layout attribute for a bitcast operation based on element
142/// type bitwidths. This ensures the source layout can always be derived from
143/// the result layout.
144///
145/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
146/// resElemTyBitWidth), the result layout's innermost dimension data sizes
147/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This maintains
148/// the invariant that the source layout can be recovered by adjusting the
149/// result layout based on bitwidth ratio of input vs output.
150DistributeLayoutAttr setupBitCastResultLayout(
151 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
152 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
153
154/// Sets up the result layout for an insert strided slice operation.
155/// Creates a result layout based on the specified layout kind (InstData or
156/// Lane).
158 LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
159 DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
160
161/// Sets up the anchor layout for a load gather operation.
162DistributeLayoutAttr
163setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
164 int chunkSize, DistributeLayoutAttr consumerLayout,
165 const uArch::uArch *uArch);
166
167/// Sets up the anchor layout for load matrix operation.
168DistributeLayoutAttr
169setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
170 DistributeLayoutAttr consumerLayout,
171 const uArch::uArch *uArch);
172
173/// Sets up the anchor layout for a store scatter operation.
174DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
175 VectorType vectorTy,
176 int chunkSize,
177 const uArch::uArch *uArch);
178
179/// Sets up the anchor layout for a store matrix operation.
180DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
181 VectorType vectorTy,
182 const uArch::uArch *uArch);
183
184/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
185/// The numSg and consumerLayout (optional) are only used by sg layout creation.
186std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
187 DistributeLayoutAttr>>
188setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
189 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
190 const uArch::uArch *uArch);
191
192/// Gets the expected layout for a given consumer operand. This will check if
193/// the owning operation of the consumer operand is one of the special layout
194/// users and determine the expected layout accordingly.
195xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
196
197} // namespace xegpu
198
199} // namespace mlir
200
201#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
void recoverTemporaryLayoutsDeprecated(Operation *op)
[to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and OpResult of of the given opera...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.