MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include <optional>
26
27namespace mlir {
28namespace xegpu {
29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31} // namespace xegpu
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Retrieve the RangeAttr if it is specified.
39static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
40 Operation *parent = op->getParentOfType<scf::IfOp>();
41 while (parent) {
42 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->getAttr("sg_id_range")))
44 return attr;
45 parent = parent->getParentOfType<scf::IfOp>();
46 }
47 return {};
48}
49
50static std::pair<SmallVector<int64_t>, int>
51getSgShapeAndCount(ArrayRef<int64_t> shape,
52 xegpu::DistributeLayoutAttr layout) {
53 int count = 1;
55 auto distributedShape = layout.computeDistributedShape(
56 SmallVector<int64_t>(shape.begin(), shape.end()));
57 if (failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
60 count = computeProduct(distributedShape.value()) / computeProduct(sgData);
61 return std::make_pair(sgData, count);
62}
63
64/// Utility helper for deriving a list of offsets for each sub-TensorDescs
65/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
66/// associated distribute layout attribute, the shape, subgroup id and the
67/// original offsets of the op
68template <
69 typename OpType,
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
72 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
73static LogicalResult
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
76 Location loc = op.getLoc();
77 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
78 // not applicable to ops without offsets operands.
79 if (origOffsets.empty())
80 return failure();
81
82 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
87 } else {
88 layout = op.getDescLayoutAttr();
89 }
90
91 // not applicable to ops without workgroup layout attributes
92 if (!layout || !layout.isForWorkgroup())
93 return failure();
94
95 Value sgId =
96 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
97
98 // verify and adjust the sgId if the range specifier is present
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 if (sgIdRange) {
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 // verify the RangeAttr against the layout attribute
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op, "sg_layout size must match the sg_id_range");
107 // adjust the sgId if necessary
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
110 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
112 }
113 }
114
115 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
116 // descriptors to be accessed, based on the layout information.
117 ArrayRef<int64_t> wgShape = op.getDataShape();
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (failed(maybeDescOffsets))
121 return failure();
122
123 // Compute the final global offsets for each accessed sub-tensor
124 // or sub-memory descriptor.
125 for (const auto &sgOffsets : *maybeDescOffsets) {
127 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
128 offsetsList.push_back(std::move(newOffsets));
129 }
130
131 // callback(offsetsList);
132 return success();
133}
134
135/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
136/// from a workgroup descriptor. It replaces the offsets and sizes with
137/// appropriate values for the subgroup.
138/// It uses round-robin assignment to distribute the work to the subgroups.
139/// Following create_nd_desc operation:,
140/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
141/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
142/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
143/// is converted to 9 subgroup level operations based on the sg_layout &
144/// sg_data:
145/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
146/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
147/// lane_data = [1, 1]>>
148///
149/// The sg_layout and sg_data attributes are dropped after the pass as they are
150/// no longer needed.
151///
152/// 24x24 matrix distribution example:
153/// sg_layout = [4, 4], sg_data = [2, 2]
154/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
155/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
156///
157/// +------------------------+
158/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
159/// |-----+-----+-----|
160/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
161/// |-----+-----+-----|
162/// | 8x8 | 8x8 | 8x8 |
163/// +------------------------+
164///
165/// Each 8x8 tile is further subdivided among subgroups:
166/// +------------------------+
167/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
168/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
169/// | 2x2 2x2 2x2 2x2 |
170/// | 2x2 2x2 2x2 2x2 |
171/// +------------------------+
172///
173/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
174/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
175
176/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
177/// pattern and all the other ops just follow.
178/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
179/// ops in the pass.
180struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
181 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186 SmallVector<SmallVector<OpFoldResult>> offsetsList;
187 if (failed(genOffsetsList(rewriter, op, offsetsList)))
188 return failure();
189
190 MLIRContext *ctx = op.getContext();
191 xegpu::TensorDescType tdescTy = op.getType();
192 ArrayRef<int64_t> wgShape = tdescTy.getShape();
193 Type elemTy = tdescTy.getElementType();
194 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
195 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
196 auto newTdescTy =
197 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
198 layout.dropSgLayoutAndData());
199
200 SmallVector<Value> newOps;
201 for (auto offsets : offsetsList) {
202 auto newOp = xegpu::CreateNdDescOp::create(
203 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
204 op.getMixedSizes(), op.getMixedStrides());
205
206 newOps.push_back(newOp);
207 }
208 rewriter.replaceOpWithMultiple(op, {newOps});
209
210 return success();
211 }
212};
213
214// This pattern transforms the CreateNdDescOp without offsets to create a
215// subgroup descriptor from a workgroup descriptor
216struct WgToSgCreateNdOpNoOffset
217 : public OpConversionPattern<xegpu::CreateNdDescOp> {
218 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
219
220 LogicalResult
221 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223
224 // Check no offsets are specified.
225 if (!op.getMixedOffsets().empty())
226 return failure();
227
228 Location loc = op.getLoc();
229 MLIRContext *ctx = op.getContext();
230 xegpu::TensorDescType tdescTy = op.getType();
231 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
232 if (!layout || !layout.isForWorkgroup())
233 return failure();
234
235 Type elemTy = tdescTy.getElementType();
236 ArrayRef<int64_t> wgShape = tdescTy.getShape();
237
238 SmallVector<int64_t> sgShape;
239 int count;
240 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
241 xegpu::TensorDescType newTdescTy =
242 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
243 layout.dropSgLayoutAndData());
244
245 SmallVector<Value> newCreateNdOps(count);
246 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
247 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
248 op.getSource(), op.getMixedSizes(),
249 op.getMixedStrides());
250 });
251
252 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
253 return success();
254 }
255};
256
257/// This pattern transforms the LoadNdOp to load subgroup data.
258struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
259 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
260 LogicalResult
261 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 if (!op.getMixedOffsets().empty())
264 return failure();
265
266 SmallVector<Value> newLoadOps;
267 for (auto src : adaptor.getTensorDesc()) {
268 xegpu::TensorDescType tdescTy =
269 dyn_cast<xegpu::TensorDescType>(src.getType());
270 ArrayRef<int64_t> srcShape = tdescTy.getShape();
271 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
272 auto newLoadOp = xegpu::LoadNdOp::create(
273 rewriter, op.getLoc(), newResTy, src,
274 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
275 newLoadOps.push_back(newLoadOp);
276 }
277 rewriter.replaceOpWithMultiple(op, {newLoadOps});
278 return mlir::success();
279 }
280};
281
282/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
283/// It creates a StoreNdOp op to store the updated values to the new subgroup
284/// src tensor descriptors.
285struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
286 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
287 LogicalResult
288 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 if (!op.getMixedOffsets().empty())
291 return failure();
292
293 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
294 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
296
297 rewriter.eraseOp(op);
298 return success();
299 }
300};
301
302// This pattern transforms the LoadNdOp with explicit offsets to load
303// subgroup data.
304struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
305 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
306 LogicalResult
307 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override {
309
310 SmallVector<SmallVector<OpFoldResult>> offsetsList;
311 if (failed(genOffsetsList(rewriter, op, offsetsList)))
312 return failure();
313
314 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
315 if (layout)
316 layout = layout.dropSgLayoutAndData();
317 SmallVector<Value> newOps;
318 for (auto [tdesc, offsets] :
319 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
320 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
321 VectorType newResTy =
322 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
323 auto newOp = xegpu::LoadNdOp::create(
324 rewriter, op.getLoc(), newResTy, tdesc, offsets,
325 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
326 op.getL2HintAttr(), op.getL3HintAttr(), layout);
327 newOps.push_back(newOp);
328 }
329 rewriter.replaceOpWithMultiple(op, {newOps});
330
331 return success();
332 }
333};
334
335// This pattern transforms the StoreNdOp with explicit offsets to store
336// subgroup data.
337struct WgToSgStoreNdOpWithOffset
338 : public OpConversionPattern<xegpu::StoreNdOp> {
339 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
340 LogicalResult
341 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const override {
343 SmallVector<SmallVector<OpFoldResult>> offsetsList;
344 if (failed(genOffsetsList(rewriter, op, offsetsList)))
345 return failure();
346
347 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
348 if (layout)
349 layout = layout.dropSgLayoutAndData();
350 for (auto [v, tdesc, offsets] :
351 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
352 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
353 op.getL1HintAttr(), op.getL2HintAttr(),
354 op.getL3HintAttr(), layout);
355 }
356 rewriter.eraseOp(op);
357
358 return success();
359 }
360};
361
362// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
363// subgroup data.
364struct WgToSgPrefetchNdOpWithOffset
365 : public OpConversionPattern<xegpu::PrefetchNdOp> {
366 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
367 LogicalResult
368 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter) const override {
370 SmallVector<SmallVector<OpFoldResult>> offsetsList;
371 if (failed(genOffsetsList(rewriter, op, offsetsList)))
372 return failure();
373
374 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
375 if (layout)
376 layout = layout.dropSgLayoutAndData();
377 for (auto [tdesc, offsets] :
378 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
379 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
380 op.getL1HintAttr(), op.getL2HintAttr(),
381 op.getL3HintAttr(), layout);
382 }
383 rewriter.eraseOp(op);
384
385 return success();
386 }
387};
388
389/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
390/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
391/// offsets of the new subgroup src tensor descriptors.
392struct WgToSgUpdateNdOffsetOp
393 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
394 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
395 LogicalResult
396 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
397 ConversionPatternRewriter &rewriter) const override {
398 llvm::SmallVector<Value> newUpdateTileOffsetOps;
399 for (auto tDesc : adaptor.getTensorDesc()) {
400 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
401 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
402 op.getConstOffsets());
403 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
404 }
405
406 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
407 return success();
408 }
409};
410
411/// This pattern transforms the DpasOp to work at subgroup level.
412struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
413 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
414 LogicalResult
415 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter) const override {
417 Location loc = op.getLoc();
418 VectorType resultTy = op.getResult().getType();
419 if (resultTy.getRank() != 2)
420 return failure();
421
422 auto layoutCd = op.getLayoutCdAttr();
423 auto layoutA = op.getLayoutAAttr();
424 auto layoutB = op.getLayoutBAttr();
425 if (!layoutCd || !layoutA || !layoutB)
426 return failure();
427 size_t i = 0;
428 SmallVector<Value> newDpasOps;
429 for (auto aVec : adaptor.getLhs()) {
430 for (auto bVec : adaptor.getRhs()) {
431
432 llvm::SmallVector<Value> operands({aVec, bVec});
433 Value tmpC;
434 if (op.getAcc()) {
435 tmpC = adaptor.getAcc()[i++];
436 operands.push_back(tmpC);
437 }
438
439 ArrayRef<int64_t> aVecShape =
440 llvm::cast<VectorType>(aVec.getType()).getShape();
441 ArrayRef<int64_t> bVecShape =
442 llvm::cast<VectorType>(bVec.getType()).getShape();
443 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
444 resultTy.getElementType());
445 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
446 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
447 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
448 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
449
450 newDpasOps.push_back(newDpasOp);
451 }
452 }
453 rewriter.replaceOpWithMultiple(op, {newDpasOps});
454 return success();
455 }
456};
457
458/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
459struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
460 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
461 LogicalResult
462 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter) const override {
464
465 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
466 if ((offsetSize != 0) || op.getConstOffsetsAttr())
467 return failure();
468
469 for (auto src : adaptor.getTensorDesc())
470 xegpu::PrefetchNdOp::create(
471 rewriter, op.getLoc(), TypeRange(), src,
472 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
473 rewriter.eraseOp(op);
474 return success();
475 }
476};
477
478/// This pattern transforms vector.broadcast ops to work at subgroup level.
479struct WgToSgVectorBroadcastOp
480 : public OpConversionPattern<vector::BroadcastOp> {
481 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486
487 VectorType resultType = op.getResult().getType();
488 ArrayRef<int64_t> wgShape = resultType.getShape();
489
490 xegpu::DistributeLayoutAttr layout =
491 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
492 if (!layout || !layout.isForWorkgroup())
493 return failure();
494
495 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
496 VectorType newResultType =
497 VectorType::get(sgShape, resultType.getElementType());
498
499 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
500 return failure();
501
502 SmallVector<Value> newBroadcastOps;
503 for (auto operand : adaptor.getOperands().front()) {
504 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
505 newResultType, operand);
506
507 newBroadcastOps.push_back(newBroadcast.getResult());
508 }
509 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
510 return success();
511 }
512};
513
514// This pattern transforms elementwise ops to work at subgroup level.
515struct WgToSgElementwiseOp : public ConversionPattern {
516 WgToSgElementwiseOp(MLIRContext *ctx)
517 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
518
519 LogicalResult
520 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
521 ConversionPatternRewriter &rewriter) const override {
522 // Only match ops with elementwise trait and single result.
524 return failure();
525
526 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
527 assert(resultType && "Expected result to be a VectorType");
528
529 ArrayRef<int64_t> wgShape = resultType.getShape();
530
531 xegpu::DistributeLayoutAttr layout =
532 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
533 if (!layout || !layout.isForWorkgroup())
534 return failure();
535
536 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
537
538 size_t numVariants = operands.empty() ? 0 : operands.front().size();
539
540 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
541 return operandVec.size() != numVariants;
542 }))
543 return failure();
544
545 SmallVector<Value> newResults;
546 VectorType newResultType =
547 VectorType::get(sgShape, resultType.getElementType());
548
549 for (size_t i = 0; i < numVariants; ++i) {
550 SmallVector<Value> opOperands;
551 for (auto &operandVec : operands)
552 opOperands.push_back(operandVec[i]);
553
554 OperationState state(op->getLoc(), op->getName());
555 state.addOperands(opOperands);
556 state.addTypes(newResultType);
557 state.addAttributes(op->getAttrs());
558 Operation *newOp = rewriter.create(state);
560 newResults.push_back(newOp->getResult(0));
561 }
562
563 rewriter.replaceOpWithMultiple(op, {newResults});
564 return success();
565 }
566};
567
568// clang-format off
569// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
570// If input_layout and target_layout have identical sg_layout and sg_data,
571// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
572// dropped. For example:
573// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
574// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
575// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
576// becomes:
577// #a = #xegpu.layout<inst_data = [16, 16]>
578// #b = #xegpu.layout<inst_data = [8, 16]>
579// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
580// (vector<16x16xf32> is determined by sg_data = [16, 16])
581//
582// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
583// For example:
584// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
585// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
586// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
587// is lowered to:
588// #a = #xegpu.layout<inst_data = [16, 16]>
589// #b = #xegpu.layout<inst_data = [8, 16]>
590// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
591// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
592// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
593// clang-format on
594struct WgToSgConvertLayoutOp
595 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
596 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
597
598 LogicalResult
599 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
600 ConversionPatternRewriter &rewriter) const override {
601 Location loc = op.getLoc();
602 auto inputLayout = op.getInputLayout();
603 auto targetLayout = op.getTargetLayout();
604
605 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
606 !targetLayout.isForWorkgroup())
607 return rewriter.notifyMatchFailure(
608 op, "Input and target layouts must have subgroup layout");
609
610 Type resultType = op.getResult().getType();
611 if (resultType.isIntOrFloat()) {
612 rewriter.replaceOp(op, op.getSource());
613 assert(!inputLayout.dropSgLayoutAndData() &&
614 !targetLayout.dropSgLayoutAndData() &&
615 "unexpected layout attributes for scalar type");
616 return success();
617 }
618
619 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
620 SmallVector<int64_t> inputSgLayout =
621 inputLayout.getEffectiveSgLayoutAsInt();
622 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
623 SmallVector<int64_t> targetSgLayout =
624 targetLayout.getEffectiveSgLayoutAsInt();
625 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
626
627 // Fast path: if sg_layout and sg_data are identical, no SLM needed
628 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
629 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
630 xegpu::LayoutKind::Subgroup)) {
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
633
634 SmallVector<Value> newOps(adaptor.getSource());
635 if (inputLayout && targetLayout) {
636 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
639 newOps[i] = newOp;
640 }
641 }
642 rewriter.replaceOpWithMultiple(op, {newOps});
643 return success();
644 }
645
646 // SLM path: layouts differ, need cross-subgroup data redistribution
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
648
649 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
650
651 // Calculate SLM size requirements
652 auto bitWidth = elemTy.getIntOrFloatBitWidth();
653 auto bytesPerElement = bitWidth / 8;
654 auto slmSize = computeProduct(slmShape) * bytesPerElement;
655
656 // Allocate SLM
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
659
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
661 elemTy, nullptr);
662 auto memDesc =
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
664
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(), nullptr);
667
668 // STORE PHASE: Each subgroup stores in SLM using input layout
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
671 if (failed(storeCoords))
672 return failure();
673
674 // Store to SLM
675 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
676 SmallVector<OpFoldResult> storeMatrixOffsets;
677 for (Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
679 }
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets, nullptr /*layout*/);
682 }
683
684 gpu::BarrierOp::create(rewriter, loc);
685
686 // LOAD PHASE: Each target subgroup loads from SLM using target layout
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
689 if (failed(loadCoords))
690 return failure();
691
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
693
694 // Load vectors from SLM
695 SmallVector<Value> finalResults;
696 for (auto coords : *loadCoords) {
697 SmallVector<OpFoldResult> loadMatrixOffsets;
698 for (Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
700 }
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
704
705 finalResults.push_back(loadOp.getResult());
706 }
707
708 rewriter.replaceOpWithMultiple(op, {finalResults});
709 return success();
710 }
711};
712
713// Handles UnrealizedConversionCastOp generated during
714// SCFStructuralTypeConversions (step 1). This op may appear as either a
715// target or source materialization for Vector values, e.g.:
716// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
717// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
718// it could be either 1:N or N:1 cast. In both cases, the pattern
719// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
720// for example, the following scf::forOp
721// ```
722// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
723// %n = use(%arg1): vector<128x128xf16>
724// scf.yield %n : vector<128x128xf16>
725// }
726// ```
727// Could be converted to:
728// ```
729// %1 = unrealized_conversion_cast %0
730// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
731// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
732// -> (vector<16x16xf16>, vector<16x16xf16) {
733// %m = unrealized_conversion_cast %arg1, %arg2
734// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
735// %n = use(%m): vector<128x128xf16>
736// %b = unrealized_conversion_cast %n
737// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
738// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
739// }
740// %cast = unrealized_conversion_cast %for:2
741// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
742// ```
743// TODO: remove it when context-aware type converter is ready.
744struct UnrealizedConversionCastOpPattern
745 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
748
749 mlir::LogicalResult
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override {
752 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
753
754 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
756
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(ValueRange(inputs).getTypes()))
759 return failure();
760
761 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
762 // It is generated by source materialization (e.g., inits to scf forOp).
763 // The input values provided by the adaptor should already be distributed,
764 // and their types should correspond exactly to the result types of the
765 // operation.
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
769 return success();
770 }
771
772 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
773 // It is generated by target materialization (e.g., arguments/results
774 // of scf forOp). All input values must have the same vector type, and
775 // their shape must be evenly divisible by the output vector's shape
776 // (determined by the nature of the workgroup to subgroup distribution).
777 // TODO: it is not safe to do such forward, since such N:1 cast could be
778 // from others.
779 if (op.getNumResults() == 1 &&
780 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
781 rewriter.replaceOpWithMultiple(op, {inputs});
782 return success();
783 }
784
785 return mlir::failure();
786 }
787};
788
789// This pattern distributes arith.constant op into subgroup-level constants
790struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
792
793 LogicalResult
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter) const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
799 return failure();
800
801 xegpu::DistributeLayoutAttr layout =
802 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
803 if (!layout || !layout.isForWorkgroup())
804 return failure();
805
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
808 int count;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
810
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
814
815 if (vecAttr.isSplat()) {
816 // Splat: single value for all subgroups
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
818 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
821 return success();
822 } else if (sgShape == wgShape) { // if the entire vector is shared by all
823 // subgroups, don't distribute
824 auto newConstOp =
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
827 return success();
828 } else {
829 // Non-splat constant
830 // Only supports 1D & 2D
831 // TODO: support other cases that require SLM access
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op, "Unsupported element type for non-splat constant op.");
835
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op, "Only 1D & 2D vector constant supported");
839
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
844
845 // Compute colStride and rowStride, and check for constant strides.
846 if (cols > 1) {
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
849 }
850 if (rows > 1) {
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
853 }
854
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
858 // Check column stride
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op, "Non-constant column stride in constant op.");
866 }
867 // Check row stride
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op, "Non-constant row stride in constant op.");
875 }
876 }
877 }
878
879 // Create a constant for the base tile.
880 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
881 // For 1D case, extract the first sgShape[0] elements.
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
888 }
889 }
890
891 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
892 baseTileValues);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
894
895 // Get subgroup id
896 Value sgId =
897 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
898 auto sgOffsets =
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
900 if (failed(sgOffsets))
901 return failure();
902
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
905 arith::ConstantIndexOp::create(rewriter, loc, colStride));
906 if (rows > 1)
907 strideConsts.insert(
908 strideConsts.begin(),
909 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
910
911 SmallVector<Value> newConstOps;
912 for (auto offsets : *sgOffsets) {
913 // Multiply offset with stride, broadcast it and add to baseConstVec
914 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
915 for (size_t i = 0; i < strideConsts.size(); ++i) {
916 Value mul =
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
921 }
922 // Broadcast to baseConstVec size
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
925 auto finalConst =
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
928 }
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
930 return success();
931 }
932 }
933};
934
935// This pattern transforms the LoadGatherOp with explicit offsets to load
936// subgroup data
937struct WgToSgLoadGatherOpWithOffset
938 : public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
940 LogicalResult
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943
944 if (!op.getOffsets())
945 return failure();
946
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
949 if (!resultType)
950 return failure();
951 ArrayRef<int64_t> wgShape = resultType.getShape();
952
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
954
955 if (!layout || !layout.isForWorkgroup())
956 return failure();
957
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
959
960 // The offsets need to be distributed
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
963 auto maskVecType =
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
969 }
970
971 SmallVector<Value> newLoadOps;
972 auto chunkSizeAttr =
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
981 newLayout);
982 newLoadOps.push_back(newLoadOp);
983 }
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
985 return success();
986 }
987};
988
989// This pattern transforms the StoreScatterOp with explicit offsets to store
990// subgroup data
991struct WgToSgStoreScatterOpWithOffset
992 : public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
994 LogicalResult
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997
998 if (!op.getOffsets())
999 return failure();
1000
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1003 if (!valueType)
1004 return failure();
1005
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1007
1008 if (!layout || !layout.isForWorkgroup())
1009 return failure();
1010
1011 // The offsets need to be distributed
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1014 auto maskVecType =
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1020 }
1021
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1031 }
1032 rewriter.eraseOp(op);
1033 return success();
1034 }
1035};
1036
1037struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1039 LogicalResult
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter) const override {
1042
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1045 return failure();
1046
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy && "the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1051
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1061 }
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1063
1064 return success();
1065 }
1066};
1067
1068struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1070 LogicalResult
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter) const override {
1073
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1076 return failure();
1077
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1083 return success();
1084 }
1085};
1086
1087// This pattern distributes the vector.step ops to work at subgroup level
1088struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1090 LogicalResult
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter) const override {
1093 xegpu::DistributeLayoutAttr layout =
1094 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1095 if (!layout || !layout.isForWorkgroup())
1096 return failure();
1097
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1103 if (!sgShape)
1104 return failure();
1105
1106 Value sgId =
1107 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1108 auto sgOffsets =
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1110 if (failed(sgOffsets))
1111 return failure();
1112
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (auto offsets : *sgOffsets) {
1117 // Broadcast the offset scalar to a vector & add to the base steps
1118 auto bcastOffset =
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1120 auto finalSteps =
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1123 }
1124
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1126 return success();
1127 }
1128};
1129
1130// This pattern transforms vector.shape_cast ops to work at subgroup level.
1131struct WgToSgVectorShapeCastOp
1132 : public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1134
1135 LogicalResult
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter) const override {
1138
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1140 if (!resultType)
1141 return failure();
1142
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1145 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1146 if (!layout || !layout.isForWorkgroup())
1147 return failure();
1148
1149 // Check that srcShape and destShape, if they differ, only differ by
1150 // expand of unit dimensions.
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1152 if (!srcType)
1153 return failure();
1154
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1156
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1159 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1160 xegpu::DistributeLayoutAttr sourceLayout =
1161 xegpu::getTemporaryLayout(op->getOpOperand(0));
1162
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1166 });
1167 };
1168
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op, "ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1173
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op, "The ShapeCast op only expands dimensions, the input layout "
1177 "must be a slice of the result layout.");
1178
1179 assert(layoutToDistribute.isEqualTo(
1180 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1181 "The sg_data for unit dimensions should be set as 1");
1182 }
1183
1184 SmallVector<int64_t> sgShape =
1185 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1186 VectorType newResultType =
1187 VectorType::get(sgShape, resultType.getElementType());
1188
1189 SmallVector<Value> newShapeCastOps;
1190 for (auto src : adaptor.getSource()) {
1191 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1192 newResultType, src);
1193 newShapeCastOps.push_back(newShapeCast.getResult());
1194 }
1195
1196 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1197 return success();
1198 }
1199};
1200
1201/// This pattern transforms vector.multi_dim_reduction operations from
1202/// workgroup-level to subgroup-level execution with support for multiple
1203/// reduction dimensions.
1204///
1205/// Steps include:
1206/// 1. LOCAL REDUCTION :
1207/// - Each subgroup performs local reduction on its data slice
1208/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1209/// phase
1210///
1211/// 2. CROSS-SUBGROUP :
1212/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1213/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1214/// - If not needed, adds original accumulator and returns local results
1215///
1216/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1217/// a) SLM Layout Design:
1218/// - Rows: subgroups participating in reduction (product of sg_layout in
1219/// reduction dims)
1220/// - Cols: total result elements across non-reduction dimensions
1221///
1222/// b) Store Phase:
1223/// - Each subgroup stores its local reduction result to SLM
1224/// - Row offset: linearized index of subgroup in reduction dimensions
1225/// - Col offset: linearized index of subgroup in non-reduction dimensions
1226///
1227/// c) Load and Final Reduction Phase:
1228/// - Each subgroup loads a column of data (all reduction participants for
1229/// its position)
1230/// - Performs final reduction along the loaded dimension
1231/// - Adds original accumulator to get final result
1232///
1233struct WgToSgMultiDimReductionOp
1234 : public OpConversionPattern<vector::MultiDimReductionOp> {
1235 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1236
1237 LogicalResult
1238 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1239 ConversionPatternRewriter &rewriter) const override {
1240 Location loc = op.getLoc();
1241
1242 VectorType srcType = op.getSourceVectorType();
1243 Type resultTy = op.getResult().getType();
1244 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1245 bool isScalarResult = !dstVecType;
1246
1247 auto originalSrcShape = srcType.getShape();
1248 int srcVecRank = originalSrcShape.size();
1249 Type elemTy = srcType.getElementType();
1250
1251 xegpu::DistributeLayoutAttr layout =
1252 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1253 if (!layout || !layout.isForWorkgroup())
1254 return failure();
1255
1256 auto reductionDims = llvm::to_vector(op.getReductionDims());
1257
1258 // Get sg_layout and sg_data from the parent layout
1259 SmallVector<int64_t> sgLayout;
1260 SmallVector<int64_t> sgData;
1261 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1262 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1263 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1264 } else
1265 return rewriter.notifyMatchFailure(
1266 op, "Reduction should have SliceAttr layout");
1267
1268 // Step 1: perform local subgroup reductions with neutral accumulator
1269 SmallVector<Value> localReductions;
1270 auto sgSrcs = adaptor.getSource();
1271 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1272 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1273 sgSrcType.getShape().end());
1274
1275 // Determine the SG-level destination type.
1276 // For scalar results (all dims reduced), the sg result is also scalar.
1277 // For vector results, compute the sg destination shape from layout.
1278 Type sgDstType;
1279 if (dstVecType) {
1280 auto originalDstShape = dstVecType.getShape();
1281 SmallVector<int64_t> sgDstShape =
1282 getSgShapeAndCount(originalDstShape, layout).first;
1283 sgDstType = VectorType::get(sgDstShape, elemTy);
1284 } else {
1285 sgDstType = elemTy;
1286 }
1287
1288 for (auto sgSrc : sgSrcs) {
1289 // Create neutral accumulator for local reduction
1290 Value neutralLocalAcc = xegpu::createReductionNeutralValue(
1291 rewriter, loc, sgDstType, op.getKind());
1292 // Local reduction with neutral accumulator
1293 auto localReduce = vector::MultiDimReductionOp::create(
1294 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1295 reductionDims);
1296 localReductions.push_back(localReduce.getResult());
1297 }
1298
1299 // Check if cross-subgroup reduction is needed for any reduction dimension
1300 SmallVector<int64_t> crossSgReductionDims;
1301 for (int64_t reductionDim : reductionDims) {
1302 bool needsCrossSubgroupReduction =
1303 (sgLayout[reductionDim] > 1) &&
1304 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1305
1306 if (needsCrossSubgroupReduction) {
1307 crossSgReductionDims.push_back(reductionDim);
1308 }
1309 }
1310
1311 // If no cross-subgroup reduction needed, add accumulator and return
1312 if (crossSgReductionDims.empty()) {
1313 SmallVector<Value> results;
1314 for (auto localResult : localReductions) {
1315 auto finalResult = vector::makeArithReduction(
1316 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1317 results.push_back(finalResult);
1318 }
1319 rewriter.replaceOpWithMultiple(op, {results});
1320 return success();
1321 }
1322
1323 // Step 2: cross-subgroup reduction using SLM
1324 auto slmStoreDataShape = sgSrcShape;
1325 for (int64_t dim : reductionDims)
1326 slmStoreDataShape[dim] = 1;
1327 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1328 Value slmStoreData;
1329 if (isScalarResult) {
1330 // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
1331 slmStoreData = vector::BroadcastOp::create(
1332 rewriter, loc, slmStoreDataType, localReductions[0]);
1333 } else {
1334 slmStoreData = vector::ShapeCastOp::create(
1335 rewriter, loc, slmStoreDataType, localReductions[0]);
1336 }
1337
1338 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1339 originalSrcShape.end());
1340 // for reduction dimension, SLM stores partial results from each subgroup
1341 for (int64_t dim : reductionDims)
1342 slmShape[dim] = sgLayout[dim];
1343
1344 // Allocate SLM
1345 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1346 auto bytesPerElement = bitWidth / 8;
1347 auto slmSize = computeProduct(slmShape) * bytesPerElement;
1348 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1349 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1350
1351 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1352 elemTy, nullptr);
1353 auto memDesc =
1354 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1355
1356 // if localReductions have more than 1 result, not support
1357 if (localReductions.size() > 1) {
1358 return rewriter.notifyMatchFailure(
1359 op,
1360 "Multiple local reductions not supported in current implementation.");
1361 }
1362
1363 // Step 4: Store local results to SLM
1364 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1365 rewriter.getIndexType(), nullptr);
1366
1367 // Convert sgLayout to Values for delinearizeIndex
1368 SmallVector<Value> sgLayoutValues;
1369 for (int64_t dim : sgLayout)
1370 sgLayoutValues.push_back(
1371 arith::ConstantIndexOp::create(rewriter, loc, dim));
1372
1373 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1374 sgLayoutValues);
1375 if (failed(sgIdsResult))
1376 return failure();
1377 SmallVector<Value> sgIds = *sgIdsResult;
1378
1379 auto getSlmOffsets = [&](int64_t reductionDimStride) {
1380 SmallVector<OpFoldResult> offsets;
1381 offsets.reserve(srcVecRank);
1382 for (int i = 0; i < srcVecRank; ++i) {
1383 Value dimVal = sgIds[i];
1384 int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
1385 ? reductionDimStride
1386 : sgSrcShape[i];
1387 Value strideVal =
1388 arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
1389 Value offsetVal =
1390 arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1391 offsets.push_back(offsetVal);
1392 }
1393 return offsets;
1394 };
1395
1396 SmallVector<OpFoldResult> slmStoreOffsets =
1397 getSlmOffsets(/*reductionDimStride=*/1);
1398
1399 xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
1400 memDesc.getResult(), slmStoreOffsets,
1401 /*layout=*/nullptr);
1402
1403 gpu::BarrierOp::create(rewriter, loc);
1404
1405 // Step 5: Load from SLM for final reduction
1406 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1407 for (int64_t dim : reductionDims)
1408 slmLoadDataShape[dim] = slmShape[dim];
1409
1410 SmallVector<OpFoldResult> slmLoadOffsets =
1411 getSlmOffsets(/*reductionDimStride=*/0);
1412
1413 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1414 auto slmLoadOp = xegpu::LoadMatrixOp::create(
1415 rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
1416 /*layout=*/nullptr);
1417
1418 // Step 6: Perform final reduction with neutral accumulator
1419 Value neutralFinalAcc = xegpu::createReductionNeutralValue(
1420 rewriter, loc, sgDstType, op.getKind());
1421
1422 auto finalReduce = vector::MultiDimReductionOp::create(
1423 rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
1424 neutralFinalAcc, reductionDims);
1425
1426 // Step 7: Add the original accumulator at the end
1427 auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
1428 finalReduce.getResult(),
1429 adaptor.getAcc()[0]);
1430
1431 rewriter.replaceOp(op, finalResult);
1432 return success();
1433 }
1434};
1435
1436// This pattern transforms vector.transpose ops to work at subgroup level.
1437struct WgToSgVectorTransposeOp
1438 : public OpConversionPattern<vector::TransposeOp> {
1439 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1440
1441 LogicalResult
1442 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1443 ConversionPatternRewriter &rewriter) const override {
1444 VectorType resultType = op.getResultVectorType();
1445
1446 ArrayRef<int64_t> wgShape = resultType.getShape();
1447 xegpu::DistributeLayoutAttr layout =
1448 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1449 if (!layout || !layout.isForWorkgroup())
1450 return failure();
1451 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1452 xegpu::DistributeLayoutAttr sourceLayout =
1453 xegpu::getDistributeLayoutAttr(op.getVector());
1454 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1455 return failure();
1456
1457 SmallVector<int64_t> sourceSgLayout =
1458 sourceLayout.getEffectiveSgLayoutAsInt();
1459 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1460
1461 ArrayRef<int64_t> permutation = op.getPermutation();
1462 size_t permutationSize = permutation.size();
1463 if (sourceSgLayout.size() != permutationSize ||
1464 resultSgLayout.size() != permutationSize) {
1465 return rewriter.notifyMatchFailure(
1466 op, "Layouts and permutation must have the same rank");
1467 }
1468
1469 // Check that sgLayout, sgData & order are properly transposed for source
1470 // and result
1471 if (!layout.isTransposeOf(sourceLayout, permutation,
1472 xegpu::LayoutKind::Subgroup))
1473 return rewriter.notifyMatchFailure(
1474 op, "Result layout is not a valid transpose of source layout "
1475 "according to permutation");
1476
1477 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1478 VectorType newResultType =
1479 VectorType::get(sgShape, resultType.getElementType());
1480
1481 SmallVector<Value> newTransposeOps;
1482 for (auto src : adaptor.getVector()) {
1483 auto newTranspose = vector::TransposeOp::create(
1484 rewriter, op.getLoc(), newResultType, src, permutation);
1485 newTransposeOps.push_back(newTranspose.getResult());
1486 }
1487 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1488 return success();
1489 }
1490};
1491
1492// Distribute vector mask ops to work at subgroup level.
1493template <typename MaskOpType>
1494struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1495 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1496
1497 LogicalResult matchAndRewrite(
1498 MaskOpType op,
1499 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1500 ConversionPatternRewriter &rewriter) const override {
1501 xegpu::DistributeLayoutAttr layout =
1502 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1503 if (!layout || !layout.isForWorkgroup())
1504 return failure();
1505
1506 Location loc = op.getLoc();
1507 VectorType type = op.getResult().getType();
1508 auto wgShape = type.getShape();
1509
1510 SmallVector<Value> wgMaskDimSizes;
1511 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1512 for (int64_t maskSize : op.getMaskDimSizes()) {
1513 wgMaskDimSizes.push_back(
1514 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1515 }
1516 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1517 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1518 }
1519
1520 Value sgId =
1521 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1522 auto sgOffsets =
1523 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1524 if (failed(sgOffsets))
1525 return failure();
1526
1527 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1528 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1529
1530 // In each dimension, each subgroup computes its local mask size as:
1531 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1532 SmallVector<Value> newCreateMaskOps;
1533 for (auto offsetSet : *sgOffsets) {
1534 SmallVector<Value> maskOperands;
1535
1536 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1537 Value dimSizeVal =
1538 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1539 Value offset = offsetSet[i];
1540 Value adjustedMaskSize =
1541 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1542 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1543 Value nonNegative =
1544 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1545 Value sgMaskSize =
1546 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1547 maskOperands.push_back(sgMaskSize);
1548 }
1549
1550 auto newCreateMaskOp =
1551 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1552 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1553 }
1554
1555 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1556 return success();
1557 }
1558};
1559
1560using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1561using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1562} // namespace
1563
1564namespace mlir {
1565namespace xegpu {
1567 patterns
1568 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1569 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1570 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1571 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1572 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1573 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1574 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1575 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1576 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1577 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1578 patterns.getContext());
1579}
1580} // namespace xegpu
1581} // namespace mlir
1582
1583namespace {
1584struct XeGPUWgToSgDistributePass
1585 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1586 void runOnOperation() override;
1587};
1588} // namespace
1589
1590void XeGPUWgToSgDistributePass::runOnOperation() {
1591
1592 Operation *op = getOperation();
1594 signalPassFailure();
1595 return;
1596 }
1597
1598 // Track existing UnrealizedConversionCastOps
1599 SmallVector<Operation *> existingCastOps;
1600 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1601 existingCastOps.push_back(castOp.getOperation());
1602 });
1603
1604 {
1605 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1606 // VectorType operands. This first converts such operands to
1607 // RankedTensorType, propagates the layout attribute into the encoding
1608 // attribute, and finally converts the RankedTensorType to VectorType based
1609 // on the encoding.
1610
1611 TypeConverter converter;
1612 converter.addConversion([&](Type type) -> Type { return type; });
1613 converter.addConversion(
1614 [&](RankedTensorType type,
1615 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1616 // Only convert RankedTensorTypes that carry an XeGPU layout encoding.
1617 // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must
1618 // not be converted: VectorType does not support dynamic dimensions.
1619 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1620 type.getEncoding());
1621 if (!encoding)
1622 return std::nullopt;
1623
1624 Type elemTy = type.getElementType();
1625 ArrayRef<int64_t> shape = type.getShape();
1626
1627 int count;
1628 SmallVector<int64_t> subShape;
1629 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1630
1631 auto newTy = VectorType::get(subShape, elemTy);
1632 result.append(count, newTy);
1633 return success();
1634 });
1635
1637 converter);
1638 }
1639
1640 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1641 // as well as XeGPU, Arith, and Vector operations.
1642 MLIRContext *ctx = &getContext();
1643 RewritePatternSet patterns(ctx);
1644 ConversionTarget target(*ctx);
1645 TypeConverter converter;
1646 converter.addConversion([&](Type type) -> Type { return type; });
1647 converter.addConversion(
1648 [&](xegpu::TensorDescType type,
1649 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1650 xegpu::LayoutAttr layout = type.getLayoutAttr();
1651 // Only convert WG-level tensor descs. SG-level or layout-less types
1652 // are already legal and should pass through unchanged.
1653 if (!layout || !layout.isForWorkgroup())
1654 return std::nullopt;
1655
1656 Type elemTy = type.getElementType();
1657 ArrayRef<int64_t> shape = type.getShape();
1658
1659 int count;
1660 SmallVector<int64_t> subShape;
1661 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1662
1663 layout = layout.dropSgLayoutAndData();
1664
1665 auto newTy = xegpu::TensorDescType::get(
1666 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1667 result.append(count, newTy);
1668 return success();
1669 });
1670
1671 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1672 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1673 return createOp.getType();
1674 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1675 return loadOp.getTensorDescType();
1676 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1677 return storeOp.getTensorDescType();
1678 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1679 return updateOp.getType();
1680 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1681 return prefetchOp.getTensorDescType();
1682 return xegpu::TensorDescType();
1683 };
1684
1685 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1686 return !layout || !layout.isForWorkgroup();
1687 };
1688
1689 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1690 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1691 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1692 auto tdescTy = getTensorDescType(op);
1693 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1694 return isLegal(layout);
1695 });
1696
1697 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1698 auto layout = op.getLayoutCdAttr();
1699 return isLegal(layout);
1700 });
1701
1702 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1703 [=](xegpu::LoadMatrixOp op) -> bool {
1704 return isLegal(op.getLayoutAttr());
1705 });
1706
1707 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1708 [=](xegpu::StoreMatrixOp op) -> bool {
1709 return isLegal(op.getLayoutAttr());
1710 });
1711
1712 target.addDynamicallyLegalOp<arith::ConstantOp>(
1713 [=](arith::ConstantOp op) -> bool {
1714 auto vecType = dyn_cast<VectorType>(op.getType());
1715 if (!vecType)
1716 return true;
1717
1718 auto layout =
1719 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1720 return isLegal(layout);
1721 });
1722
1723 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1724 vector::TransposeOp, vector::BroadcastOp,
1725 vector::MultiDimReductionOp,
1726 vector::ConstantMaskOp, vector::CreateMaskOp>(
1727 [=](Operation *op) -> bool {
1728 // Check for either a SliceAttr or LayoutAttr on the result.
1729 auto layout =
1730 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1731 return isLegal(layout);
1732 });
1733
1734 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1735 [=](xegpu::LoadGatherOp op) -> bool {
1736 auto layout = op.getLayoutAttr();
1737 return isLegal(layout);
1738 });
1739
1740 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1741 [=](xegpu::StoreScatterOp op) -> bool {
1742 auto layout = op.getLayoutAttr();
1743 return isLegal(layout);
1744 });
1745
1746 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1747 [=](xegpu::ConvertLayoutOp op) -> bool {
1748 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1749 });
1750
1751 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1752 [=](Operation *op) -> std::optional<bool> {
1753 // Only handle elementwise mappable ops
1755 return true;
1756
1757 VectorType resultType =
1758 dyn_cast<VectorType>(op->getResult(0).getType());
1759 if (!resultType)
1760 return true;
1761
1762 // Check if all operands are vectors of the same shape
1763 // TODO: Support other types.
1764 for (Value operand : op->getOperands()) {
1765 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1766 if (!operandType || operandType.getShape() != resultType.getShape()) {
1767 return true;
1768 }
1769 }
1770
1771 xegpu::DistributeLayoutAttr layout =
1773 return isLegal(layout);
1774 });
1775
1776 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1777 [=](UnrealizedConversionCastOp op) {
1778 return llvm::is_contained(existingCastOps, op.getOperation());
1779 });
1780
1781 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1782
1784 target);
1786 if (failed(
1787 applyPartialConversion(getOperation(), target, std::move(patterns))))
1788 return signalPassFailure();
1789
1790 // Remove layout attributes from SCF ops
1791 getOperation()->walk([](Operation *op) {
1792 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1793 return;
1794
1795 SmallVector<StringAttr> attrsToRemove;
1796 for (auto namedAttr : op->getDiscardableAttrs()) {
1797 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1798 attrsToRemove.push_back(namedAttr.getName());
1799 }
1800 for (auto attrName : attrsToRemove)
1801 op->removeDiscardableAttr(attrName);
1802 });
1803}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:498
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.