MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include <optional>
26
27namespace mlir {
28namespace xegpu {
29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31} // namespace xegpu
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Retrieve the RangeAttr if it is specified.
39static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
40 Operation *parent = op->getParentOfType<scf::IfOp>();
41 while (parent) {
42 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->getAttr("sg_id_range")))
44 return attr;
45 parent = parent->getParentOfType<scf::IfOp>();
46 }
47 return {};
48}
49
50static std::pair<SmallVector<int64_t>, int>
51getSgShapeAndCount(ArrayRef<int64_t> shape,
52 xegpu::DistributeLayoutAttr layout) {
53 int count = 1;
55 if (layout && layout.isForWorkgroup()) {
56 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
57 if (!layout.getEffectiveSgDataAsInt().empty())
58 sgShape = layout.getEffectiveSgDataAsInt();
59 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
60 sgShape = *maybeDerivedSgData;
61 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
62 // Clamp distUnit to the original shape to handle cases where data is
63 // shared among subgroups, which may cause distUnit to exceed the original
64 // shape.
65 for (size_t i = 0; i < distUnit.size(); ++i)
66 distUnit[i] = std::min(shape[i], distUnit[i]);
67 count = computeProduct(shape) / computeProduct(distUnit);
68 }
69 return std::make_pair(sgShape, count);
70}
71
72/// Utility helper for deriving a list of offsets for each sub-TensorDescs
73/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
74/// associated distribute layout attribute, the shape, subgroup id and the
75/// original offsets of the op
76template <
77 typename OpType,
78 typename = std::enable_if_t<llvm::is_one_of<
79 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
80 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81static LogicalResult
82genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
84 Location loc = op.getLoc();
85 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
86 // not applicable to ops without offsets operands.
87 if (origOffsets.empty())
88 return failure();
89
90 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
91 xegpu::DistributeLayoutAttr layout;
92 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
93 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
94 layout = op.getLayoutAttr();
95 } else {
96 layout = op.getDescLayoutAttr();
97 }
98
99 // not applicable to ops without workgroup layout attributes
100 if (!layout || !layout.isForWorkgroup())
101 return failure();
102
103 Value sgId =
104 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
105
106 // verify and adjust the sgId if the range specifier is present
107 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
108 if (sgIdRange) {
109 int64_t startOfRange = sgIdRange.getStart().getInt();
110 int64_t endOfRange = sgIdRange.getEnd().getInt();
111 // verify the RangeAttr against the layout attribute
112 if (layout.getNumSubgroups() != endOfRange - startOfRange)
113 return rewriter.notifyMatchFailure(
114 op, "sg_layout size must match the sg_id_range");
115 // adjust the sgId if necessary
116 if (startOfRange > 0) {
117 Value startOfRangeVal =
118 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
119 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
120 }
121 }
122
123 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
124 // descriptors to be accessed, based on the layout information.
125 ArrayRef<int64_t> wgShape = op.getDataShape();
126 auto maybeDescOffsets =
127 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
128 if (failed(maybeDescOffsets))
129 return failure();
130
131 // Compute the final global offsets for each accessed sub-tensor
132 // or sub-memory descriptor.
133 for (const auto &sgOffsets : *maybeDescOffsets) {
135 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
136 offsetsList.push_back(std::move(newOffsets));
137 }
138
139 // callback(offsetsList);
140 return success();
141}
142
143/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
144/// from a workgroup descriptor. It replaces the offsets and sizes with
145/// appropriate values for the subgroup.
146/// It uses round-robin assignment to distribute the work to the subgroups.
147/// Following create_nd_desc operation:,
148/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
149/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
150/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
151/// is converted to 9 subgroup level operations based on the sg_layout &
152/// sg_data:
153/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
154/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
155/// lane_data = [1, 1]>>
156///
157/// The sg_layout and sg_data attributes are dropped after the pass as they are
158/// no longer needed.
159///
160/// 24x24 matrix distribution example:
161/// sg_layout = [4, 4], sg_data = [2, 2]
162/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
163/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
164///
165/// +------------------------+
166/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
167/// |-----+-----+-----|
168/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
169/// |-----+-----+-----|
170/// | 8x8 | 8x8 | 8x8 |
171/// +------------------------+
172///
173/// Each 8x8 tile is further subdivided among subgroups:
174/// +------------------------+
175/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
176/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
177/// | 2x2 2x2 2x2 2x2 |
178/// | 2x2 2x2 2x2 2x2 |
179/// +------------------------+
180///
181/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
182/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
183
184/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
185/// pattern and all the other ops just follow.
186/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
187/// ops in the pass.
188struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
189 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
190
191 LogicalResult
192 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter) const override {
194 SmallVector<SmallVector<OpFoldResult>> offsetsList;
195 if (failed(genOffsetsList(rewriter, op, offsetsList)))
196 return failure();
197
198 MLIRContext *ctx = op.getContext();
199 xegpu::TensorDescType tdescTy = op.getType();
200 ArrayRef<int64_t> wgShape = tdescTy.getShape();
201 Type elemTy = tdescTy.getElementType();
202 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
203 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
204 auto newTdescTy =
205 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
206 layout.dropSgLayoutAndData());
207
208 SmallVector<Value> newOps;
209 for (auto offsets : offsetsList) {
210 auto newOp = xegpu::CreateNdDescOp::create(
211 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
212 op.getMixedSizes(), op.getMixedStrides());
213
214 newOps.push_back(newOp);
215 }
216 rewriter.replaceOpWithMultiple(op, {newOps});
217
218 return success();
219 }
220};
221
222// This pattern transforms the CreateNdDescOp without offsets to create a
223// subgroup descriptor from a workgroup descriptor
224struct WgToSgCreateNdOpNoOffset
225 : public OpConversionPattern<xegpu::CreateNdDescOp> {
226 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
227
228 LogicalResult
229 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override {
231
232 // Check no offsets are specified.
233 if (!op.getMixedOffsets().empty())
234 return failure();
235
236 Location loc = op.getLoc();
237 MLIRContext *ctx = op.getContext();
238 xegpu::TensorDescType tdescTy = op.getType();
239 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
240 if (!layout || !layout.isForWorkgroup())
241 return failure();
242
243 Type elemTy = tdescTy.getElementType();
244 ArrayRef<int64_t> wgShape = tdescTy.getShape();
245
246 SmallVector<int64_t> sgShape;
247 int count;
248 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
249 xegpu::TensorDescType newTdescTy =
250 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
251 layout.dropSgLayoutAndData());
252
253 SmallVector<Value> newCreateNdOps(count);
254 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
255 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
256 op.getSource(), op.getMixedSizes(),
257 op.getMixedStrides());
258 });
259
260 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
261 return success();
262 }
263};
264
265/// This pattern transforms the LoadNdOp to load subgroup data.
266struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
267 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
268 LogicalResult
269 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter) const override {
271 if (!op.getMixedOffsets().empty())
272 return failure();
273
274 SmallVector<Value> newLoadOps;
275 for (auto src : adaptor.getTensorDesc()) {
276 xegpu::TensorDescType tdescTy =
277 dyn_cast<xegpu::TensorDescType>(src.getType());
278 ArrayRef<int64_t> srcShape = tdescTy.getShape();
279 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
280 auto newLoadOp = xegpu::LoadNdOp::create(
281 rewriter, op.getLoc(), newResTy, src,
282 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
283 newLoadOps.push_back(newLoadOp);
284 }
285 rewriter.replaceOpWithMultiple(op, {newLoadOps});
286 return mlir::success();
287 }
288};
289
290/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
291/// It creates a StoreNdOp op to store the updated values to the new subgroup
292/// src tensor descriptors.
293struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
294 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
295 LogicalResult
296 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298 if (!op.getMixedOffsets().empty())
299 return failure();
300
301 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
302 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
304
305 rewriter.eraseOp(op);
306 return success();
307 }
308};
309
310// This pattern transforms the LoadNdOp with explicit offsets to load
311// subgroup data.
312struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
313 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
314 LogicalResult
315 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317
318 SmallVector<SmallVector<OpFoldResult>> offsetsList;
319 if (failed(genOffsetsList(rewriter, op, offsetsList)))
320 return failure();
321
322 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
323 if (layout)
324 layout = layout.dropSgLayoutAndData();
325 SmallVector<Value> newOps;
326 for (auto [tdesc, offsets] :
327 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
328 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
329 VectorType newResTy =
330 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
331 auto newOp = xegpu::LoadNdOp::create(
332 rewriter, op.getLoc(), newResTy, tdesc, offsets,
333 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
334 op.getL2HintAttr(), op.getL3HintAttr(), layout);
335 newOps.push_back(newOp);
336 }
337 rewriter.replaceOpWithMultiple(op, {newOps});
338
339 return success();
340 }
341};
342
343// This pattern transforms the StoreNdOp with explicit offsets to store
344// subgroup data.
345struct WgToSgStoreNdOpWithOffset
346 : public OpConversionPattern<xegpu::StoreNdOp> {
347 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
348 LogicalResult
349 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351 SmallVector<SmallVector<OpFoldResult>> offsetsList;
352 if (failed(genOffsetsList(rewriter, op, offsetsList)))
353 return failure();
354
355 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
356 if (layout)
357 layout = layout.dropSgLayoutAndData();
358 for (auto [v, tdesc, offsets] :
359 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
360 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
361 op.getL1HintAttr(), op.getL2HintAttr(),
362 op.getL3HintAttr(), layout);
363 }
364 rewriter.eraseOp(op);
365
366 return success();
367 }
368};
369
370// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
371// subgroup data.
372struct WgToSgPrefetchNdOpWithOffset
373 : public OpConversionPattern<xegpu::PrefetchNdOp> {
374 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
375 LogicalResult
376 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
377 ConversionPatternRewriter &rewriter) const override {
378 SmallVector<SmallVector<OpFoldResult>> offsetsList;
379 if (failed(genOffsetsList(rewriter, op, offsetsList)))
380 return failure();
381
382 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
383 if (layout)
384 layout = layout.dropSgLayoutAndData();
385 for (auto [tdesc, offsets] :
386 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
387 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
388 op.getL1HintAttr(), op.getL2HintAttr(),
389 op.getL3HintAttr(), layout);
390 }
391 rewriter.eraseOp(op);
392
393 return success();
394 }
395};
396
397/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
398/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
399/// offsets of the new subgroup src tensor descriptors.
400struct WgToSgUpdateNdOffsetOp
401 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
402 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
403 LogicalResult
404 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 llvm::SmallVector<Value> newUpdateTileOffsetOps;
407 for (auto tDesc : adaptor.getTensorDesc()) {
408 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
409 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
410 op.getConstOffsets());
411 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
412 }
413
414 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
415 return success();
416 }
417};
418
419/// This pattern transforms the DpasOp to work at subgroup level.
420struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
421 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
422 LogicalResult
423 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 Location loc = op.getLoc();
426 VectorType resultTy = op.getResult().getType();
427 if (resultTy.getRank() != 2)
428 return failure();
429
430 auto layoutCd = op.getLayoutCdAttr();
431 auto layoutA = op.getLayoutAAttr();
432 auto layoutB = op.getLayoutBAttr();
433 if (!layoutCd || !layoutA || !layoutB)
434 return failure();
435 size_t i = 0;
436 SmallVector<Value> newDpasOps;
437 for (auto aVec : adaptor.getLhs()) {
438 for (auto bVec : adaptor.getRhs()) {
439
440 llvm::SmallVector<Value> operands({aVec, bVec});
441 Value tmpC;
442 if (op.getAcc()) {
443 tmpC = adaptor.getAcc()[i++];
444 operands.push_back(tmpC);
445 }
446
447 ArrayRef<int64_t> aVecShape =
448 llvm::cast<VectorType>(aVec.getType()).getShape();
449 ArrayRef<int64_t> bVecShape =
450 llvm::cast<VectorType>(bVec.getType()).getShape();
451 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
452 resultTy.getElementType());
453 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
454 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
455 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
456 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
457
458 newDpasOps.push_back(newDpasOp);
459 }
460 }
461 rewriter.replaceOpWithMultiple(op, {newDpasOps});
462 return success();
463 }
464};
465
466/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
467struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
468 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
469 LogicalResult
470 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter) const override {
472
473 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
474 if ((offsetSize != 0) || op.getConstOffsetsAttr())
475 return failure();
476
477 for (auto src : adaptor.getTensorDesc())
478 xegpu::PrefetchNdOp::create(
479 rewriter, op.getLoc(), TypeRange(), src,
480 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
481 rewriter.eraseOp(op);
482 return success();
483 }
484};
485
486/// This pattern transforms vector.broadcast ops to work at subgroup level.
487struct WgToSgVectorBroadcastOp
488 : public OpConversionPattern<vector::BroadcastOp> {
489 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
490
491 LogicalResult
492 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter) const override {
494
495 VectorType resultType = op.getResult().getType();
496 ArrayRef<int64_t> wgShape = resultType.getShape();
497
498 xegpu::DistributeLayoutAttr layout =
499 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
500 if (!layout || !layout.isForWorkgroup())
501 return failure();
502
503 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
504 VectorType newResultType =
505 VectorType::get(sgShape, resultType.getElementType());
506
507 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
508 return failure();
509
510 SmallVector<Value> newBroadcastOps;
511 for (auto operand : adaptor.getOperands().front()) {
512 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
513 newResultType, operand);
514
515 newBroadcastOps.push_back(newBroadcast.getResult());
516 }
517 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
518 return success();
519 }
520};
521
522// This pattern transforms elementwise ops to work at subgroup level.
523struct WgToSgElementwiseOp : public ConversionPattern {
524 WgToSgElementwiseOp(MLIRContext *ctx)
525 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
526
527 LogicalResult
528 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
529 ConversionPatternRewriter &rewriter) const override {
530 // Only match ops with elementwise trait and single result.
532 return failure();
533
534 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
535 assert(resultType && "Expected result to be a VectorType");
536
537 ArrayRef<int64_t> wgShape = resultType.getShape();
538
539 xegpu::DistributeLayoutAttr layout =
540 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
541 if (!layout || !layout.isForWorkgroup())
542 return failure();
543
544 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
545
546 size_t numVariants = operands.empty() ? 0 : operands.front().size();
547
548 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
549 return operandVec.size() != numVariants;
550 }))
551 return failure();
552
553 SmallVector<Value> newResults;
554 VectorType newResultType =
555 VectorType::get(sgShape, resultType.getElementType());
556
557 for (size_t i = 0; i < numVariants; ++i) {
558 SmallVector<Value> opOperands;
559 for (auto &operandVec : operands)
560 opOperands.push_back(operandVec[i]);
561
562 OperationState state(op->getLoc(), op->getName());
563 state.addOperands(opOperands);
564 state.addTypes(newResultType);
565 state.addAttributes(op->getAttrs());
566 Operation *newOp = rewriter.create(state);
568 newResults.push_back(newOp->getResult(0));
569 }
570
571 rewriter.replaceOpWithMultiple(op, {newResults});
572 return success();
573 }
574};
575
576// clang-format off
577// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
578// If input_layout and target_layout have identical sg_layout and sg_data,
579// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
580// dropped. For example:
581// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
582// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
583// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
584// becomes:
585// #a = #xegpu.layout<inst_data = [16, 16]>
586// #b = #xegpu.layout<inst_data = [8, 16]>
587// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
588// (vector<16x16xf32> is determined by sg_data = [16, 16])
589//
590// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
591// For example:
592// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
593// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
594// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
595// is lowered to:
596// #a = #xegpu.layout<inst_data = [16, 16]>
597// #b = #xegpu.layout<inst_data = [8, 16]>
598// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
599// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
600// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
601// clang-format on
602struct WgToSgConvertLayoutOp
603 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
604 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
606 LogicalResult
607 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
608 ConversionPatternRewriter &rewriter) const override {
609 Location loc = op.getLoc();
611 VectorType resultType = op.getResult().getType();
612 ArrayRef<int64_t> wgShape = resultType.getShape();
613 auto inputLayout = op.getInputLayout();
614 auto targetLayout = op.getTargetLayout();
616 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
617 !targetLayout.isForWorkgroup())
618 return rewriter.notifyMatchFailure(
619 op, "Input and target layouts must have subgroup layout");
620
621 SmallVector<int64_t> inputSgLayout =
622 inputLayout.getEffectiveSgLayoutAsInt();
623 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
624 SmallVector<int64_t> targetSgLayout =
625 targetLayout.getEffectiveSgLayoutAsInt();
626 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
627
628 // Fast path: if sg_layout and sg_data are identical, no SLM needed
629 if (inputLayout.isCompatibleWith(targetLayout,
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
633
634 SmallVector<Value> newOps(adaptor.getSource());
635 if (inputLayout && targetLayout) {
636 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
639 newOps[i] = newOp;
640 }
641 }
642 rewriter.replaceOpWithMultiple(op, {newOps});
643 return success();
644 }
645
646 // SLM path: layouts differ, need cross-subgroup data redistribution
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
648
649 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
650
651 // Calculate SLM size requirements
652 auto bitWidth = elemTy.getIntOrFloatBitWidth();
653 auto bytesPerElement = bitWidth / 8;
654 auto slmSize = computeProduct(slmShape) * bytesPerElement;
655
656 // Allocate SLM
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
659
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
661 elemTy, nullptr);
662 auto memDesc =
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
664
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(), nullptr);
667
668 // STORE PHASE: Each subgroup stores in SLM using input layout
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
671 if (failed(storeCoords))
672 return failure();
673
674 // Store to SLM
675 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
676 SmallVector<OpFoldResult> storeMatrixOffsets;
677 for (Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
679 }
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets, nullptr /*layout*/);
682 }
683
684 gpu::BarrierOp::create(rewriter, loc);
685
686 // LOAD PHASE: Each target subgroup loads from SLM using target layout
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
689 if (failed(loadCoords))
690 return failure();
691
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
693
694 // Load vectors from SLM
695 SmallVector<Value> finalResults;
696 for (auto coords : *loadCoords) {
697 SmallVector<OpFoldResult> loadMatrixOffsets;
698 for (Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
700 }
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
704
705 finalResults.push_back(loadOp.getResult());
706 }
707
708 rewriter.replaceOpWithMultiple(op, {finalResults});
709 return success();
710 }
711};
712
713// Handles UnrealizedConversionCastOp generated during
714// SCFStructuralTypeConversions (step 1). This op may appear as either a
715// target or source materialization for Vector values, e.g.:
716// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
717// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
718// it could be either 1:N or N:1 cast. In both cases, the pattern
719// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
720// for example, the following scf::forOp
721// ```
722// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
723// %n = use(%arg1): vector<128x128xf16>
724// scf.yield %n : vector<128x128xf16>
725// }
726// ```
727// Could be converted to:
728// ```
729// %1 = unrealized_conversion_cast %0
730// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
731// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
732// -> (vector<16x16xf16>, vector<16x16xf16) {
733// %m = unrealized_conversion_cast %arg1, %arg2
734// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
735// %n = use(%m): vector<128x128xf16>
736// %b = unrealized_conversion_cast %n
737// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
738// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
739// }
740// %cast = unrealized_conversion_cast %for:2
741// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
742// ```
743// TODO: remove it when context-aware type converter is ready.
744struct UnrealizedConversionCastOpPattern
745 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
748
749 mlir::LogicalResult
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override {
752 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
753
754 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
756
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(ValueRange(inputs).getTypes()))
759 return failure();
760
761 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
762 // It is generated by source materialization (e.g., inits to scf forOp).
763 // The input values provided by the adaptor should already be distributed,
764 // and their types should correspond exactly to the result types of the
765 // operation.
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
769 return success();
770 }
771
772 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
773 // It is generated by target materialization (e.g., arguments/results
774 // of scf forOp). All input values must have the same vector type, and
775 // their shape must be evenly divisible by the output vector's shape
776 // (determined by the nature of the workgroup to subgroup distribution).
777 // TODO: it is not safe to do such forward, since such N:1 cast could be
778 // from others.
779 if (op.getNumResults() == 1 &&
780 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
781 rewriter.replaceOpWithMultiple(op, {inputs});
782 return success();
783 }
784
785 return mlir::failure();
786 }
787};
788
789// This pattern distributes arith.constant op into subgroup-level constants
790struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
792
793 LogicalResult
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter) const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
799 return failure();
800
801 xegpu::DistributeLayoutAttr layout =
802 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
803 if (!layout || !layout.isForWorkgroup())
804 return failure();
805
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
808 int count;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
810
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
814
815 if (vecAttr.isSplat()) {
816 // Splat: single value for all subgroups
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
818 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
821 return success();
822 } else if (sgShape == wgShape) { // if the entire vector is shared by all
823 // subgroups, don't distribute
824 auto newConstOp =
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
827 return success();
828 } else {
829 // Non-splat constant
830 // Only supports 1D & 2D
831 // TODO: support other cases that require SLM access
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op, "Unsupported element type for non-splat constant op.");
835
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op, "Only 1D & 2D vector constant supported");
839
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
844
845 // Compute colStride and rowStride, and check for constant strides.
846 if (cols > 1) {
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
849 }
850 if (rows > 1) {
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
853 }
854
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
858 // Check column stride
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op, "Non-constant column stride in constant op.");
866 }
867 // Check row stride
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op, "Non-constant row stride in constant op.");
875 }
876 }
877 }
878
879 // Create a constant for the base tile.
880 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
881 // For 1D case, extract the first sgShape[0] elements.
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
888 }
889 }
890
891 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
892 baseTileValues);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
894
895 // Get subgroup id
896 Value sgId =
897 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
898 auto sgOffsets =
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
900 if (failed(sgOffsets))
901 return failure();
902
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
905 arith::ConstantIndexOp::create(rewriter, loc, colStride));
906 if (rows > 1)
907 strideConsts.insert(
908 strideConsts.begin(),
909 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
910
911 SmallVector<Value> newConstOps;
912 for (auto offsets : *sgOffsets) {
913 // Multiply offset with stride, broadcast it and add to baseConstVec
914 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
915 for (size_t i = 0; i < strideConsts.size(); ++i) {
916 Value mul =
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
921 }
922 // Broadcast to baseConstVec size
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
925 auto finalConst =
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
928 }
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
930 return success();
931 }
932 }
933};
934
935// This pattern transforms the LoadGatherOp with explicit offsets to load
936// subgroup data
937struct WgToSgLoadGatherOpWithOffset
938 : public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
940 LogicalResult
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943
944 if (!op.getOffsets())
945 return failure();
946
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
949 if (!resultType)
950 return failure();
951 ArrayRef<int64_t> wgShape = resultType.getShape();
952
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
954
955 if (!layout || !layout.isForWorkgroup())
956 return failure();
957
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
959
960 // The offsets need to be distributed
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
963 auto maskVecType =
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
969 }
970
971 SmallVector<Value> newLoadOps;
972 auto chunkSizeAttr =
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
981 newLayout);
982 newLoadOps.push_back(newLoadOp);
983 }
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
985 return success();
986 }
987};
988
989// This pattern transforms the StoreScatterOp with explicit offsets to store
990// subgroup data
991struct WgToSgStoreScatterOpWithOffset
992 : public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
994 LogicalResult
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997
998 if (!op.getOffsets())
999 return failure();
1000
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1003 if (!valueType)
1004 return failure();
1005
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1007
1008 if (!layout || !layout.isForWorkgroup())
1009 return failure();
1010
1011 // The offsets need to be distributed
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1014 auto maskVecType =
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1020 }
1021
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1031 }
1032 rewriter.eraseOp(op);
1033 return success();
1034 }
1035};
1036
1037struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1039 LogicalResult
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter) const override {
1042
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1045 return failure();
1046
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy && "the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1051
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1061 }
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1063
1064 return success();
1065 }
1066};
1067
1068struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1070 LogicalResult
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter) const override {
1073
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1076 return failure();
1077
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1083 return success();
1084 }
1085};
1086
1087// This pattern distributes the vector.step ops to work at subgroup level
1088struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1090 LogicalResult
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter) const override {
1093 xegpu::DistributeLayoutAttr layout =
1094 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1095 if (!layout || !layout.isForWorkgroup())
1096 return failure();
1097
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1103 if (!sgShape)
1104 return failure();
1105
1106 Value sgId =
1107 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1108 auto sgOffsets =
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1110 if (failed(sgOffsets))
1111 return failure();
1112
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (auto offsets : *sgOffsets) {
1117 // Broadcast the offset scalar to a vector & add to the base steps
1118 auto bcastOffset =
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1120 auto finalSteps =
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1123 }
1124
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1126 return success();
1127 }
1128};
1129
1130// This pattern transforms vector.shape_cast ops to work at subgroup level.
1131struct WgToSgVectorShapeCastOp
1132 : public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1134
1135 LogicalResult
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter) const override {
1138
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1140 if (!resultType)
1141 return failure();
1142
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1145 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1146 if (!layout || !layout.isForWorkgroup())
1147 return failure();
1148
1149 // Check that srcShape and destShape, if they differ, only differ by
1150 // expand of unit dimensions.
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1152 if (!srcType)
1153 return failure();
1154
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1156
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1159 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1160 xegpu::DistributeLayoutAttr sourceLayout =
1161 xegpu::getTemporaryLayout(op->getOpOperand(0));
1162
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1166 });
1167 };
1168
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op, "ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1173
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op, "The ShapeCast op only expands dimensions, the result layout "
1177 "must be a slice of the input layout, or vice versa.");
1178 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1179 layoutToDistribute =
1180 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1181 }
1182
1183 SmallVector<int64_t> sgShape =
1184 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1185 VectorType newResultType =
1186 VectorType::get(sgShape, resultType.getElementType());
1187
1188 SmallVector<Value> newShapeCastOps;
1189 for (auto src : adaptor.getSource()) {
1190 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1191 newResultType, src);
1192 newShapeCastOps.push_back(newShapeCast.getResult());
1193 }
1194
1195 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1196 return success();
1197 }
1198};
1199
1200static Value createAccumulator(ConversionPatternRewriter &rewriter,
1201 Location loc, VectorType type,
1202 vector::CombiningKind kind) {
1203 Type elemTy = type.getElementType();
1204
1205 switch (kind) {
1206 case vector::CombiningKind::ADD:
1207 case vector::CombiningKind::XOR:
1208 case vector::CombiningKind::OR:
1209 return arith::ConstantOp::create(
1210 rewriter, loc, type,
1211 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1212
1213 case vector::CombiningKind::MUL:
1214 case vector::CombiningKind::AND:
1215 return arith::ConstantOp::create(
1216 rewriter, loc, type,
1217 DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
1218
1219 case vector::CombiningKind::MINSI:
1220 // Use max signed int value for signed integer min
1221 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1222 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1223 return arith::ConstantOp::create(
1224 rewriter, loc, type,
1226 rewriter.getIntegerAttr(elemTy, maxVal)));
1227 }
1228 return nullptr;
1229
1230 case vector::CombiningKind::MINUI:
1231 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1232 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1233 return arith::ConstantOp::create(
1234 rewriter, loc, type,
1236 rewriter.getIntegerAttr(elemTy, maxVal)));
1237 }
1238 return nullptr;
1239
1240 case vector::CombiningKind::MAXSI:
1241 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1242 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1243 return arith::ConstantOp::create(
1244 rewriter, loc, type,
1246 rewriter.getIntegerAttr(elemTy, minVal)));
1247 }
1248 return nullptr;
1249
1250 case vector::CombiningKind::MAXUI:
1251 return arith::ConstantOp::create(
1252 rewriter, loc, type,
1253 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1254
1255 case vector::CombiningKind::MINNUMF:
1256 case vector::CombiningKind::MINIMUMF:
1257 // Use +infinity for float min operations
1258 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1259 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1260 return arith::ConstantOp::create(
1261 rewriter, loc, type,
1262 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
1263 }
1264 return nullptr;
1265
1266 case vector::CombiningKind::MAXNUMF:
1267 case vector::CombiningKind::MAXIMUMF:
1268 // Use -infinity for float max operations
1269 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1270 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
1271 return arith::ConstantOp::create(
1272 rewriter, loc, type,
1273 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
1274 }
1275 return nullptr;
1276 }
1277 return nullptr;
1278}
1279
1280/// This function converts multi-dimensional subgroup indices into a single
1281/// linear offset. It's used to calculate memory offsets in SLM for
1282/// cross-subgroup reduction coordination.
1283///
1284/// Parameters:
1285/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
1286/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
1287/// z dims)
1288/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
1289/// 4x8x2 subgroups)
1290///
1291/// It uses row-major linearization formula:
1292/// offset = sum(sgIds[dim] * stride[dim])
1293/// where stride[dim] = product of all sgLayout sizes in dimensions after
1294/// 'dim'
1295///
1296/// Example:
1297/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
1298/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
1299/// - Calculation:
1300/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
1301/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
1302/// * linearizedOffset = 1 + 4 = 5
1303///
1304/// This gives us a unique linear index for each combination of subgroup
1305/// positions in the specified dimensions, which is used for SLM row/column
1306/// addressing.
1307static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
1308 Location loc, ArrayRef<Value> sgIds,
1309 ArrayRef<int64_t> dims,
1310 ArrayRef<int64_t> sgLayout) {
1311 Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1312 int64_t stride = 1;
1313
1314 for (int64_t dim : dims) {
1315 Value dimVal = sgIds[dim];
1316 Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
1317 Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1318 linearizedOffset =
1319 arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
1320 stride *= sgLayout[dim];
1321 }
1322
1323 return linearizedOffset;
1324}
1325
1326/// This pattern transforms vector.multi_dim_reduction operations from
1327/// workgroup-level to subgroup-level execution with support for multiple
1328/// reduction dimensions.
1329///
1330/// Steps include:
1331/// 1. LOCAL REDUCTION :
1332/// - Each subgroup performs local reduction on its data slice
1333/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1334/// phase
1335///
1336/// 2. CROSS-SUBGROUP :
1337/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1338/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1339/// - If not needed, adds original accumulator and returns local results
1340///
1341/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1342/// a) SLM Layout Design:
1343/// - Rows: subgroups participating in reduction (product of sg_layout in
1344/// reduction dims)
1345/// - Cols: total result elements across non-reduction dimensions
1346///
1347/// b) Store Phase:
1348/// - Each subgroup stores its local reduction result to SLM
1349/// - Row offset: linearized index of subgroup in reduction dimensions
1350/// - Col offset: linearized index of subgroup in non-reduction dimensions
1351///
1352/// c) Load and Final Reduction Phase:
1353/// - Each subgroup loads a column of data (all reduction participants for
1354/// its position)
1355/// - Performs final reduction along the loaded dimension
1356/// - Adds original accumulator to get final result
1357///
1358struct WgToSgMultiDimReductionOp
1359 : public OpConversionPattern<vector::MultiDimReductionOp> {
1360 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1361
1362 LogicalResult
1363 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1364 ConversionPatternRewriter &rewriter) const override {
1365 Location loc = op.getLoc();
1366
1367 VectorType srcType = op.getSourceVectorType();
1368 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1369 if (!dstType)
1370 return failure();
1371
1372 auto originalSrcShape = srcType.getShape();
1373 xegpu::DistributeLayoutAttr layout =
1374 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1375 if (!layout || !layout.isForWorkgroup())
1376 return failure();
1377
1378 auto reductionDims = llvm::to_vector(op.getReductionDims());
1379
1380 // Get sg_layout and sg_data from the parent layout
1381 SmallVector<int64_t> sgLayout;
1382 SmallVector<int64_t> sgData;
1383 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1384 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1385 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1386 } else
1387 return rewriter.notifyMatchFailure(
1388 op, "Reduction should have SliceAttr layout");
1389
1390 Type elemTy = dstType.getElementType();
1391
1392 // Step 1: perform local subgroup reductions with ZERO accumulator
1393 SmallVector<Value> localReductions;
1394 SmallVector<int64_t> sgShape =
1395 getSgShapeAndCount(originalSrcShape, layout).first;
1396 VectorType newDstType = VectorType::get(sgShape, elemTy);
1397 for (auto sgSrc : adaptor.getSource()) {
1398 // Create ZERO accumulator for local reduction
1399 auto neutralLocalAcc =
1400 createAccumulator(rewriter, loc, newDstType, op.getKind());
1401 // Local reduction with ZERO accumulator
1402 auto localReduce = vector::MultiDimReductionOp::create(
1403 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1404 reductionDims);
1405 localReductions.push_back(localReduce.getResult());
1406 }
1407
1408 // Check if cross-subgroup reduction is needed for any reduction dimension
1409 SmallVector<int64_t> crossSgReductionDims;
1410 for (int64_t reductionDim : reductionDims) {
1411 bool needsCrossSubgroupReduction =
1412 (sgLayout[reductionDim] > 1) &&
1413 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1414
1415 if (needsCrossSubgroupReduction) {
1416 crossSgReductionDims.push_back(reductionDim);
1417 }
1418 }
1419
1420 // If no cross-subgroup reduction needed, add accumulator and return
1421 if (crossSgReductionDims.empty()) {
1422 SmallVector<Value> results;
1423 for (auto localResult : localReductions) {
1424 auto finalResult = vector::makeArithReduction(
1425 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1426 results.push_back(finalResult);
1427 }
1428 rewriter.replaceOpWithMultiple(op, {results});
1429 return success();
1430 }
1431
1432 // Step 2: cross-subgroup reduction using SLM
1433
1434 // Calculate total elements in local result
1435 int64_t localElements = computeProduct(sgShape);
1436
1437 // Shape cast for SLM storage - store as [1, localElements]
1438 SmallVector<int64_t> storeShape2D = {1, localElements};
1439 VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
1440 auto storeShapeCast = vector::ShapeCastOp::create(
1441 rewriter, loc, storeType2D, localReductions[0]);
1442 Value storeData = storeShapeCast.getResult();
1443
1444 // Calculate SLM shape - rows for sg's in reduction dims, cols for total
1445 // result elements across all subgroups in non-reduction dimensions
1446 int64_t totalReductionSubgroups = 1;
1447 for (int64_t dim : crossSgReductionDims) {
1448 totalReductionSubgroups *= sgLayout[dim];
1449 }
1450
1451 // Total result elements across all subgroups in non-reduction dimensions
1452 int64_t totalResultElements =
1453 localElements * computeProduct(sgLayout) / totalReductionSubgroups;
1454
1455 SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
1456 totalResultElements};
1457
1458 // Allocate SLM
1459 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1460 auto bytesPerElement = bitWidth / 8;
1461 int64_t slmElements = slmShape2D[0] * slmShape2D[1];
1462 auto slmSize = slmElements * bytesPerElement;
1463 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1464 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1465
1466 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
1467 slmShape2D, elemTy, nullptr);
1468 auto memDesc =
1469 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1470
1471 // Step 4: Store local results to SLM
1472 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1473 rewriter.getIndexType(), nullptr);
1474
1475 // Convert sgLayout to Values for delinearizeIndex
1476 SmallVector<Value> sgLayoutValues;
1477 for (int64_t dim : sgLayout)
1478 sgLayoutValues.push_back(
1479 arith::ConstantIndexOp::create(rewriter, loc, dim));
1480
1481 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1482 sgLayoutValues);
1483 if (failed(sgIdsResult))
1484 return failure();
1485 SmallVector<Value> sgIds = *sgIdsResult;
1486
1487 // Row offset: linearize reduction dimension indices
1488 Value rowOffsetStore = linearizeSubgroupIndices(
1489 rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
1490
1491 // Column offset: linearize non-reduction dimension indices
1492 SmallVector<int64_t> nonReductionDims;
1493 for (size_t i = 0; i < sgLayout.size(); ++i) {
1494 if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
1495 nonReductionDims.push_back(static_cast<int64_t>(i));
1496 }
1497 }
1498
1499 Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
1500 nonReductionDims, sgLayout);
1501
1502 Value localElementsVal =
1503 arith::ConstantIndexOp::create(rewriter, loc, localElements);
1504 colOffset =
1505 arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
1506
1507 SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
1508
1509 xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
1510 storeOffsets2D, /*layout=*/nullptr);
1511
1512 gpu::BarrierOp::create(rewriter, loc);
1513
1514 // Step 5: Load from SLM for final reduction
1515 SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
1516 VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
1517
1518 // Load offsets - each subgroup loads its column based on non-reduction
1519 // position
1520 Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
1521
1522 SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
1523
1524 auto loadOp = xegpu::LoadMatrixOp::create(
1525 rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
1526 /*layout=*/nullptr);
1527
1528 // Step 6: Perform final reduction with ZERO accumulator
1529 SmallVector<int64_t> finalReductionDims = {0};
1530 SmallVector<int64_t> finalResultShape = {localElements};
1531 VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
1532
1533 auto neutralFinalAcc =
1534 createAccumulator(rewriter, loc, finalResultType, op.getKind());
1535
1536 auto finalReduce = vector::MultiDimReductionOp::create(
1537 rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
1538 neutralFinalAcc, finalReductionDims);
1539
1540 // Step 7: Add the original accumulator at the end
1541 Value originalAcc = adaptor.getAcc()[0];
1542 Value accToAdd = originalAcc;
1543
1544 // Handle shape mismatch by shape casting
1545 if (originalAcc.getType() != finalReduce.getResult().getType()) {
1546 auto originalAccType = cast<VectorType>(originalAcc.getType());
1547 auto finalResultType =
1548 cast<VectorType>(finalReduce.getResult().getType());
1549
1550 // If they have the same number of elements, just shape cast
1551 if (originalAccType.getNumElements() ==
1552 finalResultType.getNumElements()) {
1553 auto shapeCast = vector::ShapeCastOp::create(
1554 rewriter, loc, finalResultType, originalAcc);
1555 accToAdd = shapeCast.getResult();
1556 }
1557 }
1558
1559 auto finalResult = vector::makeArithReduction(
1560 rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
1561
1562 rewriter.replaceOp(op, finalResult);
1563 return success();
1564 }
1565};
1566
1567// This pattern transforms vector.transpose ops to work at subgroup level.
1568struct WgToSgVectorTransposeOp
1569 : public OpConversionPattern<vector::TransposeOp> {
1570 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1571
1572 LogicalResult
1573 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter) const override {
1575 VectorType resultType = op.getResultVectorType();
1576
1577 ArrayRef<int64_t> wgShape = resultType.getShape();
1578 xegpu::DistributeLayoutAttr layout =
1579 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1580 if (!layout || !layout.isForWorkgroup())
1581 return failure();
1582 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1583 xegpu::DistributeLayoutAttr sourceLayout =
1584 xegpu::getDistributeLayoutAttr(op.getVector());
1585 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1586 return failure();
1587
1588 SmallVector<int64_t> sourceSgLayout =
1589 sourceLayout.getEffectiveSgLayoutAsInt();
1590 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1591 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
1592 DenseI32ArrayAttr resultOrder = layout.getOrder();
1593
1594 if (!sourceOrder || !resultOrder) {
1595 return rewriter.notifyMatchFailure(
1596 op, "Both source and result must have order attributes");
1597 }
1598
1599 ArrayRef<int64_t> permutation = op.getPermutation();
1600 size_t permutationSize = permutation.size();
1601 if (sourceSgLayout.size() != permutationSize ||
1602 resultSgLayout.size() != permutationSize) {
1603 return rewriter.notifyMatchFailure(
1604 op, "Layouts and permutation must have the same rank");
1605 }
1606
1607 // Check that sgLayout, sgData & order are properly transposed for source
1608 // and result
1609 if (!layout.isTransposeOf(sourceLayout, permutation))
1610 return rewriter.notifyMatchFailure(
1611 op, "Result layout is not a valid transpose of source layout "
1612 "according to permutation");
1613
1614 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1615 VectorType newResultType =
1616 VectorType::get(sgShape, resultType.getElementType());
1617 SmallVector<Value> newTransposeOps;
1618 for (auto src : adaptor.getVector()) {
1619 auto newTranspose = vector::TransposeOp::create(
1620 rewriter, op.getLoc(), newResultType, src, permutation);
1621 newTransposeOps.push_back(newTranspose.getResult());
1622 }
1623
1624 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1625 return success();
1626 }
1627};
1628
1629// Distribute vector mask ops to work at subgroup level.
1630template <typename MaskOpType>
1631struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1632 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1633
1634 LogicalResult matchAndRewrite(
1635 MaskOpType op,
1636 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter) const override {
1638 xegpu::DistributeLayoutAttr layout =
1639 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1640 if (!layout || !layout.isForWorkgroup())
1641 return failure();
1642
1643 Location loc = op.getLoc();
1644 VectorType type = op.getResult().getType();
1645 auto wgShape = type.getShape();
1646
1647 SmallVector<Value> wgMaskDimSizes;
1648 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1649 for (int64_t maskSize : op.getMaskDimSizes()) {
1650 wgMaskDimSizes.push_back(
1651 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1652 }
1653 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1654 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1655 }
1656
1657 Value sgId =
1658 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1659 auto sgOffsets =
1660 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1661 if (failed(sgOffsets))
1662 return failure();
1663
1664 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1665 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1666
1667 // In each dimension, each subgroup computes its local mask size as:
1668 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1669 SmallVector<Value> newCreateMaskOps;
1670 for (auto offsetSet : *sgOffsets) {
1671 SmallVector<Value> maskOperands;
1672
1673 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1674 Value dimSizeVal =
1675 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1676 Value offset = offsetSet[i];
1677 Value adjustedMaskSize =
1678 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1679 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1680 Value nonNegative =
1681 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1682 Value sgMaskSize =
1683 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1684 maskOperands.push_back(sgMaskSize);
1685 }
1686
1687 auto newCreateMaskOp =
1688 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1689 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1690 }
1691
1692 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1693 return success();
1694 }
1695};
1696
1697using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1698using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1699} // namespace
1700
1701namespace mlir {
1702namespace xegpu {
1704 patterns
1705 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1706 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1707 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1708 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1709 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1710 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1711 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1712 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1713 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1714 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1715 patterns.getContext());
1716}
1717} // namespace xegpu
1718} // namespace mlir
1719
1720namespace {
1721struct XeGPUWgToSgDistributePass
1722 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1723 void runOnOperation() override;
1724};
1725} // namespace
1726
1727void XeGPUWgToSgDistributePass::runOnOperation() {
1728
1729 Operation *op = getOperation();
1731 signalPassFailure();
1732 return;
1733 }
1734
1735 // Track existing UnrealizedConversionCastOps
1736 SmallVector<Operation *> existingCastOps;
1737 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1738 existingCastOps.push_back(castOp.getOperation());
1739 });
1740
1741 {
1742 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1743 // VectorType operands. This first converts such operands to
1744 // RankedTensorType, propagates the layout attribute into the encoding
1745 // attribute, and finally converts the RankedTensorType to VectorType based
1746 // on the encoding.
1747
1748 TypeConverter converter;
1749 converter.addConversion([&](Type type) -> Type { return type; });
1750 converter.addConversion(
1751 [&](RankedTensorType type,
1752 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1753 Type elemTy = type.getElementType();
1754 ArrayRef<int64_t> shape = type.getShape();
1755
1756 int count;
1757 SmallVector<int64_t> subShape;
1758 std::tie(subShape, count) = getSgShapeAndCount(
1759 shape,
1760 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1761
1762 auto newTy = VectorType::get(subShape, elemTy);
1763 result.append(count, newTy);
1764 return success();
1765 });
1766
1768 converter);
1769 }
1770
1771 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1772 // as well as XeGPU, Arith, and Vector operations.
1773 MLIRContext *ctx = &getContext();
1774 RewritePatternSet patterns(ctx);
1775 ConversionTarget target(*ctx);
1776 TypeConverter converter;
1777 converter.addConversion([&](Type type) -> Type { return type; });
1778 converter.addConversion(
1779 [&](xegpu::TensorDescType type,
1780 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1781 Type elemTy = type.getElementType();
1782 ArrayRef<int64_t> shape = type.getShape();
1783
1784 int count;
1785 SmallVector<int64_t> subShape;
1786 xegpu::LayoutAttr layout = type.getLayoutAttr();
1787 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1788
1789 if (layout)
1790 layout = layout.dropSgLayoutAndData();
1791
1792 auto newTy = xegpu::TensorDescType::get(
1793 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1794 result.append(count, newTy);
1795 return success();
1796 });
1797
1798 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1799 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1800 return createOp.getType();
1801 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1802 return loadOp.getTensorDescType();
1803 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1804 return storeOp.getTensorDescType();
1805 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1806 return updateOp.getType();
1807 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1808 return prefetchOp.getTensorDescType();
1809 return xegpu::TensorDescType();
1810 };
1811
1812 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1813 return !layout || !layout.isForWorkgroup();
1814 };
1815
1816 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1817 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1818 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1819 auto tdescTy = getTensorDescType(op);
1820 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1821 return isLegal(layout);
1822 });
1823
1824 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1825 auto layout = op.getLayoutCdAttr();
1826 return isLegal(layout);
1827 });
1828
1829 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1830 [=](xegpu::LoadMatrixOp op) -> bool {
1831 return isLegal(op.getLayoutAttr());
1832 });
1833
1834 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1835 [=](xegpu::StoreMatrixOp op) -> bool {
1836 return isLegal(op.getLayoutAttr());
1837 });
1838
1839 target.addDynamicallyLegalOp<arith::ConstantOp>(
1840 [=](arith::ConstantOp op) -> bool {
1841 auto vecType = dyn_cast<VectorType>(op.getType());
1842 if (!vecType)
1843 return true;
1844
1845 auto layout =
1846 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1847 return isLegal(layout);
1848 });
1849
1850 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1851 vector::TransposeOp, vector::BroadcastOp,
1852 vector::MultiDimReductionOp,
1853 vector::ConstantMaskOp, vector::CreateMaskOp>(
1854 [=](Operation *op) -> bool {
1855 // Check for either a SliceAttr or LayoutAttr on the result.
1856 auto layout =
1857 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1858 return isLegal(layout);
1859 });
1860
1861 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1862 [=](xegpu::LoadGatherOp op) -> bool {
1863 auto layout = op.getLayoutAttr();
1864 return isLegal(layout);
1865 });
1866
1867 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1868 [=](xegpu::StoreScatterOp op) -> bool {
1869 auto layout = op.getLayoutAttr();
1870 return isLegal(layout);
1871 });
1872
1873 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1874 [=](xegpu::ConvertLayoutOp op) -> bool {
1875 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1876 });
1877
1878 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1879 [=](Operation *op) -> std::optional<bool> {
1880 // Only handle elementwise mappable ops
1882 return true;
1883
1884 VectorType resultType =
1885 dyn_cast<VectorType>(op->getResult(0).getType());
1886 if (!resultType)
1887 return true;
1888
1889 // Check if all operands are vectors of the same shape
1890 // TODO: Support other types.
1891 for (Value operand : op->getOperands()) {
1892 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1893 if (!operandType || operandType.getShape() != resultType.getShape()) {
1894 return true;
1895 }
1896 }
1897
1898 xegpu::DistributeLayoutAttr layout =
1900 return isLegal(layout);
1901 });
1902
1903 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1904 [=](UnrealizedConversionCastOp op) {
1905 return llvm::is_contained(existingCastOps, op.getOperation());
1906 });
1907
1908 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1909
1911 target);
1913 if (failed(
1914 applyPartialConversion(getOperation(), target, std::move(patterns))))
1915 return signalPassFailure();
1916
1917 // Remove layout attributes from SCF ops
1918 getOperation()->walk([](Operation *op) {
1919 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1920 return;
1921
1922 SmallVector<StringAttr> attrsToRemove;
1923 for (auto namedAttr : op->getDiscardableAttrs()) {
1924 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1925 attrsToRemove.push_back(namedAttr.getName());
1926 }
1927 for (auto attrName : attrsToRemove)
1928 op->removeDiscardableAttr(attrName);
1929 });
1930}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:472
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.