MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
24#include <optional>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37// Retrieve the RangeAttr if it is specified.
38static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39 Operation *parent = op->getParentOfType<scf::IfOp>();
40 while (parent) {
41 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->getAttr("sg_id_range")))
43 return attr;
44 parent = parent->getParentOfType<scf::IfOp>();
45 }
46 return {};
47}
48
49static std::pair<SmallVector<int64_t>, int>
50getSgShapeAndCount(ArrayRef<int64_t> shape,
51 xegpu::DistributeLayoutAttr layout) {
52 int count = 1;
54 if (layout && layout.isForWorkgroup()) {
55 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
58 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59 sgShape = *maybeDerivedSgData;
60 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61 // Clamp distUnit to the original shape to handle cases where data is
62 // shared among subgroups, which may cause distUnit to exceed the original
63 // shape.
64 for (size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(shape[i], distUnit[i]);
66 count = computeProduct(shape) / computeProduct(distUnit);
67 }
68 return std::make_pair(sgShape, count);
69}
70
71/// Utility helper for deriving a list of offsets for each sub-TensorDescs
72/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73/// associated distribute layout attribute, the shape, subgroup id and the
74/// original offsets of the op
75template <
76 typename OpType,
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80static LogicalResult
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
83 Location loc = op.getLoc();
84 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85 // not applicable to ops without offsets operands.
86 if (origOffsets.empty())
87 return failure();
88
89 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
94 } else {
95 layout = op.getDescLayoutAttr();
96 }
97
98 // not applicable to ops without workgroup layout attributes
99 if (!layout || !layout.isForWorkgroup())
100 return failure();
101
102 Value sgId =
103 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
104
105 // verify and adjust the sgId if the range specifier is present
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
107 if (sgIdRange) {
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
110 // verify the RangeAttr against the layout attribute
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op, "sg_layout size must match the sg_id_range");
114 // adjust the sgId if necessary
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
117 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
119 }
120 }
121
122 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
123 // descriptors to be accessed, based on the layout information.
124 ArrayRef<int64_t> wgShape = op.getDataShape();
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (failed(maybeDescOffsets))
128 return failure();
129
130 // Compute the final global offsets for each accessed sub-tensor
131 // or sub-memory descriptor.
132 for (const auto &sgOffsets : *maybeDescOffsets) {
134 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
135 offsetsList.push_back(std::move(newOffsets));
136 }
137
138 // callback(offsetsList);
139 return success();
140}
141
142/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
143/// from a workgroup descriptor. It replaces the offsets and sizes with
144/// appropriate values for the subgroup.
145/// It uses round-robin assignment to distribute the work to the subgroups.
146/// Following create_nd_desc operation:,
147/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
148/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
149/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
150/// is converted to 9 subgroup level operations based on the sg_layout &
151/// sg_data:
152/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
153/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
154/// lane_data = [1, 1]>>
155///
156/// The sg_layout and sg_data attributes are dropped after the pass as they are
157/// no longer needed.
158///
159/// 24x24 matrix distribution example:
160/// sg_layout = [4, 4], sg_data = [2, 2]
161/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
162/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
163///
164/// +------------------------+
165/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
166/// |-----+-----+-----|
167/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
168/// |-----+-----+-----|
169/// | 8x8 | 8x8 | 8x8 |
170/// +------------------------+
171///
172/// Each 8x8 tile is further subdivided among subgroups:
173/// +------------------------+
174/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
175/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
176/// | 2x2 2x2 2x2 2x2 |
177/// | 2x2 2x2 2x2 2x2 |
178/// +------------------------+
179///
180/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
181/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
182
183/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
184/// pattern and all the other ops just follow.
185/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
186/// ops in the pass.
187struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
189
190 LogicalResult
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (failed(genOffsetsList(rewriter, op, offsetsList)))
195 return failure();
196
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
203 auto newTdescTy =
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
206
207 SmallVector<Value> newOps;
208 for (auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
212
213 newOps.push_back(newOp);
214 }
215 rewriter.replaceOpWithMultiple(op, {newOps});
216
217 return success();
218 }
219};
220
221// This pattern transforms the CreateNdDescOp without offsets to create a
222// subgroup descriptor from a workgroup descriptor
223struct WgToSgCreateNdOpNoOffset
224 : public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
226
227 LogicalResult
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override {
230
231 // Check no offsets are specified.
232 if (!op.getMixedOffsets().empty())
233 return failure();
234
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
240 return failure();
241
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
244
245 SmallVector<int64_t> sgShape;
246 int count;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
251
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
257 });
258
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
260 return success();
261 }
262};
263
264/// This pattern transforms the LoadNdOp to load subgroup data.
265struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
267 LogicalResult
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270 if (!op.getMixedOffsets().empty())
271 return failure();
272
273 SmallVector<Value> newLoadOps;
274 for (auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(
280 rewriter, op.getLoc(), newResTy, src,
281 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
282 newLoadOps.push_back(newLoadOp);
283 }
284 rewriter.replaceOpWithMultiple(op, {newLoadOps});
285 return mlir::success();
286 }
287};
288
289/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
290/// It creates a StoreNdOp op to store the updated values to the new subgroup
291/// src tensor descriptors.
292struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
293 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
294 LogicalResult
295 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297 if (!op.getMixedOffsets().empty())
298 return failure();
299
300 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
301 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
302 op.getL2HintAttr(), op.getL3HintAttr());
303
304 rewriter.eraseOp(op);
305 return success();
306 }
307};
308
309// This pattern transforms the LoadNdOp with explicit offsets to load
310// subgroup data.
311struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
312 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
313 LogicalResult
314 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter) const override {
316
317 SmallVector<SmallVector<OpFoldResult>> offsetsList;
318 if (failed(genOffsetsList(rewriter, op, offsetsList)))
319 return failure();
320
321 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
322 if (layout)
323 layout = layout.dropSgLayoutAndData();
324 SmallVector<Value> newOps;
325 for (auto [tdesc, offsets] :
326 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
327 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
328 VectorType newResTy =
329 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
330 auto newOp = xegpu::LoadNdOp::create(
331 rewriter, op.getLoc(), newResTy, tdesc, offsets,
332 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
333 op.getL2HintAttr(), op.getL3HintAttr(), layout);
334 newOps.push_back(newOp);
335 }
336 rewriter.replaceOpWithMultiple(op, {newOps});
337
338 return success();
339 }
340};
341
342// This pattern transforms the StoreNdOp with explicit offsets to store
343// subgroup data.
344struct WgToSgStoreNdOpWithOffset
345 : public OpConversionPattern<xegpu::StoreNdOp> {
346 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
347 LogicalResult
348 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 SmallVector<SmallVector<OpFoldResult>> offsetsList;
351 if (failed(genOffsetsList(rewriter, op, offsetsList)))
352 return failure();
353
354 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
355 if (layout)
356 layout = layout.dropSgLayoutAndData();
357 for (auto [v, tdesc, offsets] :
358 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
359 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
360 op.getL1HintAttr(), op.getL2HintAttr(),
361 op.getL3HintAttr(), layout);
362 }
363 rewriter.eraseOp(op);
364
365 return success();
366 }
367};
368
369// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
370// subgroup data.
371struct WgToSgPrefetchNdOpWithOffset
372 : public OpConversionPattern<xegpu::PrefetchNdOp> {
373 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
374 LogicalResult
375 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
376 ConversionPatternRewriter &rewriter) const override {
377 SmallVector<SmallVector<OpFoldResult>> offsetsList;
378 if (failed(genOffsetsList(rewriter, op, offsetsList)))
379 return failure();
380
381 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
382 if (layout)
383 layout = layout.dropSgLayoutAndData();
384 for (auto [tdesc, offsets] :
385 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
386 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
387 op.getL1HintAttr(), op.getL2HintAttr(),
388 op.getL3HintAttr(), layout);
389 }
390 rewriter.eraseOp(op);
391
392 return success();
393 }
394};
395
396/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
397/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
398/// offsets of the new subgroup src tensor descriptors.
399struct WgToSgUpdateNdOffsetOp
400 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
401 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
402 LogicalResult
403 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter) const override {
405 llvm::SmallVector<Value> newUpdateTileOffsetOps;
406 for (auto tDesc : adaptor.getTensorDesc()) {
407 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
408 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
409 op.getConstOffsets());
410 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
411 }
412
413 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
414 return success();
415 }
416};
417
418/// This pattern transforms the DpasOp to work at subgroup level.
419struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
420 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
421 LogicalResult
422 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const override {
424 Location loc = op.getLoc();
425 VectorType resultTy = op.getResult().getType();
426 if (resultTy.getRank() != 2)
427 return failure();
428
429 auto layoutCd = op.getLayoutCdAttr();
430 auto layoutA = op.getLayoutAAttr();
431 auto layoutB = op.getLayoutBAttr();
432 if (!layoutCd || !layoutA || !layoutB)
433 return failure();
434 size_t i = 0;
435 SmallVector<Value> newDpasOps;
436 for (auto aVec : adaptor.getLhs()) {
437 for (auto bVec : adaptor.getRhs()) {
438
439 llvm::SmallVector<Value> operands({aVec, bVec});
440 Value tmpC;
441 if (op.getAcc()) {
442 tmpC = adaptor.getAcc()[i++];
443 operands.push_back(tmpC);
444 }
445
446 ArrayRef<int64_t> aVecShape =
447 llvm::cast<VectorType>(aVec.getType()).getShape();
448 ArrayRef<int64_t> bVecShape =
449 llvm::cast<VectorType>(bVec.getType()).getShape();
450 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
451 resultTy.getElementType());
452 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
453 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
454 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
455 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
456
457 newDpasOps.push_back(newDpasOp);
458 }
459 }
460 rewriter.replaceOpWithMultiple(op, {newDpasOps});
461 return success();
462 }
463};
464
465/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
466struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
467 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
468 LogicalResult
469 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
470 ConversionPatternRewriter &rewriter) const override {
471
472 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
473 if ((offsetSize != 0) || op.getConstOffsetsAttr())
474 return failure();
475
476 for (auto src : adaptor.getTensorDesc())
477 xegpu::PrefetchNdOp::create(
478 rewriter, op.getLoc(), TypeRange(), src,
479 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
480 rewriter.eraseOp(op);
481 return success();
482 }
483};
484
485/// This pattern transforms vector.broadcast ops to work at subgroup level.
486struct WgToSgVectorBroadcastOp
487 : public OpConversionPattern<vector::BroadcastOp> {
488 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
489
490 LogicalResult
491 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493
494 VectorType resultType = op.getResult().getType();
495 ArrayRef<int64_t> wgShape = resultType.getShape();
496
497 xegpu::DistributeLayoutAttr layout =
498 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
499 if (!layout || !layout.isForWorkgroup())
500 return failure();
501
502 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
503 VectorType newResultType =
504 VectorType::get(sgShape, resultType.getElementType());
505
506 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
507 return failure();
508
509 SmallVector<Value> newBroadcastOps;
510 for (auto operand : adaptor.getOperands().front()) {
511 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
512 newResultType, operand);
513 xegpu::setTemporaryLayout(newBroadcast->getResult(0),
514 layout.dropSgLayoutAndData());
515
516 newBroadcastOps.push_back(newBroadcast.getResult());
517 }
518 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
519 return success();
520 }
521};
522
523// This pattern transforms elementwise ops to work at subgroup level.
524struct WgToSgElementwiseOp : public ConversionPattern {
525 WgToSgElementwiseOp(MLIRContext *ctx)
526 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
527
528 LogicalResult
529 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
530 ConversionPatternRewriter &rewriter) const override {
531 // Only match ops with elementwise trait and single result.
533 return failure();
534
535 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
536 assert(resultType && "Expected result to be a VectorType");
537
538 ArrayRef<int64_t> wgShape = resultType.getShape();
539
540 xegpu::DistributeLayoutAttr layout =
541 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
542 if (!layout || !layout.isForWorkgroup())
543 return failure();
544
545 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
546
547 size_t numVariants = operands.empty() ? 0 : operands.front().size();
548
549 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
550 return operandVec.size() != numVariants;
551 }))
552 return failure();
553
554 SmallVector<Value> newResults;
555 VectorType newResultType =
556 VectorType::get(sgShape, resultType.getElementType());
557
558 for (size_t i = 0; i < numVariants; ++i) {
559 SmallVector<Value> opOperands;
560 for (auto &operandVec : operands)
561 opOperands.push_back(operandVec[i]);
562
563 OperationState state(op->getLoc(), op->getName());
564 state.addOperands(opOperands);
565 state.addTypes(newResultType);
566 // Copy all attributes, but update "layout_result_0" to drop
567 // sgLayout/sgData
568 state.addAttributes(xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
569 Operation *newOp = rewriter.create(state);
570 newResults.push_back(newOp->getResult(0));
571 }
572
573 rewriter.replaceOpWithMultiple(op, {newResults});
574 return success();
575 }
576};
577
578// clang-format off
579// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
580// If input_layout and target_layout have identical sg_layout and sg_data,
581// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
582// dropped. For example:
583// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
584// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
585// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
586// becomes:
587// #a = #xegpu.layout<inst_data = [16, 16]>
588// #b = #xegpu.layout<inst_data = [8, 16]>
589// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
590// (vector<16x16xf32> is determined by sg_data = [16, 16])
591//
592// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
593// For example:
594// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
595// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
596// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
597// is lowered to:
598// #a = #xegpu.layout<inst_data = [16, 16]>
599// #b = #xegpu.layout<inst_data = [8, 16]>
600// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
601// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
602// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
603// clang-format on
604struct WgToSgConvertLayoutOp
605 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
606 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
607 LogicalResult
608 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610
611 auto input = op.getInputLayout();
612 auto target = op.getTargetLayout();
613
614 if (!input || !target || !input.isForWorkgroup() ||
615 !target.isForWorkgroup())
616 return rewriter.notifyMatchFailure(
617 op, "Input and target layouts must have subgroup layout");
618
619 SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
620 SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
621 DenseI32ArrayAttr inputOrder = input.getOrder();
622 SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
623 SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
624 DenseI32ArrayAttr targetOrder = target.getOrder();
625
626 // TODO: currently we only support for optimal case, where input and
627 // output has the same sg_layout and sg_data, so SLM is not involved.
628 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
629 inputOrder != targetOrder)
630 return failure();
631
632 input = input.dropSgLayoutAndData();
633 target = target.dropSgLayoutAndData();
634
635 SmallVector<Value> newOps(adaptor.getSource());
636 if (input && target) {
637 // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
638 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
639 auto newOp = xegpu::ConvertLayoutOp::create(
640 rewriter, op.getLoc(), src.getType(), src, input, target);
641 newOps[i] = newOp;
642 }
643 }
644 rewriter.replaceOpWithMultiple(op, {newOps});
645 return success();
646 }
647};
648
649// Handles UnrealizedConversionCastOp generated during
650// SCFStructuralTypeConversions (step 1). This op may appear as either a
651// target or source materialization for Vector values, e.g.:
652// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
653// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
654// it could be either 1:N or N:1 cast. In both cases, the pattern
655// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
656// for example, the following scf::forOp
657// ```
658// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
659// %n = use(%arg1): vector<128x128xf16>
660// scf.yield %n : vector<128x128xf16>
661// }
662// ```
663// Could be converted to:
664// ```
665// %1 = unrealized_conversion_cast %0
666// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
667// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
668// -> (vector<16x16xf16>, vector<16x16xf16) {
669// %m = unrealized_conversion_cast %arg1, %arg2
670// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
671// %n = use(%m): vector<128x128xf16>
672// %b = unrealized_conversion_cast %n
673// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
674// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
675// }
676// %cast = unrealized_conversion_cast %for:2
677// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
678// ```
679// TODO: remove it when context-aware type converter is ready.
680struct UnrealizedConversionCastOpPattern
681 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
682 using OpConversionPattern<
683 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
684
685 mlir::LogicalResult
686 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) const override {
688 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
689
690 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
691 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
692
693 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
694 !llvm::all_equal(ValueRange(inputs).getTypes()))
695 return failure();
696
697 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
698 // It is generated by source materialization (e.g., inits to scf forOp).
699 // The input values provided by the adaptor should already be distributed,
700 // and their types should correspond exactly to the result types of the
701 // operation.
702 if (op.getNumOperands() == 1 &&
703 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
704 rewriter.replaceOp(op, inputs);
705 return success();
706 }
707
708 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
709 // It is generated by target materialization (e.g., arguments/results
710 // of scf forOp). All input values must have the same vector type, and
711 // their shape must be evenly divisible by the output vector's shape
712 // (determined by the nature of the workgroup to subgroup distribution).
713 // TODO: it is not safe to do such forward, since such N:1 cast could be
714 // from others.
715 if (op.getNumResults() == 1 &&
716 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
717 rewriter.replaceOpWithMultiple(op, {inputs});
718 return success();
719 }
720
721 return mlir::failure();
722 }
723};
724
725// This pattern distributes arith.constant op into subgroup-level constants
726struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
727 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
728
729 LogicalResult
730 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter) const override {
732 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
733 auto vecType = dyn_cast<VectorType>(op.getType());
734 if (!vecAttr || !vecType)
735 return failure();
736
737 xegpu::DistributeLayoutAttr layout =
738 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
739 if (!layout || !layout.isForWorkgroup())
740 return failure();
741
742 ArrayRef<int64_t> wgShape = vecType.getShape();
743 SmallVector<int64_t> sgShape;
744 int count;
745 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
746
747 auto newType = VectorType::get(sgShape, vecType.getElementType());
748 Location loc = op.getLoc();
749 auto eltType = vecType.getElementType();
750
751 auto setLayout = [&](Value val) {
752 xegpu::setTemporaryLayout(llvm::dyn_cast<OpResult>(val),
753 layout.dropSgLayoutAndData());
754 };
755
756 if (vecAttr.isSplat()) {
757 // Splat: single value for all subgroups
758 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
759 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
760 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
761 setLayout(cstOp->getResult(0));
762 rewriter.replaceOp(op, cstOp);
763 return success();
764 } else if (sgShape == wgShape) { // if the entire vector is shared by all
765 // subgroups, don't distribute
766 auto newConstOp =
767 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
768 setLayout(newConstOp->getResult(0));
769 rewriter.replaceOp(op, newConstOp);
770 return success();
771 } else {
772 // Non-splat constant
773 // Only supports 1D & 2D
774 // TODO: support other cases that require SLM access
775 if (!eltType.isIndex())
776 return rewriter.notifyMatchFailure(
777 op, "Unsupported element type for non-splat constant op.");
778
779 if (wgShape.size() > 2)
780 return rewriter.notifyMatchFailure(
781 op, "Only 1D & 2D vector constant supported");
782
783 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
784 int64_t rowStride = 0, colStride = 0;
785 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
786 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
787
788 // Compute colStride and rowStride, and check for constant strides.
789 if (cols > 1) {
790 colStride = cast<IntegerAttr>(values[1]).getInt() -
791 cast<IntegerAttr>(values[0]).getInt();
792 }
793 if (rows > 1) {
794 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
795 cast<IntegerAttr>(values[0]).getInt();
796 }
797
798 for (int64_t r = 0; r < rows; ++r) {
799 for (int64_t c = 0; c < cols; ++c) {
800 int64_t idx = r * cols + c;
801 // Check column stride
802 if (c > 0 && cols > 1) {
803 int64_t prevIdx = r * cols + (c - 1);
804 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
805 cast<IntegerAttr>(values[prevIdx]).getInt();
806 if (diff != colStride)
807 return rewriter.notifyMatchFailure(
808 op, "Non-constant column stride in constant op.");
809 }
810 // Check row stride
811 if (r > 0 && rows > 1) {
812 int64_t prevIdx = (r - 1) * cols + c;
813 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
814 cast<IntegerAttr>(values[prevIdx]).getInt();
815 if (diff != rowStride)
816 return rewriter.notifyMatchFailure(
817 op, "Non-constant row stride in constant op.");
818 }
819 }
820 }
821
822 // Create a constant for the base tile.
823 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
824 // For 1D case, extract the first sgShape[0] elements.
825 SmallVector<Attribute> baseTileValues;
826 int baseTileCols = sgShape[sgShape.size() - 1];
827 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
828 for (int64_t r = 0; r < baseTileRows; ++r) {
829 for (int64_t c = 0; c < baseTileCols; ++c) {
830 baseTileValues.push_back(values[r * cols + c]);
831 }
832 }
833
834 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
835 baseTileValues);
836 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
837
838 // Get subgroup id
839 Value sgId =
840 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
841 auto sgOffsets =
842 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
843 if (failed(sgOffsets))
844 return failure();
845
846 SmallVector<Value, 2> strideConsts;
847 strideConsts.push_back(
848 arith::ConstantIndexOp::create(rewriter, loc, colStride));
849 if (rows > 1)
850 strideConsts.insert(
851 strideConsts.begin(),
852 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
853
854 SmallVector<Value> newConstOps;
855 for (auto offsets : *sgOffsets) {
856 // Multiply offset with stride, broadcast it and add to baseConstVec
857 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
858 for (size_t i = 0; i < strideConsts.size(); ++i) {
859 Value mul =
860 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
861 offsets[i], strideConsts[i]);
862 mulOffset = arith::AddIOp::create(
863 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
864 }
865 // Broadcast to baseConstVec size
866 auto bcastOffset = vector::BroadcastOp::create(
867 rewriter, loc, baseConstVec.getType(), mulOffset);
868 auto finalConst =
869 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
870 setLayout(baseConstVec);
871 setLayout(bcastOffset);
872 setLayout(finalConst);
873 newConstOps.push_back(finalConst);
874 }
875 rewriter.replaceOpWithMultiple(op, {newConstOps});
876 return success();
877 }
878 }
879};
880
881// This pattern transforms the LoadGatherOp with explicit offsets to load
882// subgroup data
883struct WgToSgLoadGatherOpWithOffset
884 : public OpConversionPattern<xegpu::LoadGatherOp> {
885 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
886 LogicalResult
887 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
888 ConversionPatternRewriter &rewriter) const override {
889
890 if (!op.getOffsets())
891 return failure();
892
893 Location loc = op.getLoc();
894 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
895 if (!resultType)
896 return failure();
897 ArrayRef<int64_t> wgShape = resultType.getShape();
898
899 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
900
901 if (!layout || !layout.isForWorkgroup())
902 return failure();
903
904 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
905
906 // The offsets need to be distributed
907 auto offsetsVecType =
908 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
909 auto maskVecType =
910 dyn_cast<VectorType>(adaptor.getMask().front().getType());
911 if (!offsetsVecType || !maskVecType ||
912 offsetsVecType.getShape() != maskVecType.getShape()) {
913 return rewriter.notifyMatchFailure(op,
914 "offsets have not been distributed");
915 }
916
917 SmallVector<Value> newLoadOps;
918 auto chunkSizeAttr =
919 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
920 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
921 for (auto [offsets, mask] :
922 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
923 auto newLayout = layout.dropSgLayoutAndData();
924 auto newLoadOp = xegpu::LoadGatherOp::create(
925 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
926 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
927 newLayout);
928 newLoadOp.setAnchorLayout(newLayout);
929 newLoadOps.push_back(newLoadOp);
930 }
931 rewriter.replaceOpWithMultiple(op, {newLoadOps});
932 return success();
933 }
934};
935
936// This pattern transforms the StoreScatterOp with explicit offsets to store
937// subgroup data
938struct WgToSgStoreScatterOpWithOffset
939 : public OpConversionPattern<xegpu::StoreScatterOp> {
940 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
941 LogicalResult
942 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
943 ConversionPatternRewriter &rewriter) const override {
944
945 if (!op.getOffsets())
946 return failure();
947
948 Location loc = op.getLoc();
949 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
950 if (!valueType)
951 return failure();
952
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
954
955 if (!layout || !layout.isForWorkgroup())
956 return failure();
957
958 // The offsets need to be distributed
959 auto offsetsVecType =
960 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
961 auto maskVecType =
962 dyn_cast<VectorType>(adaptor.getMask().front().getType());
963 if (!offsetsVecType || !maskVecType ||
964 offsetsVecType.getShape() != maskVecType.getShape()) {
965 return rewriter.notifyMatchFailure(op,
966 "offsets have not been distributed");
967 }
968
969 auto chunkSizeOpt = op.getChunkSize();
970 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
971 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
972 for (auto [val, offs, mask] : llvm::zip(
973 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
974 auto store = xegpu::StoreScatterOp::create(
975 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
976 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
977 layout.dropSgLayoutAndData());
978 // Update the layout attribute to drop sg_layout and sg_data.
979 for (OpOperand &operand : store->getOpOperands()) {
980 // Skip for operand one (memref)
981 if (operand.getOperandNumber() == 1)
982 continue;
983 xegpu::setTemporaryLayout(operand, layout.dropSgLayoutAndData());
984 }
985 }
986 rewriter.eraseOp(op);
987 return success();
988 }
989};
990
991struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
992 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
993 LogicalResult
994 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
995 ConversionPatternRewriter &rewriter) const override {
996
997 SmallVector<SmallVector<OpFoldResult>> offsetsList;
998 if (failed(genOffsetsList(rewriter, op, offsetsList)))
999 return failure();
1000
1001 ArrayRef<int64_t> wgShape = op.getDataShape();
1002 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1003 assert(valueTy && "the value type must be vector type!");
1004 Type elemTy = valueTy.getElementType();
1005
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1007 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1008 VectorType newResTy = VectorType::get(sgShape, elemTy);
1009 SmallVector<Value> newOps;
1010 for (auto offsets : offsetsList) {
1011 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1012 op.getMemDesc(), offsets,
1013 layout.dropSgLayoutAndData());
1014 newOps.push_back(newOp);
1015 }
1016 rewriter.replaceOpWithMultiple(op, {newOps});
1017
1018 return success();
1019 }
1020};
1021
1022struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1023 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1024 LogicalResult
1025 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1026 ConversionPatternRewriter &rewriter) const override {
1027
1028 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1029 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1030 return failure();
1031
1032 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1033 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1034 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1035 offsets, layout.dropSgLayoutAndData());
1036 rewriter.eraseOp(op);
1037 return success();
1038 }
1039};
1040
1041// This pattern distributes the vector.step ops to work at subgroup level
1042struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1043 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1044 LogicalResult
1045 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1046 ConversionPatternRewriter &rewriter) const override {
1047 xegpu::DistributeLayoutAttr layout =
1048 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1049 if (!layout || !layout.isForWorkgroup())
1050 return failure();
1051
1052 Location loc = op.getLoc();
1053 VectorType type = op.getResult().getType();
1054 auto wgShape = type.getShape();
1055 std::optional<SmallVector<int64_t>> sgShape =
1056 getSgShapeAndCount(wgShape, layout).first;
1057 if (!sgShape)
1058 return failure();
1059
1060 Value sgId =
1061 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1062 auto sgOffsets =
1063 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1064 if (failed(sgOffsets))
1065 return failure();
1066
1067 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1068 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1069 SmallVector<Value> newOps;
1070 for (auto offsets : *sgOffsets) {
1071 // Broadcast the offset scalar to a vector & add to the base steps
1072 auto bcastOffset =
1073 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1074 auto finalSteps =
1075 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1076 xegpu::setTemporaryLayout(steps->getResult(0),
1077 layout.dropSgLayoutAndData());
1078 xegpu::setTemporaryLayout(bcastOffset->getResult(0),
1079 layout.dropSgLayoutAndData());
1080 xegpu::setTemporaryLayout(finalSteps->getResult(0),
1081 layout.dropSgLayoutAndData());
1082 newOps.push_back(finalSteps);
1083 }
1084
1085 rewriter.replaceOpWithMultiple(op, {newOps});
1086 return success();
1087 }
1088};
1089
1090// This pattern transforms vector.shape_cast ops to work at subgroup level.
1091struct WgToSgVectorShapeCastOp
1092 : public OpConversionPattern<vector::ShapeCastOp> {
1093 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1094
1095 LogicalResult
1096 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1097 ConversionPatternRewriter &rewriter) const override {
1098
1099 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1100 if (!resultType)
1101 return failure();
1102
1103 ArrayRef<int64_t> wgShape = resultType.getShape();
1104 xegpu::DistributeLayoutAttr layout =
1105 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1106 if (!layout || !layout.isForWorkgroup())
1107 return failure();
1108
1109 // Check that srcShape and destShape, if they differ, only differ by
1110 // expand of unit dimensions.
1111 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1112 if (!srcType)
1113 return failure();
1114
1115 ArrayRef<int64_t> srcShape = srcType.getShape();
1116 llvm::SetVector<int64_t> expandedUnitDims;
1117
1118 // Check if shapes only differ by expanding unit dimensions (like
1119 // expand_dims)
1120 auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
1121 ArrayRef<int64_t> dst) -> bool {
1122 // All unit dimensions in dst that don't appear in src are the expanded
1123 // unit dimensions
1124 size_t srcIdx = 0;
1125 for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
1126 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
1127 srcIdx++;
1128 else if (dst[dstIdx] == 1)
1129 expandedUnitDims.insert(dstIdx);
1130 else
1131 return false;
1132 return srcIdx == src.size();
1133 };
1134 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1135
1136 if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
1137 xegpu::DistributeLayoutAttr sourceLayout =
1138 xegpu::getTemporaryLayout(op->getOpOperand(0));
1139
1140 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1141 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1142 return isa<vector::BroadcastOp>(user);
1143 });
1144 };
1145
1146 if (!usedByBroadcastOp(op))
1147 return rewriter.notifyMatchFailure(
1148 op, "ShapeCast ops that expand unit dimensions and are used by "
1149 "non-broadcast operations are not supported.");
1150
1151 if (!sourceLayout.isSliceOf(layout))
1152 return rewriter.notifyMatchFailure(
1153 op, "The ShapeCast op only expands dimensions, the result layout "
1154 "must be a slice of the input layout, or vice versa.");
1155 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1156 layoutToDistribute =
1157 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1158 }
1159
1160 SmallVector<int64_t> sgShape =
1161 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1162 VectorType newResultType =
1163 VectorType::get(sgShape, resultType.getElementType());
1164
1165 SmallVector<Value> newShapeCastOps;
1166 for (auto src : adaptor.getSource()) {
1167 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1168 newResultType, src);
1169 xegpu::setTemporaryLayout(newShapeCast->getResult(0),
1170 layout.dropSgLayoutAndData());
1171 newShapeCastOps.push_back(newShapeCast.getResult());
1172 }
1173
1174 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1175 return success();
1176 }
1177};
1178
1179static Value createAccumulator(ConversionPatternRewriter &rewriter,
1180 Location loc, VectorType type,
1181 vector::CombiningKind kind) {
1182 Type elemTy = type.getElementType();
1183
1184 switch (kind) {
1185 case vector::CombiningKind::ADD:
1186 case vector::CombiningKind::XOR:
1187 case vector::CombiningKind::OR:
1188 return arith::ConstantOp::create(
1189 rewriter, loc, type,
1190 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1191
1192 case vector::CombiningKind::MUL:
1193 case vector::CombiningKind::AND:
1194 return arith::ConstantOp::create(
1195 rewriter, loc, type,
1196 DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
1197
1198 case vector::CombiningKind::MINSI:
1199 // Use max signed int value for signed integer min
1200 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1201 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1202 return arith::ConstantOp::create(
1203 rewriter, loc, type,
1205 rewriter.getIntegerAttr(elemTy, maxVal)));
1206 }
1207 return nullptr;
1208
1209 case vector::CombiningKind::MINUI:
1210 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1211 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1212 return arith::ConstantOp::create(
1213 rewriter, loc, type,
1215 rewriter.getIntegerAttr(elemTy, maxVal)));
1216 }
1217 return nullptr;
1218
1219 case vector::CombiningKind::MAXSI:
1220 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1221 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1222 return arith::ConstantOp::create(
1223 rewriter, loc, type,
1225 rewriter.getIntegerAttr(elemTy, minVal)));
1226 }
1227 return nullptr;
1228
1229 case vector::CombiningKind::MAXUI:
1230 return arith::ConstantOp::create(
1231 rewriter, loc, type,
1232 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1233
1234 case vector::CombiningKind::MINNUMF:
1235 case vector::CombiningKind::MINIMUMF:
1236 // Use +infinity for float min operations
1237 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1238 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1239 return arith::ConstantOp::create(
1240 rewriter, loc, type,
1241 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
1242 }
1243 return nullptr;
1244
1245 case vector::CombiningKind::MAXNUMF:
1246 case vector::CombiningKind::MAXIMUMF:
1247 // Use -infinity for float max operations
1248 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1249 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
1250 return arith::ConstantOp::create(
1251 rewriter, loc, type,
1252 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
1253 }
1254 return nullptr;
1255 }
1256 return nullptr;
1257}
1258
1259/// This function converts multi-dimensional subgroup indices into a single
1260/// linear offset. It's used to calculate memory offsets in SLM for
1261/// cross-subgroup reduction coordination.
1262///
1263/// Parameters:
1264/// - sgIds: Multi-dimensional subgroup indices (e.g., [sgId_x, sgId_y, sgId_z])
1265/// - dims: Which dimensions to include in linearization (e.g., [0, 2] for x and
1266/// z dims)
1267/// - sgLayout: Subgroup layout sizes for each dimension (e.g., [4, 8, 2] means
1268/// 4x8x2 subgroups)
1269///
1270/// It uses row-major linearization formula:
1271/// offset = sum(sgIds[dim] * stride[dim])
1272/// where stride[dim] = product of all sgLayout sizes in dimensions after
1273/// 'dim'
1274///
1275/// Example:
1276/// - sgLayout = [4, 8, 2], dims = [0, 2] (linearize x and z dimensions)
1277/// - sgIds = [1, 3, 1] (subgroup at position x=1, y=3, z=1)
1278/// - Calculation:
1279/// * dim=0: stride=1, term = sgIds[0] * 1 = 1 * 1 = 1
1280/// * dim=2: stride=sgLayout[0]=4, term = sgIds[2] * 4 = 1 * 4 = 4
1281/// * linearizedOffset = 1 + 4 = 5
1282///
1283/// This gives us a unique linear index for each combination of subgroup
1284/// positions in the specified dimensions, which is used for SLM row/column
1285/// addressing.
1286static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
1287 Location loc, ArrayRef<Value> sgIds,
1288 ArrayRef<int64_t> dims,
1289 ArrayRef<int64_t> sgLayout) {
1290 Value linearizedOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1291 int64_t stride = 1;
1292
1293 for (int64_t dim : dims) {
1294 Value dimVal = sgIds[dim];
1295 Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
1296 Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1297 linearizedOffset =
1298 arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
1299 stride *= sgLayout[dim];
1300 }
1301
1302 return linearizedOffset;
1303}
1304
1305/// This pattern transforms vector.multi_dim_reduction operations from
1306/// workgroup-level to subgroup-level execution with support for multiple
1307/// reduction dimensions.
1308///
1309/// Steps include:
1310/// 1. LOCAL REDUCTION :
1311/// - Each subgroup performs local reduction on its data slice
1312/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1313/// phase
1314///
1315/// 2. CROSS-SUBGROUP :
1316/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1317/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1318/// - If not needed, adds original accumulator and returns local results
1319///
1320/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1321/// a) SLM Layout Design:
1322/// - Rows: subgroups participating in reduction (product of sg_layout in
1323/// reduction dims)
1324/// - Cols: total result elements across non-reduction dimensions
1325///
1326/// b) Store Phase:
1327/// - Each subgroup stores its local reduction result to SLM
1328/// - Row offset: linearized index of subgroup in reduction dimensions
1329/// - Col offset: linearized index of subgroup in non-reduction dimensions
1330///
1331/// c) Load and Final Reduction Phase:
1332/// - Each subgroup loads a column of data (all reduction participants for
1333/// its position)
1334/// - Performs final reduction along the loaded dimension
1335/// - Adds original accumulator to get final result
1336///
1337struct WgToSgMultiDimReductionOp
1338 : public OpConversionPattern<vector::MultiDimReductionOp> {
1339 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1340
1341 LogicalResult
1342 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1343 ConversionPatternRewriter &rewriter) const override {
1344 Location loc = op.getLoc();
1345
1346 VectorType srcType = op.getSourceVectorType();
1347 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1348 if (!dstType)
1349 return failure();
1350
1351 auto originalSrcShape = srcType.getShape();
1352 xegpu::DistributeLayoutAttr layout =
1353 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1354 if (!layout || !layout.isForWorkgroup())
1355 return failure();
1356
1357 auto reductionDims = llvm::to_vector(op.getReductionDims());
1358
1359 // Get sg_layout and sg_data from the parent layout
1360 SmallVector<int64_t> sgLayout;
1361 SmallVector<int64_t> sgData;
1362 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1363 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1364 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1365 } else
1366 return rewriter.notifyMatchFailure(
1367 op, "Reduction should have SliceAttr layout");
1368
1369 Type elemTy = dstType.getElementType();
1370
1371 // Step 1: perform local subgroup reductions with ZERO accumulator
1372 SmallVector<Value> localReductions;
1373 SmallVector<int64_t> sgShape =
1374 getSgShapeAndCount(originalSrcShape, layout).first;
1375 VectorType newDstType = VectorType::get(sgShape, elemTy);
1376 for (auto sgSrc : adaptor.getSource()) {
1377 // Create ZERO accumulator for local reduction
1378 auto neutralLocalAcc =
1379 createAccumulator(rewriter, loc, newDstType, op.getKind());
1380 // Local reduction with ZERO accumulator
1381 auto localReduce = vector::MultiDimReductionOp::create(
1382 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1383 reductionDims);
1384 localReductions.push_back(localReduce.getResult());
1385 }
1386
1387 // Check if cross-subgroup reduction is needed for any reduction dimension
1388 SmallVector<int64_t> crossSgReductionDims;
1389 for (int64_t reductionDim : reductionDims) {
1390 bool needsCrossSubgroupReduction =
1391 (sgLayout[reductionDim] > 1) &&
1392 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1393
1394 if (needsCrossSubgroupReduction) {
1395 crossSgReductionDims.push_back(reductionDim);
1396 }
1397 }
1398
1399 // If no cross-subgroup reduction needed, add accumulator and return
1400 if (crossSgReductionDims.empty()) {
1401 SmallVector<Value> results;
1402 for (auto localResult : localReductions) {
1403 auto finalResult = vector::makeArithReduction(
1404 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1405 if (auto defOp = finalResult.getDefiningOp())
1406 xegpu::setDistributeLayoutAttr(defOp->getResult(0),
1407 layout.dropSgLayoutAndData());
1408 results.push_back(finalResult);
1409 }
1410 rewriter.replaceOpWithMultiple(op, {results});
1411 return success();
1412 }
1413
1414 // Step 2: cross-subgroup reduction using SLM
1415
1416 // Calculate total elements in local result
1417 int64_t localElements = computeProduct(sgShape);
1418
1419 // Shape cast for SLM storage - store as [1, localElements]
1420 SmallVector<int64_t> storeShape2D = {1, localElements};
1421 VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
1422 auto storeShapeCast = vector::ShapeCastOp::create(
1423 rewriter, loc, storeType2D, localReductions[0]);
1424 Value storeData = storeShapeCast.getResult();
1425
1426 // Calculate SLM shape - rows for sg's in reduction dims, cols for total
1427 // result elements across all subgroups in non-reduction dimensions
1428 int64_t totalReductionSubgroups = 1;
1429 for (int64_t dim : crossSgReductionDims) {
1430 totalReductionSubgroups *= sgLayout[dim];
1431 }
1432
1433 // Total result elements across all subgroups in non-reduction dimensions
1434 int64_t totalResultElements =
1435 localElements * computeProduct(sgLayout) / totalReductionSubgroups;
1436
1437 SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
1438 totalResultElements};
1439
1440 // Allocate SLM
1441 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1442 auto bytesPerElement = bitWidth / 8;
1443 int64_t slmElements = slmShape2D[0] * slmShape2D[1];
1444 auto slmSize = slmElements * bytesPerElement;
1445 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1446 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1447
1448 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
1449 slmShape2D, elemTy, nullptr);
1450 auto memDesc =
1451 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1452
1453 // Step 4: Store local results to SLM
1454 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1455 rewriter.getIndexType(), nullptr);
1456
1457 // Convert sgLayout to Values for delinearizeIndex
1458 SmallVector<Value> sgLayoutValues;
1459 for (int64_t dim : sgLayout)
1460 sgLayoutValues.push_back(
1461 arith::ConstantIndexOp::create(rewriter, loc, dim));
1462
1463 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1464 sgLayoutValues);
1465 if (failed(sgIdsResult))
1466 return failure();
1467 SmallVector<Value> sgIds = *sgIdsResult;
1468
1469 // Row offset: linearize reduction dimension indices
1470 Value rowOffsetStore = linearizeSubgroupIndices(
1471 rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
1472
1473 // Column offset: linearize non-reduction dimension indices
1474 SmallVector<int64_t> nonReductionDims;
1475 for (size_t i = 0; i < sgLayout.size(); ++i) {
1476 if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i))) {
1477 nonReductionDims.push_back(static_cast<int64_t>(i));
1478 }
1479 }
1480
1481 Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
1482 nonReductionDims, sgLayout);
1483
1484 Value localElementsVal =
1485 arith::ConstantIndexOp::create(rewriter, loc, localElements);
1486 colOffset =
1487 arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
1488
1489 SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
1490
1491 auto storeMatrixLayout = xegpu::SliceAttr::get(
1492 rewriter.getContext(),
1493 xegpu::LayoutAttr::get(rewriter.getContext(), /*sg_layout =*/nullptr,
1494 /*sg_data =*/nullptr,
1495 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
1496 /*lane_data =*/nullptr, /*order =*/nullptr),
1497 dyn_cast<xegpu::SliceAttr>(layout).getDims());
1498 xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
1499 storeOffsets2D, /*layout=*/storeMatrixLayout);
1500
1501 gpu::BarrierOp::create(rewriter, loc);
1502
1503 // Step 5: Load from SLM for final reduction
1504 SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
1505 VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
1506
1507 // Load offsets - each subgroup loads its column based on non-reduction
1508 // position
1509 Value rowOffsetLoad = arith::ConstantIndexOp::create(rewriter, loc, 0);
1510
1511 SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
1512
1513 auto loadOp = xegpu::LoadMatrixOp::create(
1514 rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
1515 /*layout=*/nullptr);
1516
1517 // Step 6: Perform final reduction with ZERO accumulator
1518 SmallVector<int64_t> finalReductionDims = {0};
1519 SmallVector<int64_t> finalResultShape = {localElements};
1520 VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
1521
1522 auto neutralFinalAcc =
1523 createAccumulator(rewriter, loc, finalResultType, op.getKind());
1524
1525 auto finalReduce = vector::MultiDimReductionOp::create(
1526 rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
1527 neutralFinalAcc, finalReductionDims);
1528
1529 // Step 7: Add the original accumulator at the end
1530 Value originalAcc = adaptor.getAcc()[0];
1531 Value accToAdd = originalAcc;
1532
1533 // Handle shape mismatch by shape casting
1534 if (originalAcc.getType() != finalReduce.getResult().getType()) {
1535 auto originalAccType = cast<VectorType>(originalAcc.getType());
1536 auto finalResultType =
1537 cast<VectorType>(finalReduce.getResult().getType());
1538
1539 // If they have the same number of elements, just shape cast
1540 if (originalAccType.getNumElements() ==
1541 finalResultType.getNumElements()) {
1542 auto shapeCast = vector::ShapeCastOp::create(
1543 rewriter, loc, finalResultType, originalAcc);
1544 accToAdd = shapeCast.getResult();
1545 }
1546 }
1547
1548 auto finalResult = vector::makeArithReduction(
1549 rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
1550
1551 if (auto defOp = finalResult.getDefiningOp())
1552 xegpu::setDistributeLayoutAttr(defOp->getResult(0),
1553 layout.dropSgLayoutAndData());
1554
1555 rewriter.replaceOp(op, finalResult);
1556 return success();
1557 }
1558};
1559
1560// This pattern transforms vector.transpose ops to work at subgroup level.
1561struct WgToSgVectorTransposeOp
1562 : public OpConversionPattern<vector::TransposeOp> {
1563 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1564
1565 LogicalResult
1566 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1567 ConversionPatternRewriter &rewriter) const override {
1568 VectorType resultType = op.getResultVectorType();
1569
1570 ArrayRef<int64_t> wgShape = resultType.getShape();
1571 xegpu::DistributeLayoutAttr layout =
1572 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1573 if (!layout || !layout.isForWorkgroup())
1574 return failure();
1575 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1576 xegpu::DistributeLayoutAttr sourceLayout =
1577 xegpu::getDistributeLayoutAttr(op.getVector());
1578 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1579 return failure();
1580
1581 SmallVector<int64_t> sourceSgLayout =
1582 sourceLayout.getEffectiveSgLayoutAsInt();
1583 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1584 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
1585 DenseI32ArrayAttr resultOrder = layout.getOrder();
1586
1587 if (!sourceOrder || !resultOrder) {
1588 return rewriter.notifyMatchFailure(
1589 op, "Both source and result must have order attributes");
1590 }
1591
1592 ArrayRef<int64_t> permutation = op.getPermutation();
1593 size_t permutationSize = permutation.size();
1594 if (sourceSgLayout.size() != permutationSize ||
1595 resultSgLayout.size() != permutationSize) {
1596 return rewriter.notifyMatchFailure(
1597 op, "Layouts and permutation must have the same rank");
1598 }
1599
1600 // Check that sgLayout, sgData & order are properly transposed for source
1601 // and result
1602 if (!layout.isTransposeOf(sourceLayout, permutation))
1603 return rewriter.notifyMatchFailure(
1604 op, "Result layout is not a valid transpose of source layout "
1605 "according to permutation");
1606
1607 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1608 VectorType newResultType =
1609 VectorType::get(sgShape, resultType.getElementType());
1610 SmallVector<Value> newTransposeOps;
1611 for (auto src : adaptor.getVector()) {
1612 auto newTranspose = vector::TransposeOp::create(
1613 rewriter, op.getLoc(), newResultType, src, permutation);
1614 xegpu::setTemporaryLayout(newTranspose->getResult(0),
1615 layout.dropSgLayoutAndData());
1616 newTransposeOps.push_back(newTranspose.getResult());
1617 }
1618
1619 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1620 return success();
1621 }
1622};
1623
1624// Distribute vector mask ops to work at subgroup level.
1625template <typename MaskOpType>
1626struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1627 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1628
1629 LogicalResult matchAndRewrite(
1630 MaskOpType op,
1631 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1632 ConversionPatternRewriter &rewriter) const override {
1633 xegpu::DistributeLayoutAttr layout =
1634 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1635 if (!layout || !layout.isForWorkgroup())
1636 return failure();
1637
1638 Location loc = op.getLoc();
1639 VectorType type = op.getResult().getType();
1640 auto wgShape = type.getShape();
1641
1642 SmallVector<Value> wgMaskDimSizes;
1643 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1644 for (int64_t maskSize : op.getMaskDimSizes()) {
1645 wgMaskDimSizes.push_back(
1646 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1647 }
1648 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1649 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1650 }
1651
1652 Value sgId =
1653 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1654 auto sgOffsets =
1655 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1656 if (failed(sgOffsets))
1657 return failure();
1658
1659 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1660 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1661
1662 // In each dimension, each subgroup computes its local mask size as:
1663 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1664 SmallVector<Value> newCreateMaskOps;
1665 for (auto offsetSet : *sgOffsets) {
1666 SmallVector<Value> maskOperands;
1667
1668 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1669 Value dimSizeVal =
1670 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1671 Value offset = offsetSet[i];
1672 Value adjustedMaskSize =
1673 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1674 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1675 Value nonNegative =
1676 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1677 Value sgMaskSize =
1678 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1679 maskOperands.push_back(sgMaskSize);
1680 }
1681
1682 auto newCreateMaskOp =
1683 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1684 xegpu::setTemporaryLayout(newCreateMaskOp->getResult(0),
1685 layout.dropSgLayoutAndData());
1686 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1687 }
1688
1689 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1690 return success();
1691 }
1692};
1693
1694using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1695using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1696} // namespace
1697
1698namespace mlir {
1699namespace xegpu {
1701 patterns
1702 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1703 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1704 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1705 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1706 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1707 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1708 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1709 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1710 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1711 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1712 patterns.getContext());
1713}
1714} // namespace xegpu
1715} // namespace mlir
1716
1717namespace {
1718struct XeGPUWgToSgDistributePass
1719 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1720 void runOnOperation() override;
1721};
1722} // namespace
1723
1724void XeGPUWgToSgDistributePass::runOnOperation() {
1725
1726 // TODO-LayoutRefactor: unify the local propagation for layout preprocessing
1727 // Operation *op = getOperation();
1728 // if (!xegpu::recoverTemporaryLayouts(op)) {
1729 // signalPassFailure();
1730 // return;
1731 // }
1732
1733 // Track existing UnrealizedConversionCastOps
1734 SmallVector<Operation *> existingCastOps;
1735 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1736 existingCastOps.push_back(castOp.getOperation());
1737 });
1738
1739 {
1740 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1741 // VectorType operands. This first converts such operands to
1742 // RankedTensorType, propagates the layout attribute into the encoding
1743 // attribute, and finally converts the RankedTensorType to VectorType based
1744 // on the encoding.
1745
1746 TypeConverter converter;
1747 converter.addConversion([&](Type type) -> Type { return type; });
1748 converter.addConversion(
1749 [&](RankedTensorType type,
1750 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1751 Type elemTy = type.getElementType();
1752 ArrayRef<int64_t> shape = type.getShape();
1753
1754 int count;
1755 SmallVector<int64_t> subShape;
1756 std::tie(subShape, count) = getSgShapeAndCount(
1757 shape,
1758 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1759
1760 auto newTy = VectorType::get(subShape, elemTy);
1761 result.append(count, newTy);
1762 return success();
1763 });
1764
1766 converter);
1767 }
1768
1769 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1770 // as well as XeGPU, Arith, and Vector operations.
1771 MLIRContext *ctx = &getContext();
1772 RewritePatternSet patterns(ctx);
1773 ConversionTarget target(*ctx);
1774 TypeConverter converter;
1775 converter.addConversion([&](Type type) -> Type { return type; });
1776 converter.addConversion(
1777 [&](xegpu::TensorDescType type,
1778 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1779 Type elemTy = type.getElementType();
1780 ArrayRef<int64_t> shape = type.getShape();
1781
1782 int count;
1783 SmallVector<int64_t> subShape;
1784 xegpu::LayoutAttr layout = type.getLayoutAttr();
1785 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1786
1787 if (layout)
1788 layout = layout.dropSgLayoutAndData();
1789
1790 auto newTy = xegpu::TensorDescType::get(
1791 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1792 result.append(count, newTy);
1793 return success();
1794 });
1795
1796 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1797 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1798 return createOp.getType();
1799 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1800 return loadOp.getTensorDescType();
1801 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1802 return storeOp.getTensorDescType();
1803 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1804 return updateOp.getType();
1805 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1806 return prefetchOp.getTensorDescType();
1807 return xegpu::TensorDescType();
1808 };
1809
1810 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1811 return !layout || !layout.isForWorkgroup();
1812 };
1813
1814 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1815 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1816 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1817 auto tdescTy = getTensorDescType(op);
1818 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1819 return isLegal(layout);
1820 });
1821
1822 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1823 auto layout = op.getLayoutCdAttr();
1824 return isLegal(layout);
1825 });
1826
1827 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1828 [=](xegpu::LoadMatrixOp op) -> bool {
1829 return isLegal(op.getLayoutAttr());
1830 });
1831
1832 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1833 [=](xegpu::StoreMatrixOp op) -> bool {
1834 return isLegal(op.getLayoutAttr());
1835 });
1836
1837 target.addDynamicallyLegalOp<arith::ConstantOp>(
1838 [=](arith::ConstantOp op) -> bool {
1839 auto vecType = dyn_cast<VectorType>(op.getType());
1840 if (!vecType)
1841 return true;
1842
1843 auto layout =
1844 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1845 return isLegal(layout);
1846 });
1847
1848 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1849 vector::TransposeOp, vector::BroadcastOp,
1850 vector::MultiDimReductionOp,
1851 vector::ConstantMaskOp, vector::CreateMaskOp>(
1852 [=](Operation *op) -> bool {
1853 // Check for either a SliceAttr or LayoutAttr on the result.
1854 auto layout =
1855 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1856 return isLegal(layout);
1857 });
1858
1859 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1860 [=](xegpu::LoadGatherOp op) -> bool {
1861 auto layout = op.getLayoutAttr();
1862 return isLegal(layout);
1863 });
1864
1865 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1866 [=](xegpu::StoreScatterOp op) -> bool {
1867 auto layout = op.getLayoutAttr();
1868 return isLegal(layout);
1869 });
1870
1871 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1872 [=](xegpu::ConvertLayoutOp op) -> bool {
1873 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1874 });
1875
1876 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1877 [=](Operation *op) -> std::optional<bool> {
1878 // Only handle elementwise mappable ops
1880 return true;
1881
1882 VectorType resultType =
1883 dyn_cast<VectorType>(op->getResult(0).getType());
1884 if (!resultType)
1885 return true;
1886
1887 // Check if all operands are vectors of the same shape
1888 // TODO: Support other types.
1889 for (Value operand : op->getOperands()) {
1890 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1891 if (!operandType || operandType.getShape() != resultType.getShape()) {
1892 return true;
1893 }
1894 }
1895
1896 xegpu::DistributeLayoutAttr layout =
1897 xegpu::getTemporaryLayout(op->getResult(0));
1898 return isLegal(layout);
1899 });
1900
1901 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1902 [=](UnrealizedConversionCastOp op) {
1903 return llvm::is_contained(existingCastOps, op.getOperation());
1904 });
1905
1906 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1907
1909 target);
1911 if (failed(
1912 applyPartialConversion(getOperation(), target, std::move(patterns))))
1913 return signalPassFailure();
1914
1915 // Remove sg_layout and sg_data attributes from the Layout
1916 // attribute for each VectorType result of the operation.
1917 // For Structured Control Flow ops, the layout is simply removed,
1918 // since in 1:N case, the layout for new results are missing.
1919 // Layout propagation pass will activated.
1920 getOperation()->walk([](Operation *op) {
1921 for (OpResult result : op->getOpResults()) {
1922 std::string name = xegpu::getTemporaryLayoutName(result);
1923 if (auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
1924 op->removeAttr(name);
1925 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1926 if (auto newLayout = layout.dropSgLayoutAndData())
1927 op->setAttr(name, newLayout);
1928 }
1929 }
1930 }
1931 });
1932}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getOpResults()
Definition Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.