MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include <optional>
26
27namespace mlir {
28namespace xegpu {
29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31} // namespace xegpu
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Retrieve the RangeAttr if it is specified.
39static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
40 Operation *parent = op->getParentOfType<scf::IfOp>();
41 while (parent) {
42 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->getAttr("sg_id_range")))
44 return attr;
45 parent = parent->getParentOfType<scf::IfOp>();
46 }
47 return {};
48}
49
50static std::pair<SmallVector<int64_t>, int>
51getSgShapeAndCount(ArrayRef<int64_t> shape,
52 xegpu::DistributeLayoutAttr layout) {
53 int count = 1;
55 if (layout && layout.isForWorkgroup()) {
56 SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
57 if (!layout.getEffectiveSgDataAsInt().empty())
58 sgShape = layout.getEffectiveSgDataAsInt();
59 else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
60 sgShape = *maybeDerivedSgData;
61 SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
62 // Clamp distUnit to the original shape to handle cases where data is
63 // shared among subgroups, which may cause distUnit to exceed the original
64 // shape.
65 for (size_t i = 0; i < distUnit.size(); ++i)
66 distUnit[i] = std::min(shape[i], distUnit[i]);
67 count = computeProduct(shape) / computeProduct(distUnit);
68 }
69 return std::make_pair(sgShape, count);
70}
71
72/// Utility helper for deriving a list of offsets for each sub-TensorDescs
73/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
74/// associated distribute layout attribute, the shape, subgroup id and the
75/// original offsets of the op
76template <
77 typename OpType,
78 typename = std::enable_if_t<llvm::is_one_of<
79 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
80 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81static LogicalResult
82genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
84 Location loc = op.getLoc();
85 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
86 // not applicable to ops without offsets operands.
87 if (origOffsets.empty())
88 return failure();
89
90 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
91 xegpu::DistributeLayoutAttr layout;
92 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
93 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
94 layout = op.getLayoutAttr();
95 } else {
96 layout = op.getDescLayoutAttr();
97 }
98
99 // not applicable to ops without workgroup layout attributes
100 if (!layout || !layout.isForWorkgroup())
101 return failure();
102
103 Value sgId =
104 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
105
106 // verify and adjust the sgId if the range specifier is present
107 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
108 if (sgIdRange) {
109 int64_t startOfRange = sgIdRange.getStart().getInt();
110 int64_t endOfRange = sgIdRange.getEnd().getInt();
111 // verify the RangeAttr against the layout attribute
112 if (layout.getNumSubgroups() != endOfRange - startOfRange)
113 return rewriter.notifyMatchFailure(
114 op, "sg_layout size must match the sg_id_range");
115 // adjust the sgId if necessary
116 if (startOfRange > 0) {
117 Value startOfRangeVal =
118 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
119 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
120 }
121 }
122
123 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
124 // descriptors to be accessed, based on the layout information.
125 ArrayRef<int64_t> wgShape = op.getDataShape();
126 auto maybeDescOffsets =
127 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
128 if (failed(maybeDescOffsets))
129 return failure();
130
131 // Compute the final global offsets for each accessed sub-tensor
132 // or sub-memory descriptor.
133 for (const auto &sgOffsets : *maybeDescOffsets) {
135 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
136 offsetsList.push_back(std::move(newOffsets));
137 }
138
139 // callback(offsetsList);
140 return success();
141}
142
143/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
144/// from a workgroup descriptor. It replaces the offsets and sizes with
145/// appropriate values for the subgroup.
146/// It uses round-robin assignment to distribute the work to the subgroups.
147/// Following create_nd_desc operation:,
148/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
149/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
150/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
151/// is converted to 9 subgroup level operations based on the sg_layout &
152/// sg_data:
153/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
154/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
155/// lane_data = [1, 1]>>
156///
157/// The sg_layout and sg_data attributes are dropped after the pass as they are
158/// no longer needed.
159///
160/// 24x24 matrix distribution example:
161/// sg_layout = [4, 4], sg_data = [2, 2]
162/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
163/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
164///
165/// +------------------------+
166/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
167/// |-----+-----+-----|
168/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
169/// |-----+-----+-----|
170/// | 8x8 | 8x8 | 8x8 |
171/// +------------------------+
172///
173/// Each 8x8 tile is further subdivided among subgroups:
174/// +------------------------+
175/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
176/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
177/// | 2x2 2x2 2x2 2x2 |
178/// | 2x2 2x2 2x2 2x2 |
179/// +------------------------+
180///
181/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
182/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
183
184/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
185/// pattern and all the other ops just follow.
186/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
187/// ops in the pass.
188struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
189 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
190
191 LogicalResult
192 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter) const override {
194 SmallVector<SmallVector<OpFoldResult>> offsetsList;
195 if (failed(genOffsetsList(rewriter, op, offsetsList)))
196 return failure();
197
198 MLIRContext *ctx = op.getContext();
199 xegpu::TensorDescType tdescTy = op.getType();
200 ArrayRef<int64_t> wgShape = tdescTy.getShape();
201 Type elemTy = tdescTy.getElementType();
202 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
203 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
204 auto newTdescTy =
205 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
206 layout.dropSgLayoutAndData());
207
208 SmallVector<Value> newOps;
209 for (auto offsets : offsetsList) {
210 auto newOp = xegpu::CreateNdDescOp::create(
211 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
212 op.getMixedSizes(), op.getMixedStrides());
213
214 newOps.push_back(newOp);
215 }
216 rewriter.replaceOpWithMultiple(op, {newOps});
217
218 return success();
219 }
220};
221
222// This pattern transforms the CreateNdDescOp without offsets to create a
223// subgroup descriptor from a workgroup descriptor
224struct WgToSgCreateNdOpNoOffset
225 : public OpConversionPattern<xegpu::CreateNdDescOp> {
226 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
227
228 LogicalResult
229 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override {
231
232 // Check no offsets are specified.
233 if (!op.getMixedOffsets().empty())
234 return failure();
235
236 Location loc = op.getLoc();
237 MLIRContext *ctx = op.getContext();
238 xegpu::TensorDescType tdescTy = op.getType();
239 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
240 if (!layout || !layout.isForWorkgroup())
241 return failure();
242
243 Type elemTy = tdescTy.getElementType();
244 ArrayRef<int64_t> wgShape = tdescTy.getShape();
245
246 SmallVector<int64_t> sgShape;
247 int count;
248 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
249 xegpu::TensorDescType newTdescTy =
250 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
251 layout.dropSgLayoutAndData());
252
253 SmallVector<Value> newCreateNdOps(count);
254 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
255 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
256 op.getSource(), op.getMixedSizes(),
257 op.getMixedStrides());
258 });
259
260 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
261 return success();
262 }
263};
264
265/// This pattern transforms the LoadNdOp to load subgroup data.
266struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
267 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
268 LogicalResult
269 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter) const override {
271 if (!op.getMixedOffsets().empty())
272 return failure();
273
274 SmallVector<Value> newLoadOps;
275 for (auto src : adaptor.getTensorDesc()) {
276 xegpu::TensorDescType tdescTy =
277 dyn_cast<xegpu::TensorDescType>(src.getType());
278 ArrayRef<int64_t> srcShape = tdescTy.getShape();
279 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
280 auto newLoadOp = xegpu::LoadNdOp::create(
281 rewriter, op.getLoc(), newResTy, src,
282 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
283 newLoadOps.push_back(newLoadOp);
284 }
285 rewriter.replaceOpWithMultiple(op, {newLoadOps});
286 return mlir::success();
287 }
288};
289
290/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
291/// It creates a StoreNdOp op to store the updated values to the new subgroup
292/// src tensor descriptors.
293struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
294 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
295 LogicalResult
296 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298 if (!op.getMixedOffsets().empty())
299 return failure();
300
301 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
302 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
304
305 rewriter.eraseOp(op);
306 return success();
307 }
308};
309
310// This pattern transforms the LoadNdOp with explicit offsets to load
311// subgroup data.
312struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
313 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
314 LogicalResult
315 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317
318 SmallVector<SmallVector<OpFoldResult>> offsetsList;
319 if (failed(genOffsetsList(rewriter, op, offsetsList)))
320 return failure();
321
322 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
323 if (layout)
324 layout = layout.dropSgLayoutAndData();
325 SmallVector<Value> newOps;
326 for (auto [tdesc, offsets] :
327 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
328 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
329 VectorType newResTy =
330 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
331 auto newOp = xegpu::LoadNdOp::create(
332 rewriter, op.getLoc(), newResTy, tdesc, offsets,
333 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
334 op.getL2HintAttr(), op.getL3HintAttr(), layout);
335 newOps.push_back(newOp);
336 }
337 rewriter.replaceOpWithMultiple(op, {newOps});
338
339 return success();
340 }
341};
342
343// This pattern transforms the StoreNdOp with explicit offsets to store
344// subgroup data.
345struct WgToSgStoreNdOpWithOffset
346 : public OpConversionPattern<xegpu::StoreNdOp> {
347 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
348 LogicalResult
349 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351 SmallVector<SmallVector<OpFoldResult>> offsetsList;
352 if (failed(genOffsetsList(rewriter, op, offsetsList)))
353 return failure();
354
355 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
356 if (layout)
357 layout = layout.dropSgLayoutAndData();
358 for (auto [v, tdesc, offsets] :
359 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
360 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
361 op.getL1HintAttr(), op.getL2HintAttr(),
362 op.getL3HintAttr(), layout);
363 }
364 rewriter.eraseOp(op);
365
366 return success();
367 }
368};
369
370// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
371// subgroup data.
372struct WgToSgPrefetchNdOpWithOffset
373 : public OpConversionPattern<xegpu::PrefetchNdOp> {
374 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
375 LogicalResult
376 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
377 ConversionPatternRewriter &rewriter) const override {
378 SmallVector<SmallVector<OpFoldResult>> offsetsList;
379 if (failed(genOffsetsList(rewriter, op, offsetsList)))
380 return failure();
381
382 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
383 if (layout)
384 layout = layout.dropSgLayoutAndData();
385 for (auto [tdesc, offsets] :
386 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
387 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
388 op.getL1HintAttr(), op.getL2HintAttr(),
389 op.getL3HintAttr(), layout);
390 }
391 rewriter.eraseOp(op);
392
393 return success();
394 }
395};
396
397/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
398/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
399/// offsets of the new subgroup src tensor descriptors.
400struct WgToSgUpdateNdOffsetOp
401 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
402 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
403 LogicalResult
404 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406 llvm::SmallVector<Value> newUpdateTileOffsetOps;
407 for (auto tDesc : adaptor.getTensorDesc()) {
408 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
409 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
410 op.getConstOffsets());
411 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
412 }
413
414 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
415 return success();
416 }
417};
418
419/// This pattern transforms the DpasOp to work at subgroup level.
420struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
421 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
422 LogicalResult
423 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 Location loc = op.getLoc();
426 VectorType resultTy = op.getResult().getType();
427 if (resultTy.getRank() != 2)
428 return failure();
429
430 auto layoutCd = op.getLayoutCdAttr();
431 auto layoutA = op.getLayoutAAttr();
432 auto layoutB = op.getLayoutBAttr();
433 if (!layoutCd || !layoutA || !layoutB)
434 return failure();
435 size_t i = 0;
436 SmallVector<Value> newDpasOps;
437 for (auto aVec : adaptor.getLhs()) {
438 for (auto bVec : adaptor.getRhs()) {
439
440 llvm::SmallVector<Value> operands({aVec, bVec});
441 Value tmpC;
442 if (op.getAcc()) {
443 tmpC = adaptor.getAcc()[i++];
444 operands.push_back(tmpC);
445 }
446
447 ArrayRef<int64_t> aVecShape =
448 llvm::cast<VectorType>(aVec.getType()).getShape();
449 ArrayRef<int64_t> bVecShape =
450 llvm::cast<VectorType>(bVec.getType()).getShape();
451 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
452 resultTy.getElementType());
453 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
454 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
455 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
456 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
457
458 newDpasOps.push_back(newDpasOp);
459 }
460 }
461 rewriter.replaceOpWithMultiple(op, {newDpasOps});
462 return success();
463 }
464};
465
466/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
467struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
468 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
469 LogicalResult
470 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter) const override {
472
473 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
474 if ((offsetSize != 0) || op.getConstOffsetsAttr())
475 return failure();
476
477 for (auto src : adaptor.getTensorDesc())
478 xegpu::PrefetchNdOp::create(
479 rewriter, op.getLoc(), TypeRange(), src,
480 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
481 rewriter.eraseOp(op);
482 return success();
483 }
484};
485
486/// This pattern transforms vector.broadcast ops to work at subgroup level.
487struct WgToSgVectorBroadcastOp
488 : public OpConversionPattern<vector::BroadcastOp> {
489 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
490
491 LogicalResult
492 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter) const override {
494
495 VectorType resultType = op.getResult().getType();
496 ArrayRef<int64_t> wgShape = resultType.getShape();
497
498 xegpu::DistributeLayoutAttr layout =
499 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
500 if (!layout || !layout.isForWorkgroup())
501 return failure();
502
503 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
504 VectorType newResultType =
505 VectorType::get(sgShape, resultType.getElementType());
506
507 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
508 return failure();
509
510 SmallVector<Value> newBroadcastOps;
511 for (auto operand : adaptor.getOperands().front()) {
512 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
513 newResultType, operand);
514
515 newBroadcastOps.push_back(newBroadcast.getResult());
516 }
517 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
518 return success();
519 }
520};
522// This pattern transforms elementwise ops to work at subgroup level.
523struct WgToSgElementwiseOp : public ConversionPattern {
524 WgToSgElementwiseOp(MLIRContext *ctx)
525 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
527 LogicalResult
528 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
529 ConversionPatternRewriter &rewriter) const override {
530 // Only match ops with elementwise trait and single result.
532 return failure();
534 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
535 assert(resultType && "Expected result to be a VectorType");
537 ArrayRef<int64_t> wgShape = resultType.getShape();
539 xegpu::DistributeLayoutAttr layout =
540 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
541 if (!layout || !layout.isForWorkgroup())
542 return failure();
543
544 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
545
546 size_t numVariants = operands.empty() ? 0 : operands.front().size();
548 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
549 return operandVec.size() != numVariants;
550 }))
551 return failure();
553 SmallVector<Value> newResults;
554 VectorType newResultType =
555 VectorType::get(sgShape, resultType.getElementType());
556
557 for (size_t i = 0; i < numVariants; ++i) {
558 SmallVector<Value> opOperands;
559 for (auto &operandVec : operands)
560 opOperands.push_back(operandVec[i]);
561
562 OperationState state(op->getLoc(), op->getName());
563 state.addOperands(opOperands);
564 state.addTypes(newResultType);
565 state.addAttributes(op->getAttrs());
566 Operation *newOp = rewriter.create(state);
568 newResults.push_back(newOp->getResult(0));
569 }
570
571 rewriter.replaceOpWithMultiple(op, {newResults});
572 return success();
573 }
575
576// clang-format off
577// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
578// If input_layout and target_layout have identical sg_layout and sg_data,
579// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
580// dropped. For example:
581// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
582// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
583// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
584// becomes:
585// #a = #xegpu.layout<inst_data = [16, 16]>
586// #b = #xegpu.layout<inst_data = [8, 16]>
587// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
588// (vector<16x16xf32> is determined by sg_data = [16, 16])
589//
590// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
591// For example:
592// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
593// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
594// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
595// is lowered to:
596// #a = #xegpu.layout<inst_data = [16, 16]>
597// #b = #xegpu.layout<inst_data = [8, 16]>
598// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
599// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
600// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
601// clang-format on
602struct WgToSgConvertLayoutOp
603 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
604 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
605
606 LogicalResult
607 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
608 ConversionPatternRewriter &rewriter) const override {
609 Location loc = op.getLoc();
610
611 VectorType resultType = op.getResult().getType();
612 ArrayRef<int64_t> wgShape = resultType.getShape();
613 auto inputLayout = op.getInputLayout();
614 auto targetLayout = op.getTargetLayout();
615
616 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
617 !targetLayout.isForWorkgroup())
618 return rewriter.notifyMatchFailure(
619 op, "Input and target layouts must have subgroup layout");
620
621 SmallVector<int64_t> inputSgLayout =
622 inputLayout.getEffectiveSgLayoutAsInt();
623 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
624 SmallVector<int64_t> targetSgLayout =
625 targetLayout.getEffectiveSgLayoutAsInt();
626 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
627
628 // Fast path: if sg_layout and sg_data are identical, no SLM needed
629 if (inputLayout.isCompatibleWith(targetLayout,
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
633
634 SmallVector<Value> newOps(adaptor.getSource());
635 if (inputLayout && targetLayout) {
636 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
639 newOps[i] = newOp;
640 }
641 }
642 rewriter.replaceOpWithMultiple(op, {newOps});
643 return success();
644 }
645
646 // SLM path: layouts differ, need cross-subgroup data redistribution
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
648
649 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
650
651 // Calculate SLM size requirements
652 auto bitWidth = elemTy.getIntOrFloatBitWidth();
653 auto bytesPerElement = bitWidth / 8;
654 auto slmSize = computeProduct(slmShape) * bytesPerElement;
655
656 // Allocate SLM
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
659
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
661 elemTy, nullptr);
662 auto memDesc =
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
664
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(), nullptr);
667
668 // STORE PHASE: Each subgroup stores in SLM using input layout
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
671 if (failed(storeCoords))
672 return failure();
673
674 // Store to SLM
675 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
676 SmallVector<OpFoldResult> storeMatrixOffsets;
677 for (Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
679 }
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets, nullptr /*layout*/);
682 }
683
684 gpu::BarrierOp::create(rewriter, loc);
685
686 // LOAD PHASE: Each target subgroup loads from SLM using target layout
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
689 if (failed(loadCoords))
690 return failure();
691
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
693
694 // Load vectors from SLM
695 SmallVector<Value> finalResults;
696 for (auto coords : *loadCoords) {
697 SmallVector<OpFoldResult> loadMatrixOffsets;
698 for (Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
700 }
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
704
705 finalResults.push_back(loadOp.getResult());
706 }
707
708 rewriter.replaceOpWithMultiple(op, {finalResults});
709 return success();
710 }
711};
712
713// Handles UnrealizedConversionCastOp generated during
714// SCFStructuralTypeConversions (step 1). This op may appear as either a
715// target or source materialization for Vector values, e.g.:
716// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
717// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
718// it could be either 1:N or N:1 cast. In both cases, the pattern
719// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
720// for example, the following scf::forOp
721// ```
722// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
723// %n = use(%arg1): vector<128x128xf16>
724// scf.yield %n : vector<128x128xf16>
725// }
726// ```
727// Could be converted to:
728// ```
729// %1 = unrealized_conversion_cast %0
730// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
731// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
732// -> (vector<16x16xf16>, vector<16x16xf16) {
733// %m = unrealized_conversion_cast %arg1, %arg2
734// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
735// %n = use(%m): vector<128x128xf16>
736// %b = unrealized_conversion_cast %n
737// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
738// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
739// }
740// %cast = unrealized_conversion_cast %for:2
741// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
742// ```
743// TODO: remove it when context-aware type converter is ready.
744struct UnrealizedConversionCastOpPattern
745 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
748
749 mlir::LogicalResult
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override {
752 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
753
754 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
756
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(ValueRange(inputs).getTypes()))
759 return failure();
760
761 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
762 // It is generated by source materialization (e.g., inits to scf forOp).
763 // The input values provided by the adaptor should already be distributed,
764 // and their types should correspond exactly to the result types of the
765 // operation.
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
769 return success();
770 }
771
772 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
773 // It is generated by target materialization (e.g., arguments/results
774 // of scf forOp). All input values must have the same vector type, and
775 // their shape must be evenly divisible by the output vector's shape
776 // (determined by the nature of the workgroup to subgroup distribution).
777 // TODO: it is not safe to do such forward, since such N:1 cast could be
778 // from others.
779 if (op.getNumResults() == 1 &&
780 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
781 rewriter.replaceOpWithMultiple(op, {inputs});
782 return success();
783 }
784
785 return mlir::failure();
786 }
787};
788
789// This pattern distributes arith.constant op into subgroup-level constants
790struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
792
793 LogicalResult
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter) const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
799 return failure();
800
801 xegpu::DistributeLayoutAttr layout =
802 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
803 if (!layout || !layout.isForWorkgroup())
804 return failure();
805
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
808 int count;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
810
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
814
815 if (vecAttr.isSplat()) {
816 // Splat: single value for all subgroups
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
818 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
821 return success();
822 } else if (sgShape == wgShape) { // if the entire vector is shared by all
823 // subgroups, don't distribute
824 auto newConstOp =
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
827 return success();
828 } else {
829 // Non-splat constant
830 // Only supports 1D & 2D
831 // TODO: support other cases that require SLM access
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op, "Unsupported element type for non-splat constant op.");
835
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op, "Only 1D & 2D vector constant supported");
839
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
844
845 // Compute colStride and rowStride, and check for constant strides.
846 if (cols > 1) {
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
849 }
850 if (rows > 1) {
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
853 }
854
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
858 // Check column stride
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op, "Non-constant column stride in constant op.");
866 }
867 // Check row stride
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op, "Non-constant row stride in constant op.");
875 }
876 }
877 }
878
879 // Create a constant for the base tile.
880 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
881 // For 1D case, extract the first sgShape[0] elements.
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
888 }
889 }
890
891 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
892 baseTileValues);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
894
895 // Get subgroup id
896 Value sgId =
897 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
898 auto sgOffsets =
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
900 if (failed(sgOffsets))
901 return failure();
902
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
905 arith::ConstantIndexOp::create(rewriter, loc, colStride));
906 if (rows > 1)
907 strideConsts.insert(
908 strideConsts.begin(),
909 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
910
911 SmallVector<Value> newConstOps;
912 for (auto offsets : *sgOffsets) {
913 // Multiply offset with stride, broadcast it and add to baseConstVec
914 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
915 for (size_t i = 0; i < strideConsts.size(); ++i) {
916 Value mul =
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
921 }
922 // Broadcast to baseConstVec size
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
925 auto finalConst =
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
928 }
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
930 return success();
931 }
932 }
933};
934
935// This pattern transforms the LoadGatherOp with explicit offsets to load
936// subgroup data
937struct WgToSgLoadGatherOpWithOffset
938 : public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
940 LogicalResult
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943
944 if (!op.getOffsets())
945 return failure();
946
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
949 if (!resultType)
950 return failure();
951 ArrayRef<int64_t> wgShape = resultType.getShape();
952
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
954
955 if (!layout || !layout.isForWorkgroup())
956 return failure();
957
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
959
960 // The offsets need to be distributed
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
963 auto maskVecType =
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
969 }
970
971 SmallVector<Value> newLoadOps;
972 auto chunkSizeAttr =
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
981 newLayout);
982 newLoadOps.push_back(newLoadOp);
983 }
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
985 return success();
986 }
987};
988
989// This pattern transforms the StoreScatterOp with explicit offsets to store
990// subgroup data
991struct WgToSgStoreScatterOpWithOffset
992 : public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
994 LogicalResult
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997
998 if (!op.getOffsets())
999 return failure();
1000
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1003 if (!valueType)
1004 return failure();
1005
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1007
1008 if (!layout || !layout.isForWorkgroup())
1009 return failure();
1010
1011 // The offsets need to be distributed
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1014 auto maskVecType =
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1020 }
1021
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1031 }
1032 rewriter.eraseOp(op);
1033 return success();
1034 }
1035};
1036
1037struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1039 LogicalResult
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter) const override {
1042
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1045 return failure();
1046
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy && "the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1051
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1061 }
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1063
1064 return success();
1065 }
1066};
1067
1068struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1070 LogicalResult
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter) const override {
1073
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1076 return failure();
1077
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1083 return success();
1084 }
1085};
1086
1087// This pattern distributes the vector.step ops to work at subgroup level
1088struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1090 LogicalResult
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter) const override {
1093 xegpu::DistributeLayoutAttr layout =
1094 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1095 if (!layout || !layout.isForWorkgroup())
1096 return failure();
1097
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1103 if (!sgShape)
1104 return failure();
1105
1106 Value sgId =
1107 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1108 auto sgOffsets =
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1110 if (failed(sgOffsets))
1111 return failure();
1112
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (auto offsets : *sgOffsets) {
1117 // Broadcast the offset scalar to a vector & add to the base steps
1118 auto bcastOffset =
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1120 auto finalSteps =
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1123 }
1124
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1126 return success();
1127 }
1128};
1129
1130// This pattern transforms vector.shape_cast ops to work at subgroup level.
1131struct WgToSgVectorShapeCastOp
1132 : public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1134
1135 LogicalResult
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter) const override {
1138
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1140 if (!resultType)
1141 return failure();
1142
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1145 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1146 if (!layout || !layout.isForWorkgroup())
1147 return failure();
1148
1149 // Check that srcShape and destShape, if they differ, only differ by
1150 // expand of unit dimensions.
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1152 if (!srcType)
1153 return failure();
1154
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1156
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1159 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1160 xegpu::DistributeLayoutAttr sourceLayout =
1161 xegpu::getTemporaryLayout(op->getOpOperand(0));
1162
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1166 });
1167 };
1168
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op, "ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1173
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op, "The ShapeCast op only expands dimensions, the result layout "
1177 "must be a slice of the input layout, or vice versa.");
1178 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1179 layoutToDistribute =
1180 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1181 }
1182
1183 SmallVector<int64_t> sgShape =
1184 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1185 VectorType newResultType =
1186 VectorType::get(sgShape, resultType.getElementType());
1187
1188 SmallVector<Value> newShapeCastOps;
1189 for (auto src : adaptor.getSource()) {
1190 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1191 newResultType, src);
1192 newShapeCastOps.push_back(newShapeCast.getResult());
1193 }
1194
1195 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1196 return success();
1197 }
1198};
1199
1200static Value createAccumulator(ConversionPatternRewriter &rewriter,
1201 Location loc, VectorType type,
1202 vector::CombiningKind kind) {
1203 Type elemTy = type.getElementType();
1204
1205 switch (kind) {
1206 case vector::CombiningKind::ADD:
1207 case vector::CombiningKind::XOR:
1208 case vector::CombiningKind::OR:
1209 return arith::ConstantOp::create(
1210 rewriter, loc, type,
1211 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1212
1213 case vector::CombiningKind::MUL:
1214 case vector::CombiningKind::AND:
1215 return arith::ConstantOp::create(
1216 rewriter, loc, type,
1217 DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy)));
1218
1219 case vector::CombiningKind::MINSI:
1220 // Use max signed int value for signed integer min
1221 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1222 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1223 return arith::ConstantOp::create(
1224 rewriter, loc, type,
1226 rewriter.getIntegerAttr(elemTy, maxVal)));
1227 }
1228 return nullptr;
1229
1230 case vector::CombiningKind::MINUI:
1231 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1232 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1233 return arith::ConstantOp::create(
1234 rewriter, loc, type,
1236 rewriter.getIntegerAttr(elemTy, maxVal)));
1237 }
1238 return nullptr;
1239
1240 case vector::CombiningKind::MAXSI:
1241 if (auto intTy = dyn_cast<IntegerType>(elemTy)) {
1242 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1243 return arith::ConstantOp::create(
1244 rewriter, loc, type,
1246 rewriter.getIntegerAttr(elemTy, minVal)));
1247 }
1248 return nullptr;
1249
1250 case vector::CombiningKind::MAXUI:
1251 return arith::ConstantOp::create(
1252 rewriter, loc, type,
1253 DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy)));
1254
1255 case vector::CombiningKind::MINNUMF:
1256 case vector::CombiningKind::MINIMUMF:
1257 // Use +infinity for float min operations
1258 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1259 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1260 return arith::ConstantOp::create(
1261 rewriter, loc, type,
1262 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf)));
1263 }
1264 return nullptr;
1265
1266 case vector::CombiningKind::MAXNUMF:
1267 case vector::CombiningKind::MAXIMUMF:
1268 // Use -infinity for float max operations
1269 if (auto floatTy = dyn_cast<FloatType>(elemTy)) {
1270 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true);
1271 return arith::ConstantOp::create(
1272 rewriter, loc, type,
1273 DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf)));
1274 }
1275 return nullptr;
1276 }
1277 return nullptr;
1278}
1279
1280/// This pattern transforms vector.multi_dim_reduction operations from
1281/// workgroup-level to subgroup-level execution with support for multiple
1282/// reduction dimensions.
1283///
1284/// Steps include:
1285/// 1. LOCAL REDUCTION :
1286/// - Each subgroup performs local reduction on its data slice
1287/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1288/// phase
1289///
1290/// 2. CROSS-SUBGROUP :
1291/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1292/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1293/// - If not needed, adds original accumulator and returns local results
1294///
1295/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1296/// a) SLM Layout Design:
1297/// - Rows: subgroups participating in reduction (product of sg_layout in
1298/// reduction dims)
1299/// - Cols: total result elements across non-reduction dimensions
1300///
1301/// b) Store Phase:
1302/// - Each subgroup stores its local reduction result to SLM
1303/// - Row offset: linearized index of subgroup in reduction dimensions
1304/// - Col offset: linearized index of subgroup in non-reduction dimensions
1305///
1306/// c) Load and Final Reduction Phase:
1307/// - Each subgroup loads a column of data (all reduction participants for
1308/// its position)
1309/// - Performs final reduction along the loaded dimension
1310/// - Adds original accumulator to get final result
1311///
1312struct WgToSgMultiDimReductionOp
1313 : public OpConversionPattern<vector::MultiDimReductionOp> {
1314 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1315
1316 LogicalResult
1317 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1318 ConversionPatternRewriter &rewriter) const override {
1319 Location loc = op.getLoc();
1320
1321 VectorType srcType = op.getSourceVectorType();
1322 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1323 if (!dstType)
1324 return failure();
1325
1326 auto originalSrcShape = srcType.getShape();
1327 auto originalDstShape = dstType.getShape();
1328 int srcVecRank = originalSrcShape.size();
1329
1330 xegpu::DistributeLayoutAttr layout =
1331 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1332 if (!layout || !layout.isForWorkgroup())
1333 return failure();
1334
1335 auto reductionDims = llvm::to_vector(op.getReductionDims());
1336
1337 // Get sg_layout and sg_data from the parent layout
1338 SmallVector<int64_t> sgLayout;
1339 SmallVector<int64_t> sgData;
1340 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1341 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1342 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1343 } else
1344 return rewriter.notifyMatchFailure(
1345 op, "Reduction should have SliceAttr layout");
1346
1347 Type elemTy = dstType.getElementType();
1348
1349 // Step 1: perform local subgroup reductions with ZERO accumulator
1350 SmallVector<Value> localReductions;
1351 SmallVector<int64_t> sgDstShape =
1352 getSgShapeAndCount(originalDstShape, layout).first;
1353 auto sgSrcs = adaptor.getSource();
1354 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1355 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1356 sgSrcType.getShape().end());
1357
1358 VectorType newDstType = VectorType::get(sgDstShape, elemTy);
1359 for (auto sgSrc : sgSrcs) {
1360 // Create ZERO accumulator for local reduction
1361 auto neutralLocalAcc =
1362 createAccumulator(rewriter, loc, newDstType, op.getKind());
1363 // Local reduction with ZERO accumulator
1364 auto localReduce = vector::MultiDimReductionOp::create(
1365 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1366 reductionDims);
1367 localReductions.push_back(localReduce.getResult());
1368 }
1369
1370 // Check if cross-subgroup reduction is needed for any reduction dimension
1371 SmallVector<int64_t> crossSgReductionDims;
1372 for (int64_t reductionDim : reductionDims) {
1373 bool needsCrossSubgroupReduction =
1374 (sgLayout[reductionDim] > 1) &&
1375 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1376
1377 if (needsCrossSubgroupReduction) {
1378 crossSgReductionDims.push_back(reductionDim);
1379 }
1380 }
1381
1382 // If no cross-subgroup reduction needed, add accumulator and return
1383 if (crossSgReductionDims.empty()) {
1384 SmallVector<Value> results;
1385 for (auto localResult : localReductions) {
1386 auto finalResult = vector::makeArithReduction(
1387 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1388 results.push_back(finalResult);
1389 }
1390 rewriter.replaceOpWithMultiple(op, {results});
1391 return success();
1392 }
1393
1394 // Step 2: cross-subgroup reduction using SLM
1395 auto slmStoreDataShape = sgSrcShape;
1396 for (int64_t dim : reductionDims)
1397 slmStoreDataShape[dim] = 1;
1398 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1399 Value slmStoreData = vector::ShapeCastOp::create(
1400 rewriter, loc, slmStoreDataType, localReductions[0]);
1401
1402 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1403 originalSrcShape.end());
1404 // for reduction dimension, SLM stores partial results from each subgroup
1405 for (int64_t dim : reductionDims)
1406 slmShape[dim] = sgLayout[dim];
1407
1408 // Allocate SLM
1409 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1410 auto bytesPerElement = bitWidth / 8;
1411 auto slmSize = computeProduct(slmShape) * bytesPerElement;
1412 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1413 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1414
1415 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1416 elemTy, nullptr);
1417 auto memDesc =
1418 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1419
1420 // if localReductions have more than 1 result, not support
1421 if (localReductions.size() > 1) {
1422 return rewriter.notifyMatchFailure(
1423 op,
1424 "Multiple local reductions not supported in current implementation.");
1425 }
1426
1427 // Step 4: Store local results to SLM
1428 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1429 rewriter.getIndexType(), nullptr);
1430
1431 // Convert sgLayout to Values for delinearizeIndex
1432 SmallVector<Value> sgLayoutValues;
1433 for (int64_t dim : sgLayout)
1434 sgLayoutValues.push_back(
1435 arith::ConstantIndexOp::create(rewriter, loc, dim));
1436
1437 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1438 sgLayoutValues);
1439 if (failed(sgIdsResult))
1440 return failure();
1441 SmallVector<Value> sgIds = *sgIdsResult;
1442
1443 auto getSlmOffsets = [&](int64_t reductionDimStride) {
1444 SmallVector<OpFoldResult> offsets;
1445 offsets.reserve(srcVecRank);
1446 for (int i = 0; i < srcVecRank; ++i) {
1447 Value dimVal = sgIds[i];
1448 int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
1449 ? reductionDimStride
1450 : sgSrcShape[i];
1451 Value strideVal =
1452 arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
1453 Value offsetVal =
1454 arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1455 offsets.push_back(offsetVal);
1456 }
1457 return offsets;
1458 };
1459
1460 SmallVector<OpFoldResult> slmStoreOffsets =
1461 getSlmOffsets(/*reductionDimStride=*/1);
1462
1463 xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
1464 memDesc.getResult(), slmStoreOffsets,
1465 /*layout=*/nullptr);
1466
1467 gpu::BarrierOp::create(rewriter, loc);
1468
1469 // Step 5: Load from SLM for final reduction
1470 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1471 for (int64_t dim : reductionDims)
1472 slmLoadDataShape[dim] = slmShape[dim];
1473
1474 SmallVector<OpFoldResult> slmLoadOffsets =
1475 getSlmOffsets(/*reductionDimStride=*/0);
1476
1477 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1478 auto slmLoadOp = xegpu::LoadMatrixOp::create(
1479 rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
1480 /*layout=*/nullptr);
1481
1482 // Step 6: Perform final reduction with ZERO accumulator
1483 auto neutralFinalAcc =
1484 createAccumulator(rewriter, loc, newDstType, op.getKind());
1485
1486 auto finalReduce = vector::MultiDimReductionOp::create(
1487 rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
1488 neutralFinalAcc, reductionDims);
1489
1490 // Step 7: Add the original accumulator at the end
1491 auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
1492 finalReduce.getResult(),
1493 adaptor.getAcc()[0]);
1494
1495 rewriter.replaceOp(op, finalResult);
1496 return success();
1497 }
1498};
1499
1500// This pattern transforms vector.transpose ops to work at subgroup level.
1501struct WgToSgVectorTransposeOp
1502 : public OpConversionPattern<vector::TransposeOp> {
1503 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1504
1505 LogicalResult
1506 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1507 ConversionPatternRewriter &rewriter) const override {
1508 VectorType resultType = op.getResultVectorType();
1509
1510 ArrayRef<int64_t> wgShape = resultType.getShape();
1511 xegpu::DistributeLayoutAttr layout =
1512 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1513 if (!layout || !layout.isForWorkgroup())
1514 return failure();
1515 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1516 xegpu::DistributeLayoutAttr sourceLayout =
1517 xegpu::getDistributeLayoutAttr(op.getVector());
1518 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1519 return failure();
1520
1521 SmallVector<int64_t> sourceSgLayout =
1522 sourceLayout.getEffectiveSgLayoutAsInt();
1523 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1524
1525 ArrayRef<int64_t> permutation = op.getPermutation();
1526 size_t permutationSize = permutation.size();
1527 if (sourceSgLayout.size() != permutationSize ||
1528 resultSgLayout.size() != permutationSize) {
1529 return rewriter.notifyMatchFailure(
1530 op, "Layouts and permutation must have the same rank");
1531 }
1532
1533 // Check that sgLayout, sgData & order are properly transposed for source
1534 // and result
1535 if (!layout.isTransposeOf(sourceLayout, permutation,
1536 xegpu::LayoutKind::Subgroup))
1537 return rewriter.notifyMatchFailure(
1538 op, "Result layout is not a valid transpose of source layout "
1539 "according to permutation");
1540
1541 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1542 VectorType newResultType =
1543 VectorType::get(sgShape, resultType.getElementType());
1544
1545 SmallVector<Value> newTransposeOps;
1546 for (auto src : adaptor.getVector()) {
1547 auto newTranspose = vector::TransposeOp::create(
1548 rewriter, op.getLoc(), newResultType, src, permutation);
1549 newTransposeOps.push_back(newTranspose.getResult());
1550 }
1551 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1552 return success();
1553 }
1554};
1555
1556// Distribute vector mask ops to work at subgroup level.
1557template <typename MaskOpType>
1558struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1559 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1560
1561 LogicalResult matchAndRewrite(
1562 MaskOpType op,
1563 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1564 ConversionPatternRewriter &rewriter) const override {
1565 xegpu::DistributeLayoutAttr layout =
1566 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1567 if (!layout || !layout.isForWorkgroup())
1568 return failure();
1569
1570 Location loc = op.getLoc();
1571 VectorType type = op.getResult().getType();
1572 auto wgShape = type.getShape();
1573
1574 SmallVector<Value> wgMaskDimSizes;
1575 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1576 for (int64_t maskSize : op.getMaskDimSizes()) {
1577 wgMaskDimSizes.push_back(
1578 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1579 }
1580 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1581 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1582 }
1583
1584 Value sgId =
1585 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1586 auto sgOffsets =
1587 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1588 if (failed(sgOffsets))
1589 return failure();
1590
1591 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1592 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1593
1594 // In each dimension, each subgroup computes its local mask size as:
1595 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1596 SmallVector<Value> newCreateMaskOps;
1597 for (auto offsetSet : *sgOffsets) {
1598 SmallVector<Value> maskOperands;
1599
1600 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1601 Value dimSizeVal =
1602 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1603 Value offset = offsetSet[i];
1604 Value adjustedMaskSize =
1605 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1606 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1607 Value nonNegative =
1608 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1609 Value sgMaskSize =
1610 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1611 maskOperands.push_back(sgMaskSize);
1612 }
1613
1614 auto newCreateMaskOp =
1615 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1616 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1617 }
1618
1619 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1620 return success();
1621 }
1622};
1623
1624using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1625using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1626} // namespace
1627
1628namespace mlir {
1629namespace xegpu {
1631 patterns
1632 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1633 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1634 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1635 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1636 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1637 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1638 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1639 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1640 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1641 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1642 patterns.getContext());
1643}
1644} // namespace xegpu
1645} // namespace mlir
1646
1647namespace {
1648struct XeGPUWgToSgDistributePass
1649 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1650 void runOnOperation() override;
1651};
1652} // namespace
1653
1654void XeGPUWgToSgDistributePass::runOnOperation() {
1655
1656 Operation *op = getOperation();
1658 signalPassFailure();
1659 return;
1660 }
1661
1662 // Track existing UnrealizedConversionCastOps
1663 SmallVector<Operation *> existingCastOps;
1664 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1665 existingCastOps.push_back(castOp.getOperation());
1666 });
1667
1668 {
1669 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1670 // VectorType operands. This first converts such operands to
1671 // RankedTensorType, propagates the layout attribute into the encoding
1672 // attribute, and finally converts the RankedTensorType to VectorType based
1673 // on the encoding.
1674
1675 TypeConverter converter;
1676 converter.addConversion([&](Type type) -> Type { return type; });
1677 converter.addConversion(
1678 [&](RankedTensorType type,
1679 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1680 // Only convert RankedTensorTypes that carry an XeGPU layout encoding.
1681 // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must
1682 // not be converted: VectorType does not support dynamic dimensions.
1683 auto encoding =
1684 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
1685 if (!encoding)
1686 return std::nullopt;
1687
1688 Type elemTy = type.getElementType();
1689 ArrayRef<int64_t> shape = type.getShape();
1690
1691 int count;
1692 SmallVector<int64_t> subShape;
1693 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1694
1695 auto newTy = VectorType::get(subShape, elemTy);
1696 result.append(count, newTy);
1697 return success();
1698 });
1699
1701 converter);
1702 }
1703
1704 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1705 // as well as XeGPU, Arith, and Vector operations.
1706 MLIRContext *ctx = &getContext();
1707 RewritePatternSet patterns(ctx);
1708 ConversionTarget target(*ctx);
1709 TypeConverter converter;
1710 converter.addConversion([&](Type type) -> Type { return type; });
1711 converter.addConversion(
1712 [&](xegpu::TensorDescType type,
1713 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1714 Type elemTy = type.getElementType();
1715 ArrayRef<int64_t> shape = type.getShape();
1716
1717 int count;
1718 SmallVector<int64_t> subShape;
1719 xegpu::LayoutAttr layout = type.getLayoutAttr();
1720 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1721
1722 if (layout)
1723 layout = layout.dropSgLayoutAndData();
1724
1725 auto newTy = xegpu::TensorDescType::get(
1726 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1727 result.append(count, newTy);
1728 return success();
1729 });
1730
1731 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1732 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1733 return createOp.getType();
1734 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1735 return loadOp.getTensorDescType();
1736 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1737 return storeOp.getTensorDescType();
1738 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1739 return updateOp.getType();
1740 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1741 return prefetchOp.getTensorDescType();
1742 return xegpu::TensorDescType();
1743 };
1744
1745 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1746 return !layout || !layout.isForWorkgroup();
1747 };
1748
1749 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1750 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1751 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1752 auto tdescTy = getTensorDescType(op);
1753 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1754 return isLegal(layout);
1755 });
1756
1757 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1758 auto layout = op.getLayoutCdAttr();
1759 return isLegal(layout);
1760 });
1761
1762 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1763 [=](xegpu::LoadMatrixOp op) -> bool {
1764 return isLegal(op.getLayoutAttr());
1765 });
1766
1767 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1768 [=](xegpu::StoreMatrixOp op) -> bool {
1769 return isLegal(op.getLayoutAttr());
1770 });
1771
1772 target.addDynamicallyLegalOp<arith::ConstantOp>(
1773 [=](arith::ConstantOp op) -> bool {
1774 auto vecType = dyn_cast<VectorType>(op.getType());
1775 if (!vecType)
1776 return true;
1777
1778 auto layout =
1779 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1780 return isLegal(layout);
1781 });
1782
1783 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1784 vector::TransposeOp, vector::BroadcastOp,
1785 vector::MultiDimReductionOp,
1786 vector::ConstantMaskOp, vector::CreateMaskOp>(
1787 [=](Operation *op) -> bool {
1788 // Check for either a SliceAttr or LayoutAttr on the result.
1789 auto layout =
1790 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1791 return isLegal(layout);
1792 });
1793
1794 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1795 [=](xegpu::LoadGatherOp op) -> bool {
1796 auto layout = op.getLayoutAttr();
1797 return isLegal(layout);
1798 });
1799
1800 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1801 [=](xegpu::StoreScatterOp op) -> bool {
1802 auto layout = op.getLayoutAttr();
1803 return isLegal(layout);
1804 });
1805
1806 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1807 [=](xegpu::ConvertLayoutOp op) -> bool {
1808 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1809 });
1810
1811 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1812 [=](Operation *op) -> std::optional<bool> {
1813 // Only handle elementwise mappable ops
1815 return true;
1816
1817 VectorType resultType =
1818 dyn_cast<VectorType>(op->getResult(0).getType());
1819 if (!resultType)
1820 return true;
1821
1822 // Check if all operands are vectors of the same shape
1823 // TODO: Support other types.
1824 for (Value operand : op->getOperands()) {
1825 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1826 if (!operandType || operandType.getShape() != resultType.getShape()) {
1827 return true;
1828 }
1829 }
1830
1831 xegpu::DistributeLayoutAttr layout =
1833 return isLegal(layout);
1834 });
1835
1836 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1837 [=](UnrealizedConversionCastOp op) {
1838 return llvm::is_contained(existingCastOps, op.getOperation());
1839 });
1840
1841 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1842
1844 target);
1846 if (failed(
1847 applyPartialConversion(getOperation(), target, std::move(patterns))))
1848 return signalPassFailure();
1849
1850 // Remove layout attributes from SCF ops
1851 getOperation()->walk([](Operation *op) {
1852 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1853 return;
1854
1855 SmallVector<StringAttr> attrsToRemove;
1856 for (auto namedAttr : op->getDiscardableAttrs()) {
1857 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1858 attrsToRemove.push_back(namedAttr.getName());
1859 }
1860 for (auto attrName : attrsToRemove)
1861 op->removeDiscardableAttr(attrName);
1862 });
1863}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:515
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:501
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)