MLIR 23.0.0git
XeGPUArrayLengthOptimization.cpp
Go to the documentation of this file.
1//===- XeGPUArrayLengthOptimization.cpp - Array Length Opt -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "llvm/ADT/SmallVector.h"
17
18#define DEBUG_TYPE "xegpu-array-length-optimization"
19
20using namespace mlir;
21
22namespace {
23
24// Fallback subgroup size used when the target uArch cannot be resolved from
25// the op (e.g. standalone unit tests with no chip attribute attached).
26constexpr int64_t DEFAULT_SUBGROUP_SIZE = 16;
27
28/// Return the subgroup size for `op`'s target uArch, falling back to
29/// DEFAULT_SUBGROUP_SIZE if no chip attribute is attached or the chip is not
30/// recognized.
31static int64_t getSubgroupSize(Operation *op) {
32 auto chipStr = xegpu::getChipStr(op);
33 if (!chipStr)
34 return DEFAULT_SUBGROUP_SIZE;
35 const xegpu::uArch::uArch *targetUArch =
36 xegpu::uArch::getUArch(chipStr.value());
37 if (!targetUArch)
38 return DEFAULT_SUBGROUP_SIZE;
39 return targetUArch->getSubgroupSize();
40}
41
42/// Helper to compute array_length from FCD and subgroup size.
43/// TODO: Currently, we are only allowing subgroupSize as our new FCD for LANE
44/// level distribution simplicity. But it can be different, and in the future,
45/// we can add that support.
46static int64_t computeArrayLength(int64_t fcdSize, int64_t subgroupSize) {
47 if (fcdSize <= subgroupSize)
48 return 1;
49 return fcdSize / subgroupSize;
50}
51
52/// Check if a 2D `xegpu.create_nd_tdesc` can be optimized into an
53/// array-length-enabled descriptor. Applies only when the FCD is an integer
54/// multiple of the subgroup size larger than the subgroup size itself and the
55/// tensor desc does not already carry an array_length.
56static bool needsOptimization(xegpu::TensorDescType tdescType,
57 int64_t subgroupSize) {
58 auto shape = tdescType.getShape();
59 if (shape.size() != 2)
60 return false;
61
62 int64_t fcd = shape[1];
63 if (fcd % subgroupSize != 0)
64 return false;
65
66 return fcd > subgroupSize && tdescType.getArrayLength() == 1;
67}
68
69/// Returns true if `loadOp` carries a non-identity transpose attribute. A
70/// transpose of `[0, 1]` is the identity and is therefore treated as absent.
71static bool hasNonIdentityTranspose(xegpu::LoadNdOp loadOp) {
72 auto transpose = loadOp.getTranspose();
73 if (!transpose)
74 return false;
75 ArrayRef<int64_t> perm = *transpose;
76 return !(perm.size() == 2 && perm[0] == 0 && perm[1] == 1);
77}
78
79/// Returns true if `tdescType` carries a lane layout that signals a
80/// transpose-intent load (lane_layout = `[SG, 1]`). Such descriptors are
81/// rewritten by the transpose peephole optimization and must not be touched
82/// here, since stacking the array blocks along the non-FCD dimension would
83/// invalidate that rewrite.
84static bool hasTransposeLaneLayout(xegpu::TensorDescType tdescType) {
85 auto layout = tdescType.getLayoutAttr();
86 if (!layout)
87 return false;
88 SmallVector<int64_t> laneLayout = layout.getEffectiveLaneLayoutAsInt();
89 if (laneLayout.size() != 2)
90 return false;
91 return laneLayout[0] != 1 && laneLayout[1] == 1;
92}
93
94/// Rewrite `xegpu.create_nd_tdesc` to fold an array_length attribute into the
95/// resulting tensor descriptor type. Supports static memref, dynamic-shape
96/// memref, and raw-pointer (integer) sources — the memory region described by
97/// `shape`/`strides` is unchanged; only the tensor_desc view is narrowed along
98/// the FCD and tagged with `array_length`. Skipped if any consumer load_nd
99/// carries a non-identity transpose, since stacking the array blocks along the
100/// non-FCD dimension would invalidate that load.
101class OptimizeCreateNdDescOp : public OpRewritePattern<xegpu::CreateNdDescOp> {
102public:
103 using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
104
105 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
106 PatternRewriter &rewriter) const override {
107 int64_t subgroupSize = getSubgroupSize(op);
108 auto tdescType = op.getType();
109 if (!needsOptimization(tdescType, subgroupSize))
110 return failure();
111
112 // A transpose lane layout marks this descriptor as a candidate for the
113 // separate transpose peephole; stacking the array blocks would break it.
114 if (hasTransposeLaneLayout(tdescType))
115 return failure();
116
117 Value source = op.getSource();
118 if (!isa<MemRefType, IntegerType>(source.getType()))
119 return failure();
120
121 // Bail out if any consumer is a transposing load_nd.
122 for (Operation *user : op.getResult().getUsers()) {
123 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(user))
124 if (hasNonIdentityTranspose(loadOp))
125 return failure();
126 }
127
128 auto shape = tdescType.getShape();
129 int64_t arrayLength = computeArrayLength(shape[1], subgroupSize);
130 SmallVector<int64_t> newShape = {shape[0], shape[1] / arrayLength};
131
132 auto newTdescType = xegpu::TensorDescType::get(
133 newShape, tdescType.getElementType(), arrayLength,
134 tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
135 tdescType.getLayout());
136
137 // The memory region is unchanged; pass through the existing shape/strides.
138 // The general builder recognizes the static-memref case and drops the
139 // redundant attributes.
140 auto newOp = xegpu::CreateNdDescOp::create(
141 rewriter, op.getLoc(), newTdescType, source, op.getMixedSizes(),
142 op.getMixedStrides());
143 rewriter.replaceOp(op, newOp.getResult());
144 return success();
145 }
146};
147
148/// Pattern to rewrite xegpu.load_nd operations
149class OptimizeLoadNdOp : public OpRewritePattern<xegpu::LoadNdOp> {
150public:
151 using OpRewritePattern<xegpu::LoadNdOp>::OpRewritePattern;
152
153 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
154 PatternRewriter &rewriter) const override {
155 auto tdescType = op.getTensorDescType();
156 int64_t arrayLength = tdescType.getArrayLength();
157
158 if (arrayLength <= 1)
159 return failure();
160
161 // Transposing loads are not compatible with the stacked-on-non-FCD layout
162 // that this pass produces.
163 if (hasNonIdentityTranspose(op) || hasTransposeLaneLayout(tdescType))
164 return failure();
165
166 auto origVectorType = op.getType();
167 auto origShape = origVectorType.getShape();
168 if (origShape.size() != 2)
169 return failure();
170
171 // The expected vector shape is: [tdesc_non_FCD * array_length, tdesc_FCD]
172 int64_t expectedNonFCD = tdescType.getShape()[0] * arrayLength;
173 int64_t expectedFCD = tdescType.getShape()[1];
174
175 // If already matches expected shape, skip
176 if (origShape[0] == expectedNonFCD && origShape[1] == expectedFCD)
177 return failure();
178
179 // Compute new vector shape for register layout
180 SmallVector<int64_t> newShape = {expectedNonFCD, expectedFCD};
181 auto newVectorType =
182 VectorType::get(newShape, origVectorType.getElementType());
183
184 // Create new LoadNdOp with updated result type
185 auto newLoadOp = xegpu::LoadNdOp::create(
186 rewriter, op.getLoc(), newVectorType, op.getTensorDesc(),
187 op.getMixedOffsets(), op.getPackedAttr(), op.getTransposeAttr(),
188 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
189 op.getLayoutAttr());
190
191 rewriter.replaceOp(op, newLoadOp.getResult());
192 return success();
193 }
194};
195
196/// Rewrite `vector.extract_strided_slice` offsets so they index into the
197/// stacked register layout produced by `OptimizeLoadNdOp`.
198///
199/// The optimized load places `arrayLength` blocks side-by-side in memory
200/// but stacks them along the non-FCD dimension in registers. Given a
201/// tensor desc of shape `[H, W]` with array_length = A:
202///
203/// memory layout (what the extract offsets refer to): `[H, W * A]`
204/// register layout (what the new load returns): `[H * A, W]`
205///
206/// An extract at memory offset `[r, c]` therefore maps to register offset
207/// `[r + (c / W) * H, 0]` — provided the extract is block-aligned in the
208/// FCD dimension, i.e. `c % W == 0`.
209///
210/// Example (`A = 2`, `H = 32`, `W = 16`):
211///
212/// // before
213/// %v = xegpu.load_nd %t : ... -> vector<32x32xf16>
214/// %e = vector.extract_strided_slice %v
215/// {offsets = [0, 16], sizes = [16, 16], strides = [1, 1]}
216/// : vector<32x32xf16> to vector<16x16xf16>
217///
218/// // after (load rewritten to vector<64x16>, extract offset remapped)
219/// %v = xegpu.load_nd %t : ... -> vector<64x16xf16>
220/// %e = vector.extract_strided_slice %v
221/// {offsets = [32, 0], sizes = [16, 16], strides = [1, 1]}
222/// : vector<64x16xf16> to vector<16x16xf16>
223class UpdateExtractStridedSliceOp
224 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
225public:
226 using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
227
228 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
229 PatternRewriter &rewriter) const override {
230 auto sourceType = dyn_cast<VectorType>(op.getSource().getType());
231 if (!sourceType || sourceType.getRank() != 2)
232 return failure();
233
234 auto loadOp = op.getSource().getDefiningOp<xegpu::LoadNdOp>();
235 if (!loadOp)
236 return failure();
237
238 auto tdescType = loadOp.getTensorDescType();
239 int64_t arrayLength = tdescType.getArrayLength();
240 if (arrayLength <= 1)
241 return failure();
242
243 auto offsets = op.getOffsets().getValue();
244 auto sizes = op.getSizes().getValue();
245 auto strides = op.getStrides().getValue();
246
247 if (offsets.size() != 2 || sizes.size() != 2 || strides.size() != 2)
248 return failure();
249
250 int64_t origOffset0 = cast<IntegerAttr>(offsets[0]).getInt();
251 int64_t origOffset1 = cast<IntegerAttr>(offsets[1]).getInt();
252
253 int64_t blockHeight = tdescType.getShape()[0];
254 int64_t arrayWidth = tdescType.getShape()[1];
255
256 // Skip extracts that already live entirely inside block 0: their offsets
257 // are identical in the memory and register layouts, so there is nothing
258 // to rewrite.
259 if (origOffset1 < arrayWidth)
260 return failure();
261
262 // The remap is only well-defined when the extract is aligned to an array
263 // block along the FCD.
264 assert(origOffset1 % arrayWidth == 0 &&
265 "extract offset along FCD must be a multiple of the array width");
266
267 int64_t arrayIndex = origOffset1 / arrayWidth;
268 SmallVector<int64_t> newOffsets = {origOffset0 + arrayIndex * blockHeight,
269 /*offset1=*/0};
270
271 auto toInts = [](ArrayAttr arr) {
272 return llvm::to_vector(llvm::map_range(
273 arr, [](Attribute a) { return cast<IntegerAttr>(a).getInt(); }));
274 };
275 SmallVector<int64_t> sliceSizes = toInts(op.getSizes());
276 SmallVector<int64_t> sliceStrides = toInts(op.getStrides());
277
278 auto newOp = vector::ExtractStridedSliceOp::create(
279 rewriter, op.getLoc(), op.getSource(), newOffsets, sliceSizes,
280 sliceStrides);
281
282 rewriter.replaceOp(op, newOp.getResult());
283 return success();
284 }
285};
286
287} // namespace
288
290 RewritePatternSet &patterns) {
291 patterns.add<OptimizeCreateNdDescOp, OptimizeLoadNdOp,
292 UpdateExtractStridedSliceOp>(patterns.getContext());
293}
return success()
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
virtual int getSubgroupSize() const =0