MLIR 23.0.0git
XeGPUArrayLengthOptimization.cpp
Go to the documentation of this file.
1//===- XeGPUArrayLengthOptimization.cpp - Array Length Opt -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "llvm/ADT/SmallVector.h"
17
18#define DEBUG_TYPE "xegpu-array-length-optimization"
19
20using namespace mlir;
21
22namespace {
23
24// Fallback subgroup size used when the target uArch cannot be resolved from
25// the op (e.g. standalone unit tests with no chip attribute attached).
26constexpr int64_t DEFAULT_SUBGROUP_SIZE = 16;
27
28/// Return the subgroup size for `op`'s target uArch, falling back to
29/// DEFAULT_SUBGROUP_SIZE if no chip attribute is attached or the chip is not
30/// recognized.
31static int64_t getSubgroupSize(Operation *op) {
32 auto chipStr = xegpu::getChipStr(op);
33 if (!chipStr)
34 return DEFAULT_SUBGROUP_SIZE;
35 const xegpu::uArch::uArch *targetUArch =
36 xegpu::uArch::getUArch(chipStr.value());
37 if (!targetUArch)
38 return DEFAULT_SUBGROUP_SIZE;
39 return targetUArch->getSubgroupSize();
40}
41
42/// Helper to compute array_length from FCD and subgroup size.
43/// TODO: Currently, we are only allowing subgroupSize as our new FCD for LANE
44/// level distribution simplicity. But it can be different, and in the future,
45/// we can add that support.
46static int64_t computeArrayLength(int64_t fcdSize, int64_t subgroupSize) {
47 if (fcdSize <= subgroupSize)
48 return 1;
49 return fcdSize / subgroupSize;
50}
51
52/// Check if a 2D `xegpu.create_nd_tdesc` can be optimized into an
53/// array-length-enabled descriptor. Applies only when the FCD is an integer
54/// multiple of the subgroup size larger than the subgroup size itself and the
55/// tensor desc does not already carry an array_length.
56static bool needsOptimization(xegpu::TensorDescType tdescType,
57 int64_t subgroupSize) {
58 auto shape = tdescType.getShape();
59 if (shape.size() != 2)
60 return false;
61
62 int64_t fcd = shape[1];
63 if (fcd % subgroupSize != 0)
64 return false;
65
66 return fcd > subgroupSize && tdescType.getArrayLength() == 1;
67}
68
69/// Returns true if `loadOp` carries a non-identity transpose attribute. A
70/// transpose of `[0, 1]` is the identity and is therefore treated as absent.
71static bool hasNonIdentityTranspose(xegpu::LoadNdOp loadOp) {
72 auto transpose = loadOp.getTranspose();
73 if (!transpose)
74 return false;
75 ArrayRef<int64_t> perm = *transpose;
76 return !(perm.size() == 2 && perm[0] == 0 && perm[1] == 1);
77}
78
79/// Returns true if `tdescType` carries a lane layout that signals a
80/// transpose-intent load (lane_layout = `[SG, 1]`). Such descriptors are
81/// rewritten by the transpose peephole optimization and must not be touched
82/// here, since stacking the array blocks along the non-FCD dimension would
83/// invalidate that rewrite.
84static bool hasTransposeLaneLayout(xegpu::TensorDescType tdescType) {
85 auto layout = tdescType.getLayoutAttr();
86 if (!layout)
87 return false;
88 SmallVector<int64_t> laneLayout = layout.getEffectiveLaneLayoutAsInt();
89 if (laneLayout.size() != 2)
90 return false;
91 return laneLayout[0] != 1 && laneLayout[1] == 1;
92}
93
94/// Rewrite `xegpu.create_nd_tdesc` to fold an array_length attribute into the
95/// resulting tensor descriptor type. Supports static memref, dynamic-shape
96/// memref, and raw-pointer (integer) sources — the memory region described by
97/// `shape`/`strides` is unchanged; only the tensor_desc view is narrowed along
98/// the FCD and tagged with `array_length`. Skipped if any consumer load_nd
99/// carries a non-identity transpose, since stacking the array blocks along the
100/// non-FCD dimension would invalidate that load.
101class OptimizeCreateNdDescOp : public OpRewritePattern<xegpu::CreateNdDescOp> {
102public:
103 using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
104
105 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
106 PatternRewriter &rewriter) const override {
107 // sub-byte type is not supported for now.
108 if (op.getType().getElementTypeBitWidth() < 8)
109 return failure();
110 int64_t subgroupSize = getSubgroupSize(op);
111 auto tdescType = op.getType();
112 if (!needsOptimization(tdescType, subgroupSize))
113 return failure();
114
115 // A transpose lane layout marks this descriptor as a candidate for the
116 // separate transpose peephole; stacking the array blocks would break it.
117 if (hasTransposeLaneLayout(tdescType))
118 return failure();
119
120 Value source = op.getSource();
121 if (!isa<MemRefType, IntegerType>(source.getType()))
122 return failure();
123
124 // Bail out if any consumer is a transposing load_nd.
125 for (Operation *user : op.getResult().getUsers()) {
126 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(user))
127 if (hasNonIdentityTranspose(loadOp))
128 return failure();
129 }
130
131 auto shape = tdescType.getShape();
132 int64_t arrayLength = computeArrayLength(shape[1], subgroupSize);
133 SmallVector<int64_t> newShape = {shape[0], shape[1] / arrayLength};
134
135 auto newTdescType = xegpu::TensorDescType::get(
136 newShape, tdescType.getElementType(), arrayLength,
137 tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
138 tdescType.getLayout());
139
140 // The memory region is unchanged; pass through the existing shape/strides.
141 // The general builder recognizes the static-memref case and drops the
142 // redundant attributes.
143 auto newOp = xegpu::CreateNdDescOp::create(
144 rewriter, op.getLoc(), newTdescType, source, op.getMixedSizes(),
145 op.getMixedStrides());
146 rewriter.replaceOp(op, newOp.getResult());
147 return success();
148 }
149};
150
151/// Pattern to rewrite xegpu.load_nd operations
152class OptimizeLoadNdOp : public OpRewritePattern<xegpu::LoadNdOp> {
153public:
154 using OpRewritePattern<xegpu::LoadNdOp>::OpRewritePattern;
155
156 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
157 PatternRewriter &rewriter) const override {
158 auto tdescType = op.getTensorDescType();
159 int64_t arrayLength = tdescType.getArrayLength();
160
161 if (arrayLength <= 1)
162 return failure();
163
164 // Transposing loads are not compatible with the stacked-on-non-FCD layout
165 // that this pass produces.
166 if (hasNonIdentityTranspose(op) || hasTransposeLaneLayout(tdescType))
167 return failure();
168
169 auto origVectorType = op.getType();
170 auto origShape = origVectorType.getShape();
171 if (origShape.size() != 2)
172 return failure();
173
174 // The expected vector shape is: [tdesc_non_FCD * array_length, tdesc_FCD]
175 int64_t expectedNonFCD = tdescType.getShape()[0] * arrayLength;
176 int64_t expectedFCD = tdescType.getShape()[1];
177
178 // If already matches expected shape, skip
179 if (origShape[0] == expectedNonFCD && origShape[1] == expectedFCD)
180 return failure();
181
182 // Compute new vector shape for register layout
183 SmallVector<int64_t> newShape = {expectedNonFCD, expectedFCD};
184 auto newVectorType =
185 VectorType::get(newShape, origVectorType.getElementType());
186
187 // Create new LoadNdOp with updated result type
188 auto newLoadOp = xegpu::LoadNdOp::create(
189 rewriter, op.getLoc(), newVectorType, op.getTensorDesc(),
190 op.getMixedOffsets(), op.getPackedAttr(), op.getTransposeAttr(),
191 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
192 op.getLayoutAttr());
193
194 rewriter.replaceOp(op, newLoadOp.getResult());
195 return success();
196 }
197};
198
199/// Rewrite `vector.extract_strided_slice` offsets so they index into the
200/// stacked register layout produced by `OptimizeLoadNdOp`.
201///
202/// The optimized load places `arrayLength` blocks side-by-side in memory
203/// but stacks them along the non-FCD dimension in registers. Given a
204/// tensor desc of shape `[H, W]` with array_length = A:
205///
206/// memory layout (what the extract offsets refer to): `[H, W * A]`
207/// register layout (what the new load returns): `[H * A, W]`
208///
209/// An extract at memory offset `[r, c]` therefore maps to register offset
210/// `[r + (c / W) * H, 0]` — provided the extract is block-aligned in the
211/// FCD dimension, i.e. `c % W == 0`.
212///
213/// Example (`A = 2`, `H = 32`, `W = 16`):
214///
215/// // before
216/// %v = xegpu.load_nd %t : ... -> vector<32x32xf16>
217/// %e = vector.extract_strided_slice %v
218/// {offsets = [0, 16], sizes = [16, 16], strides = [1, 1]}
219/// : vector<32x32xf16> to vector<16x16xf16>
220///
221/// // after (load rewritten to vector<64x16>, extract offset remapped)
222/// %v = xegpu.load_nd %t : ... -> vector<64x16xf16>
223/// %e = vector.extract_strided_slice %v
224/// {offsets = [32, 0], sizes = [16, 16], strides = [1, 1]}
225/// : vector<64x16xf16> to vector<16x16xf16>
226class UpdateExtractStridedSliceOp
227 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
228public:
229 using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
230
231 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
232 PatternRewriter &rewriter) const override {
233 auto sourceType = dyn_cast<VectorType>(op.getSource().getType());
234 if (!sourceType || sourceType.getRank() != 2)
235 return failure();
236
237 auto loadOp = op.getSource().getDefiningOp<xegpu::LoadNdOp>();
238 if (!loadOp)
239 return failure();
240
241 auto tdescType = loadOp.getTensorDescType();
242 int64_t arrayLength = tdescType.getArrayLength();
243 if (arrayLength <= 1)
244 return failure();
245
246 auto offsets = op.getOffsets().getValue();
247 auto sizes = op.getSizes().getValue();
248 auto strides = op.getStrides().getValue();
249
250 if (offsets.size() != 2 || sizes.size() != 2 || strides.size() != 2)
251 return failure();
252
253 int64_t origOffset0 = cast<IntegerAttr>(offsets[0]).getInt();
254 int64_t origOffset1 = cast<IntegerAttr>(offsets[1]).getInt();
255
256 int64_t blockHeight = tdescType.getShape()[0];
257 int64_t arrayWidth = tdescType.getShape()[1];
258
259 // Skip extracts that already live entirely inside block 0: their offsets
260 // are identical in the memory and register layouts, so there is nothing
261 // to rewrite.
262 if (origOffset1 < arrayWidth)
263 return failure();
264
265 // The remap is only well-defined when the extract is aligned to an array
266 // block along the FCD.
267 assert(origOffset1 % arrayWidth == 0 &&
268 "extract offset along FCD must be a multiple of the array width");
269
270 int64_t arrayIndex = origOffset1 / arrayWidth;
271 SmallVector<int64_t> newOffsets = {origOffset0 + arrayIndex * blockHeight,
272 /*offset1=*/0};
273
274 auto toInts = [](ArrayAttr arr) {
275 return llvm::to_vector(llvm::map_range(
276 arr, [](Attribute a) { return cast<IntegerAttr>(a).getInt(); }));
277 };
278 SmallVector<int64_t> sliceSizes = toInts(op.getSizes());
279 SmallVector<int64_t> sliceStrides = toInts(op.getStrides());
280
281 auto newOp = vector::ExtractStridedSliceOp::create(
282 rewriter, op.getLoc(), op.getSource(), newOffsets, sliceSizes,
283 sliceStrides);
284
285 rewriter.replaceOp(op, newOp.getResult());
286 return success();
287 }
288};
289
290} // namespace
291
293 RewritePatternSet &patterns) {
294 patterns.add<OptimizeCreateNdDescOp, OptimizeLoadNdOp,
295 UpdateExtractStridedSliceOp>(patterns.getContext());
296}
return success()
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
virtual int getSubgroupSize() const =0