MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include <optional>
26
27namespace mlir {
28namespace xegpu {
29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31} // namespace xegpu
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Retrieve the RangeAttr if it is specified.
39static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
40 Operation *parent = op->getParentOfType<scf::IfOp>();
41 while (parent) {
42 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->getAttr("sg_id_range")))
44 return attr;
45 parent = parent->getParentOfType<scf::IfOp>();
46 }
47 return {};
48}
49
50static std::pair<SmallVector<int64_t>, int>
51getSgShapeAndCount(ArrayRef<int64_t> shape,
52 xegpu::DistributeLayoutAttr layout) {
53 int count = 1;
55 auto distributedShape = layout.computeDistributedShape(
56 SmallVector<int64_t>(shape.begin(), shape.end()));
57 if (failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
60 count = computeProduct(distributedShape.value()) / computeProduct(sgData);
61 return std::make_pair(sgData, count);
62}
63
64/// Utility helper for deriving a list of offsets for each sub-TensorDescs
65/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
66/// associated distribute layout attribute, the shape, subgroup id and the
67/// original offsets of the op
68template <
69 typename OpType,
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
72 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
73static LogicalResult
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
76 Location loc = op.getLoc();
77 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
78 // not applicable to ops without offsets operands.
79 if (origOffsets.empty())
80 return failure();
81
82 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
87 } else {
88 layout = op.getDescLayoutAttr();
89 }
90
91 // not applicable to ops without workgroup layout attributes
92 if (!layout || !layout.isForWorkgroup())
93 return failure();
94
95 Value sgId =
96 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
97
98 // verify and adjust the sgId if the range specifier is present
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 if (sgIdRange) {
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 // verify the RangeAttr against the layout attribute
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op, "sg_layout size must match the sg_id_range");
107 // adjust the sgId if necessary
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
110 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
112 }
113 }
114
115 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
116 // descriptors to be accessed, based on the layout information.
117 ArrayRef<int64_t> wgShape = op.getDataShape();
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (failed(maybeDescOffsets))
121 return failure();
122
123 // Compute the final global offsets for each accessed sub-tensor
124 // or sub-memory descriptor.
125 for (const auto &sgOffsets : *maybeDescOffsets) {
127 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
128 offsetsList.push_back(std::move(newOffsets));
129 }
130
131 // callback(offsetsList);
132 return success();
133}
134
135/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
136/// from a workgroup descriptor. It replaces the offsets and sizes with
137/// appropriate values for the subgroup.
138/// It uses round-robin assignment to distribute the work to the subgroups.
139/// Following create_nd_desc operation:,
140/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
141/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
142/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
143/// is converted to 9 subgroup level operations based on the sg_layout &
144/// sg_data:
145/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
146/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
147/// lane_data = [1, 1]>>
148///
149/// The sg_layout and sg_data attributes are dropped after the pass as they are
150/// no longer needed.
151///
152/// 24x24 matrix distribution example:
153/// sg_layout = [4, 4], sg_data = [2, 2]
154/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
155/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
156///
157/// +------------------------+
158/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
159/// |-----+-----+-----|
160/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
161/// |-----+-----+-----|
162/// | 8x8 | 8x8 | 8x8 |
163/// +------------------------+
164///
165/// Each 8x8 tile is further subdivided among subgroups:
166/// +------------------------+
167/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
168/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
169/// | 2x2 2x2 2x2 2x2 |
170/// | 2x2 2x2 2x2 2x2 |
171/// +------------------------+
172///
173/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
174/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
175
176/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
177/// pattern and all the other ops just follow.
178/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
179/// ops in the pass.
180struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
181 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
182
183 LogicalResult
184 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186 SmallVector<SmallVector<OpFoldResult>> offsetsList;
187 if (failed(genOffsetsList(rewriter, op, offsetsList)))
188 return failure();
189
190 MLIRContext *ctx = op.getContext();
191 xegpu::TensorDescType tdescTy = op.getType();
192 ArrayRef<int64_t> wgShape = tdescTy.getShape();
193 Type elemTy = tdescTy.getElementType();
194 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
195 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
196 auto newTdescTy =
197 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
198 layout.dropSgLayoutAndData());
199
200 SmallVector<Value> newOps;
201 for (auto offsets : offsetsList) {
202 auto newOp = xegpu::CreateNdDescOp::create(
203 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
204 op.getMixedSizes(), op.getMixedStrides());
205
206 newOps.push_back(newOp);
207 }
208 rewriter.replaceOpWithMultiple(op, {newOps});
209
210 return success();
211 }
212};
213
214// This pattern transforms the CreateNdDescOp without offsets to create a
215// subgroup descriptor from a workgroup descriptor
216struct WgToSgCreateNdOpNoOffset
217 : public OpConversionPattern<xegpu::CreateNdDescOp> {
218 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
219
220 LogicalResult
221 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223
224 // Check no offsets are specified.
225 if (!op.getMixedOffsets().empty())
226 return failure();
227
228 Location loc = op.getLoc();
229 MLIRContext *ctx = op.getContext();
230 xegpu::TensorDescType tdescTy = op.getType();
231 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
232 if (!layout || !layout.isForWorkgroup())
233 return failure();
234
235 Type elemTy = tdescTy.getElementType();
236 ArrayRef<int64_t> wgShape = tdescTy.getShape();
237
238 SmallVector<int64_t> sgShape;
239 int count;
240 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
241 xegpu::TensorDescType newTdescTy =
242 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
243 layout.dropSgLayoutAndData());
244
245 SmallVector<Value> newCreateNdOps(count);
246 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
247 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
248 op.getSource(), op.getMixedSizes(),
249 op.getMixedStrides());
250 });
251
252 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
253 return success();
254 }
255};
256
257/// This pattern transforms the LoadNdOp to load subgroup data.
258struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
259 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
260 LogicalResult
261 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 if (!op.getMixedOffsets().empty())
264 return failure();
265
266 SmallVector<Value> newLoadOps;
267 for (auto src : adaptor.getTensorDesc()) {
268 xegpu::TensorDescType tdescTy =
269 dyn_cast<xegpu::TensorDescType>(src.getType());
270 ArrayRef<int64_t> srcShape = tdescTy.getShape();
271 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
272 auto newLoadOp = xegpu::LoadNdOp::create(
273 rewriter, op.getLoc(), newResTy, src,
274 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
275 newLoadOps.push_back(newLoadOp);
276 }
277 rewriter.replaceOpWithMultiple(op, {newLoadOps});
278 return mlir::success();
279 }
280};
281
282/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
283/// It creates a StoreNdOp op to store the updated values to the new subgroup
284/// src tensor descriptors.
285struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
286 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
287 LogicalResult
288 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 if (!op.getMixedOffsets().empty())
291 return failure();
292
293 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
294 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
296
297 rewriter.eraseOp(op);
298 return success();
299 }
300};
301
302// This pattern transforms the LoadNdOp with explicit offsets to load
303// subgroup data.
304struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
305 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
306 LogicalResult
307 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override {
309
310 SmallVector<SmallVector<OpFoldResult>> offsetsList;
311 if (failed(genOffsetsList(rewriter, op, offsetsList)))
312 return failure();
313
314 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
315 if (layout)
316 layout = layout.dropSgLayoutAndData();
317 SmallVector<Value> newOps;
318 for (auto [tdesc, offsets] :
319 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
320 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
321 VectorType newResTy =
322 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
323 auto newOp = xegpu::LoadNdOp::create(
324 rewriter, op.getLoc(), newResTy, tdesc, offsets,
325 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
326 op.getL2HintAttr(), op.getL3HintAttr(), layout);
327 newOps.push_back(newOp);
328 }
329 rewriter.replaceOpWithMultiple(op, {newOps});
330
331 return success();
332 }
333};
334
335// This pattern transforms the StoreNdOp with explicit offsets to store
336// subgroup data.
337struct WgToSgStoreNdOpWithOffset
338 : public OpConversionPattern<xegpu::StoreNdOp> {
339 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
340 LogicalResult
341 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const override {
343 SmallVector<SmallVector<OpFoldResult>> offsetsList;
344 if (failed(genOffsetsList(rewriter, op, offsetsList)))
345 return failure();
346
347 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
348 if (layout)
349 layout = layout.dropSgLayoutAndData();
350 for (auto [v, tdesc, offsets] :
351 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
352 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
353 op.getL1HintAttr(), op.getL2HintAttr(),
354 op.getL3HintAttr(), layout);
355 }
356 rewriter.eraseOp(op);
357
358 return success();
359 }
360};
361
362// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
363// subgroup data.
364struct WgToSgPrefetchNdOpWithOffset
365 : public OpConversionPattern<xegpu::PrefetchNdOp> {
366 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
367 LogicalResult
368 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter) const override {
370 SmallVector<SmallVector<OpFoldResult>> offsetsList;
371 if (failed(genOffsetsList(rewriter, op, offsetsList)))
372 return failure();
373
374 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
375 if (layout)
376 layout = layout.dropSgLayoutAndData();
377 for (auto [tdesc, offsets] :
378 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
379 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
380 op.getL1HintAttr(), op.getL2HintAttr(),
381 op.getL3HintAttr(), layout);
382 }
383 rewriter.eraseOp(op);
384
385 return success();
386 }
387};
388
389/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
390/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
391/// offsets of the new subgroup src tensor descriptors.
392struct WgToSgUpdateNdOffsetOp
393 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
394 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
395 LogicalResult
396 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
397 ConversionPatternRewriter &rewriter) const override {
398 llvm::SmallVector<Value> newUpdateTileOffsetOps;
399 for (auto tDesc : adaptor.getTensorDesc()) {
400 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
401 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
402 op.getConstOffsets());
403 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
404 }
405
406 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
407 return success();
408 }
409};
410
411/// This pattern transforms the DpasOp to work at subgroup level.
412struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
413 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
414 LogicalResult
415 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter) const override {
417 Location loc = op.getLoc();
418 VectorType resultTy = op.getResult().getType();
419 if (resultTy.getRank() != 2)
420 return failure();
421
422 auto layoutCd = op.getLayoutCdAttr();
423 auto layoutA = op.getLayoutAAttr();
424 auto layoutB = op.getLayoutBAttr();
425 if (!layoutCd || !layoutA || !layoutB)
426 return failure();
427 size_t i = 0;
428 SmallVector<Value> newDpasOps;
429 for (auto aVec : adaptor.getLhs()) {
430 for (auto bVec : adaptor.getRhs()) {
431
432 llvm::SmallVector<Value> operands({aVec, bVec});
433 Value tmpC;
434 if (op.getAcc()) {
435 tmpC = adaptor.getAcc()[i++];
436 operands.push_back(tmpC);
437 }
438
439 ArrayRef<int64_t> aVecShape =
440 llvm::cast<VectorType>(aVec.getType()).getShape();
441 ArrayRef<int64_t> bVecShape =
442 llvm::cast<VectorType>(bVec.getType()).getShape();
443 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
444 resultTy.getElementType());
445 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
446 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
447 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
448 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
449
450 newDpasOps.push_back(newDpasOp);
451 }
452 }
453 rewriter.replaceOpWithMultiple(op, {newDpasOps});
454 return success();
455 }
456};
457
458/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
459struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
460 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
461 LogicalResult
462 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter) const override {
464
465 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
466 if ((offsetSize != 0) || op.getConstOffsetsAttr())
467 return failure();
468
469 for (auto src : adaptor.getTensorDesc())
470 xegpu::PrefetchNdOp::create(
471 rewriter, op.getLoc(), TypeRange(), src,
472 xegpu::dropSgLayoutAndDataOnAttrs(op->getAttrs()));
473 rewriter.eraseOp(op);
474 return success();
475 }
476};
477
478/// This pattern transforms vector.broadcast ops to work at subgroup level.
479struct WgToSgVectorBroadcastOp
480 : public OpConversionPattern<vector::BroadcastOp> {
481 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486
487 VectorType resultType = op.getResult().getType();
488 ArrayRef<int64_t> wgShape = resultType.getShape();
489
490 xegpu::DistributeLayoutAttr layout =
491 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
492 if (!layout || !layout.isForWorkgroup())
493 return failure();
494
495 SmallVector<int64_t> sgShape;
496 int count;
497 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
498 VectorType newResultType =
499 VectorType::get(sgShape, resultType.getElementType());
500
501 SmallVector<Value> newBroadcastOps;
502 auto distSource = adaptor.getOperands().front();
503 int numDistributions = count / distSource.size();
504 for (int i = 0; i < numDistributions; ++i) {
505 for (auto operand : distSource) {
506 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
507 newResultType, operand);
508
509 newBroadcastOps.push_back(newBroadcast.getResult());
510 }
511 }
512 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
513 return success();
514 }
515};
516
517// This pattern transforms elementwise ops to work at subgroup level.
518struct WgToSgElementwiseOp : public ConversionPattern {
519 WgToSgElementwiseOp(MLIRContext *ctx)
520 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
521
522 LogicalResult
523 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
524 ConversionPatternRewriter &rewriter) const override {
525 // Only match ops with elementwise trait and single result.
527 return failure();
528
529 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
530 assert(resultType && "Expected result to be a VectorType");
531
532 ArrayRef<int64_t> wgShape = resultType.getShape();
533
534 xegpu::DistributeLayoutAttr layout =
535 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
536 if (!layout || !layout.isForWorkgroup())
537 return failure();
538
539 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
540
541 size_t numVariants = operands.empty() ? 0 : operands.front().size();
542
543 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
544 return operandVec.size() != numVariants;
545 }))
546 return failure();
547
548 SmallVector<Value> newResults;
549 VectorType newResultType =
550 VectorType::get(sgShape, resultType.getElementType());
551
552 for (size_t i = 0; i < numVariants; ++i) {
553 SmallVector<Value> opOperands;
554 for (auto &operandVec : operands)
555 opOperands.push_back(operandVec[i]);
556
557 OperationState state(op->getLoc(), op->getName());
558 state.addOperands(opOperands);
559 state.addTypes(newResultType);
560 state.addAttributes(op->getAttrs());
561 Operation *newOp = rewriter.create(state);
563 newResults.push_back(newOp->getResult(0));
564 }
565
566 rewriter.replaceOpWithMultiple(op, {newResults});
567 return success();
568 }
569};
570
571// clang-format off
572// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
573// If input_layout and target_layout have identical sg_layout and sg_data,
574// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
575// dropped. For example:
576// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
577// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
578// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
579// becomes:
580// #a = #xegpu.layout<inst_data = [16, 16]>
581// #b = #xegpu.layout<inst_data = [8, 16]>
582// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
583// (vector<16x16xf32> is determined by sg_data = [16, 16])
584//
585// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
586// For example:
587// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
588// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
589// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
590// is lowered to:
591// #a = #xegpu.layout<inst_data = [16, 16]>
592// #b = #xegpu.layout<inst_data = [8, 16]>
593// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
594// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
595// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
596// clang-format on
597struct WgToSgConvertLayoutOp
598 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
599 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
600
601 LogicalResult
602 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
603 ConversionPatternRewriter &rewriter) const override {
604 Location loc = op.getLoc();
605 auto inputLayout = op.getInputLayout();
606 auto targetLayout = op.getTargetLayout();
607
608 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
609 !targetLayout.isForWorkgroup())
610 return rewriter.notifyMatchFailure(
611 op, "Input and target layouts must have subgroup layout");
612
613 Type resultType = op.getResult().getType();
614 if (resultType.isIntOrFloat()) {
615 rewriter.replaceOp(op, op.getSource());
616 assert(!inputLayout.dropSgLayoutAndData() &&
617 !targetLayout.dropSgLayoutAndData() &&
618 "unexpected layout attributes for scalar type");
619 return success();
620 }
621
622 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
623 SmallVector<int64_t> inputSgLayout =
624 inputLayout.getEffectiveSgLayoutAsInt();
625 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
626 SmallVector<int64_t> targetSgLayout =
627 targetLayout.getEffectiveSgLayoutAsInt();
628 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
629
630 // Fast path: if sg_layout and sg_data are identical, no SLM needed
631 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
632 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
633 xegpu::LayoutKind::Subgroup)) {
634 inputLayout = inputLayout.dropSgLayoutAndData();
635 targetLayout = targetLayout.dropSgLayoutAndData();
636
637 SmallVector<Value> newOps(adaptor.getSource());
638 if (inputLayout && targetLayout) {
639 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
640 auto newOp = xegpu::ConvertLayoutOp::create(
641 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
642 newOps[i] = newOp;
643 }
644 }
645 rewriter.replaceOpWithMultiple(op, {newOps});
646 return success();
647 }
648
649 // SLM path: layouts differ, need cross-subgroup data redistribution
650 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
651
652 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
653
654 // Calculate SLM size requirements
655 auto bitWidth = elemTy.getIntOrFloatBitWidth();
656 auto bytesPerElement = bitWidth / 8;
657 auto slmSize = computeProduct(slmShape) * bytesPerElement;
658
659 // Allocate SLM
660 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
661 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
662
663 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
664 elemTy, nullptr);
665 auto memDesc =
666 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
667
668 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
669 rewriter.getIndexType(), nullptr);
670
671 // STORE PHASE: Each subgroup stores in SLM using input layout
672 auto storeCoords = inputLayout.computeDistributedCoords(
673 rewriter, loc, sgId.getResult(), wgShape);
674 if (failed(storeCoords))
675 return failure();
676
677 // Store to SLM
678 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
679 SmallVector<OpFoldResult> storeMatrixOffsets;
680 for (Value coord : coords) {
681 storeMatrixOffsets.push_back(coord);
682 }
683 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
684 storeMatrixOffsets, nullptr /*layout*/);
685 }
686
687 gpu::BarrierOp::create(rewriter, loc);
688
689 // LOAD PHASE: Each target subgroup loads from SLM using target layout
690 auto loadCoords = targetLayout.computeDistributedCoords(
691 rewriter, loc, sgId.getResult(), wgShape);
692 if (failed(loadCoords))
693 return failure();
694
695 VectorType loadType = VectorType::get(targetSgData, elemTy);
696
697 // Load vectors from SLM
698 SmallVector<Value> finalResults;
699 for (auto coords : *loadCoords) {
700 SmallVector<OpFoldResult> loadMatrixOffsets;
701 for (Value coord : coords) {
702 loadMatrixOffsets.push_back(coord);
703 }
704 auto loadOp = xegpu::LoadMatrixOp::create(
705 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
706 targetLayout.dropSgLayoutAndData());
707
708 finalResults.push_back(loadOp.getResult());
709 }
710
711 rewriter.replaceOpWithMultiple(op, {finalResults});
712 return success();
713 }
714};
715
716// Handles UnrealizedConversionCastOp generated during
717// SCFStructuralTypeConversions (step 1). This op may appear as either a
718// target or source materialization for Vector values, e.g.:
719// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
720// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
721// it could be either 1:N or N:1 cast. In both cases, the pattern
722// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
723// for example, the following scf::forOp
724// ```
725// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
726// %n = use(%arg1): vector<128x128xf16>
727// scf.yield %n : vector<128x128xf16>
728// }
729// ```
730// Could be converted to:
731// ```
732// %1 = unrealized_conversion_cast %0
733// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
734// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
735// -> (vector<16x16xf16>, vector<16x16xf16) {
736// %m = unrealized_conversion_cast %arg1, %arg2
737// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
738// %n = use(%m): vector<128x128xf16>
739// %b = unrealized_conversion_cast %n
740// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
741// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
742// }
743// %cast = unrealized_conversion_cast %for:2
744// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
745// ```
746// TODO: remove it when context-aware type converter is ready.
747struct UnrealizedConversionCastOpPattern
748 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
749 using OpConversionPattern<
750 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
751
752 mlir::LogicalResult
753 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
754 ConversionPatternRewriter &rewriter) const override {
755 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
756
757 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
758 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
759
760 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
761 !llvm::all_equal(ValueRange(inputs).getTypes()))
762 return failure();
763
764 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
765 // It is generated by source materialization (e.g., inits to scf forOp).
766 // The input values provided by the adaptor should already be distributed,
767 // and their types should correspond exactly to the result types of the
768 // operation.
769 if (op.getNumOperands() == 1 &&
770 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
771 rewriter.replaceOp(op, inputs);
772 return success();
773 }
774
775 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
776 // It is generated by target materialization (e.g., arguments/results
777 // of scf forOp). All input values must have the same vector type, and
778 // their shape must be evenly divisible by the output vector's shape
779 // (determined by the nature of the workgroup to subgroup distribution).
780 // TODO: it is not safe to do such forward, since such N:1 cast could be
781 // from others.
782 if (op.getNumResults() == 1 &&
783 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
784 rewriter.replaceOpWithMultiple(op, {inputs});
785 return success();
786 }
787
788 return mlir::failure();
789 }
790};
791
792// This pattern distributes arith.constant op into subgroup-level constants
793struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
794 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
795
796 LogicalResult
797 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
798 ConversionPatternRewriter &rewriter) const override {
799 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
800 auto vecType = dyn_cast<VectorType>(op.getType());
801 if (!vecAttr || !vecType)
802 return failure();
803
804 xegpu::DistributeLayoutAttr layout =
805 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
806 if (!layout || !layout.isForWorkgroup())
807 return failure();
808
809 ArrayRef<int64_t> wgShape = vecType.getShape();
810 SmallVector<int64_t> sgShape;
811 int count;
812 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
813
814 auto newType = VectorType::get(sgShape, vecType.getElementType());
815 Location loc = op.getLoc();
816 auto eltType = vecType.getElementType();
817
818 if (vecAttr.isSplat()) {
819 // Splat: single value for all subgroups
820 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
821 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
822 SmallVector<Value> newConstOps;
823 for (int i = 0; i < count; ++i) {
824 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
825 newConstOps.push_back(cstOp);
826 }
827 rewriter.replaceOpWithMultiple(op, {newConstOps});
828 return success();
829 } else if (sgShape == wgShape) { // if the entire vector is shared by all
830 // subgroups, don't distribute
831 auto newConstOp =
832 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
833 rewriter.replaceOp(op, newConstOp);
834 return success();
835 } else {
836 // Non-splat constant
837 // Only supports 1D & 2D
838 // TODO: support other cases that require SLM access
839 if (!eltType.isIndex())
840 return rewriter.notifyMatchFailure(
841 op, "Unsupported element type for non-splat constant op.");
842
843 if (wgShape.size() > 2)
844 return rewriter.notifyMatchFailure(
845 op, "Only 1D & 2D vector constant supported");
846
847 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
848 int64_t rowStride = 0, colStride = 0;
849 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
850 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
851
852 // Compute colStride and rowStride, and check for constant strides.
853 if (cols > 1) {
854 colStride = cast<IntegerAttr>(values[1]).getInt() -
855 cast<IntegerAttr>(values[0]).getInt();
856 }
857 if (rows > 1) {
858 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
859 cast<IntegerAttr>(values[0]).getInt();
860 }
861
862 for (int64_t r = 0; r < rows; ++r) {
863 for (int64_t c = 0; c < cols; ++c) {
864 int64_t idx = r * cols + c;
865 // Check column stride
866 if (c > 0 && cols > 1) {
867 int64_t prevIdx = r * cols + (c - 1);
868 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
869 cast<IntegerAttr>(values[prevIdx]).getInt();
870 if (diff != colStride)
871 return rewriter.notifyMatchFailure(
872 op, "Non-constant column stride in constant op.");
873 }
874 // Check row stride
875 if (r > 0 && rows > 1) {
876 int64_t prevIdx = (r - 1) * cols + c;
877 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
878 cast<IntegerAttr>(values[prevIdx]).getInt();
879 if (diff != rowStride)
880 return rewriter.notifyMatchFailure(
881 op, "Non-constant row stride in constant op.");
882 }
883 }
884 }
885
886 // Create a constant for the base tile.
887 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
888 // For 1D case, extract the first sgShape[0] elements.
889 SmallVector<Attribute> baseTileValues;
890 int baseTileCols = sgShape[sgShape.size() - 1];
891 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
892 for (int64_t r = 0; r < baseTileRows; ++r) {
893 for (int64_t c = 0; c < baseTileCols; ++c) {
894 baseTileValues.push_back(values[r * cols + c]);
895 }
896 }
897
898 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
899 baseTileValues);
900 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
901
902 // Get subgroup id
903 Value sgId =
904 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
905 auto sgOffsets =
906 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
907 if (failed(sgOffsets))
908 return failure();
909
910 SmallVector<Value, 2> strideConsts;
911 strideConsts.push_back(
912 arith::ConstantIndexOp::create(rewriter, loc, colStride));
913 if (rows > 1)
914 strideConsts.insert(
915 strideConsts.begin(),
916 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
917
918 SmallVector<Value> newConstOps;
919 for (auto offsets : *sgOffsets) {
920 // Multiply offset with stride, broadcast it and add to baseConstVec
921 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
922 for (size_t i = 0; i < strideConsts.size(); ++i) {
923 Value mul =
924 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
925 offsets[i], strideConsts[i]);
926 mulOffset = arith::AddIOp::create(
927 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
928 }
929 // Broadcast to baseConstVec size
930 auto bcastOffset = vector::BroadcastOp::create(
931 rewriter, loc, baseConstVec.getType(), mulOffset);
932 auto finalConst =
933 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
934 newConstOps.push_back(finalConst);
935 }
936 rewriter.replaceOpWithMultiple(op, {newConstOps});
937 return success();
938 }
939 }
940};
941
942// This pattern transforms the LoadGatherOp with explicit offsets to load
943// subgroup data
944struct WgToSgLoadGatherOpWithOffset
945 : public OpConversionPattern<xegpu::LoadGatherOp> {
946 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
947 LogicalResult
948 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
949 ConversionPatternRewriter &rewriter) const override {
950
951 Location loc = op.getLoc();
952 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
953 if (!resultType)
954 return failure();
955 ArrayRef<int64_t> wgShape = resultType.getShape();
956
957 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
958
959 if (!layout || !layout.isForWorkgroup())
960 return failure();
961
962 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
963
964 // The offsets need to be distributed
965 auto offsetsVecType =
966 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
967 auto maskVecType =
968 dyn_cast<VectorType>(adaptor.getMask().front().getType());
969 if (!offsetsVecType || !maskVecType ||
970 offsetsVecType.getShape() != maskVecType.getShape()) {
971 return rewriter.notifyMatchFailure(op,
972 "offsets have not been distributed");
973 }
974
975 SmallVector<Value> newLoadOps;
976 auto chunkSizeAttr =
977 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
978 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
979 for (auto [offsets, mask] :
980 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
981 auto newLayout = layout.dropSgLayoutAndData();
982 auto newLoadOp = xegpu::LoadGatherOp::create(
983 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
984 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
985 newLayout);
986 newLoadOps.push_back(newLoadOp);
987 }
988 rewriter.replaceOpWithMultiple(op, {newLoadOps});
989 return success();
990 }
991};
992
993// This pattern transforms the StoreScatterOp with explicit offsets to store
994// subgroup data
995struct WgToSgStoreScatterOpWithOffset
996 : public OpConversionPattern<xegpu::StoreScatterOp> {
997 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
998 LogicalResult
999 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
1000 ConversionPatternRewriter &rewriter) const override {
1001
1002 Location loc = op.getLoc();
1003 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1004 if (!valueType)
1005 return failure();
1006
1007 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1008
1009 if (!layout || !layout.isForWorkgroup())
1010 return failure();
1011
1012 // The offsets need to be distributed
1013 auto offsetsVecType =
1014 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1015 auto maskVecType =
1016 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1017 if (!offsetsVecType || !maskVecType ||
1018 offsetsVecType.getShape() != maskVecType.getShape()) {
1019 return rewriter.notifyMatchFailure(op,
1020 "offsets have not been distributed");
1021 }
1022
1023 auto chunkSizeOpt = op.getChunkSize();
1024 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
1025 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1026 for (auto [val, offs, mask] : llvm::zip(
1027 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1028 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1029 mask, chunkSizeAttr, op.getL1HintAttr(),
1030 op.getL2HintAttr(), op.getL3HintAttr(),
1031 layout.dropSgLayoutAndData());
1032 }
1033 rewriter.eraseOp(op);
1034 return success();
1035 }
1036};
1037
1038struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
1039 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1040 LogicalResult
1041 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const override {
1043
1044 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1045 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1046 return failure();
1047
1048 ArrayRef<int64_t> wgShape = op.getDataShape();
1049 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1050 assert(valueTy && "the value type must be vector type!");
1051 Type elemTy = valueTy.getElementType();
1052
1053 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1054 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1055 VectorType newResTy = VectorType::get(sgShape, elemTy);
1056 SmallVector<Value> newOps;
1057 for (auto offsets : offsetsList) {
1058 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1059 op.getMemDesc(), offsets,
1060 layout.dropSgLayoutAndData());
1061 newOps.push_back(newOp);
1062 }
1063 rewriter.replaceOpWithMultiple(op, {newOps});
1064
1065 return success();
1066 }
1067};
1068
1069struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1070 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1071 LogicalResult
1072 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1073 ConversionPatternRewriter &rewriter) const override {
1074
1075 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1076 if (failed(genOffsetsList(rewriter, op, offsetsList)))
1077 return failure();
1078
1079 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1080 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1081 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1082 offsets, layout.dropSgLayoutAndData());
1083 rewriter.eraseOp(op);
1084 return success();
1085 }
1086};
1087
1088// This pattern distributes the vector.step ops to work at subgroup level
1089struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1090 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1091 LogicalResult
1092 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1093 ConversionPatternRewriter &rewriter) const override {
1094 xegpu::DistributeLayoutAttr layout =
1095 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1096 if (!layout || !layout.isForWorkgroup())
1097 return failure();
1098
1099 Location loc = op.getLoc();
1100 VectorType type = op.getResult().getType();
1101 auto wgShape = type.getShape();
1102 std::optional<SmallVector<int64_t>> sgShape =
1103 getSgShapeAndCount(wgShape, layout).first;
1104 if (!sgShape)
1105 return failure();
1106
1107 Value sgId =
1108 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1109 auto sgOffsets =
1110 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1111 if (failed(sgOffsets))
1112 return failure();
1113
1114 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1115 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1116 SmallVector<Value> newOps;
1117 for (auto offsets : *sgOffsets) {
1118 // Broadcast the offset scalar to a vector & add to the base steps
1119 auto bcastOffset =
1120 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1121 auto finalSteps =
1122 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1123 newOps.push_back(finalSteps);
1124 }
1125
1126 rewriter.replaceOpWithMultiple(op, {newOps});
1127 return success();
1128 }
1129};
1130
1131// This pattern transforms vector.shape_cast ops to work at subgroup level.
1132struct WgToSgVectorShapeCastOp
1133 : public OpConversionPattern<vector::ShapeCastOp> {
1134 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1135
1136 LogicalResult
1137 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1138 ConversionPatternRewriter &rewriter) const override {
1139
1140 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1141 if (!resultType)
1142 return failure();
1143
1144 ArrayRef<int64_t> wgShape = resultType.getShape();
1145 xegpu::DistributeLayoutAttr layout =
1146 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1147 if (!layout || !layout.isForWorkgroup())
1148 return failure();
1149
1150 // Check that srcShape and destShape, if they differ, only differ by
1151 // expand of unit dimensions.
1152 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1153 if (!srcType)
1154 return failure();
1155
1156 ArrayRef<int64_t> srcShape = srcType.getShape();
1157
1158 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1159 SmallVector<int64_t> expandedUnitDims;
1160 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1161 xegpu::DistributeLayoutAttr sourceLayout =
1162 xegpu::getTemporaryLayout(op->getOpOperand(0));
1163
1164 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1165 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1166 return isa<vector::BroadcastOp>(user);
1167 });
1168 };
1169
1170 if (!usedByBroadcastOp(op))
1171 return rewriter.notifyMatchFailure(
1172 op, "ShapeCast ops that expand unit dimensions and are used by "
1173 "non-broadcast operations are not supported.");
1174
1175 if (!sourceLayout.isSliceOf(layout))
1176 return rewriter.notifyMatchFailure(
1177 op, "The ShapeCast op only expands dimensions, the input layout "
1178 "must be a slice of the result layout.");
1179
1180 assert(layoutToDistribute.isEqualTo(
1181 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1182 "The sg_data for unit dimensions should be set as 1");
1183 }
1184
1185 SmallVector<int64_t> sgShape =
1186 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1187 VectorType newResultType =
1188 VectorType::get(sgShape, resultType.getElementType());
1189
1190 SmallVector<Value> newShapeCastOps;
1191 for (auto src : adaptor.getSource()) {
1192 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1193 newResultType, src);
1194 newShapeCastOps.push_back(newShapeCast.getResult());
1195 }
1196
1197 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1198 return success();
1199 }
1200};
1201
1202/// This pattern transforms vector.multi_dim_reduction operations from
1203/// workgroup-level to subgroup-level execution with support for multiple
1204/// reduction dimensions.
1205///
1206/// Steps include:
1207/// 1. LOCAL REDUCTION :
1208/// - Each subgroup performs local reduction on its data slice
1209/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1210/// phase
1211///
1212/// 2. CROSS-SUBGROUP :
1213/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1214/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1215/// - If not needed, adds original accumulator and returns local results
1216///
1217/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1218/// a) SLM Layout Design:
1219/// - Rows: subgroups participating in reduction (product of sg_layout in
1220/// reduction dims)
1221/// - Cols: total result elements across non-reduction dimensions
1222///
1223/// b) Store Phase:
1224/// - Each subgroup stores its local reduction result to SLM
1225/// - Row offset: linearized index of subgroup in reduction dimensions
1226/// - Col offset: linearized index of subgroup in non-reduction dimensions
1227///
1228/// c) Load and Final Reduction Phase:
1229/// - Each subgroup loads a column of data (all reduction participants for
1230/// its position)
1231/// - Performs final reduction along the loaded dimension
1232/// - Adds original accumulator to get final result
1233///
1234struct WgToSgMultiDimReductionOp
1235 : public OpConversionPattern<vector::MultiDimReductionOp> {
1236 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1237
1238 LogicalResult
1239 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1240 ConversionPatternRewriter &rewriter) const override {
1241 Location loc = op.getLoc();
1242
1243 VectorType srcType = op.getSourceVectorType();
1244 Type resultTy = op.getResult().getType();
1245 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1246 bool isScalarResult = !dstVecType;
1247
1248 auto originalSrcShape = srcType.getShape();
1249 Type elemTy = srcType.getElementType();
1250
1251 xegpu::DistributeLayoutAttr layout =
1252 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1253 if (!layout || !layout.isForWorkgroup())
1254 return failure();
1255
1256 auto reductionDims = llvm::to_vector(op.getReductionDims());
1257
1258 // Get sg_layout and sg_data from the parent layout
1259 SmallVector<int64_t> sgLayout;
1260 SmallVector<int64_t> sgData;
1261 xegpu::DistributeLayoutAttr parentLayout;
1262 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1263 parentLayout = sliceAttr.getParent();
1264 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1265 sgData = parentLayout.getEffectiveSgDataAsInt();
1266 } else
1267 return rewriter.notifyMatchFailure(
1268 op, "Reduction should have SliceAttr layout");
1269
1270 // Step 1: perform local subgroup reductions with neutral accumulator
1271 SmallVector<Value> localReductions;
1272 auto sgSrcs = adaptor.getSource();
1273 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1274 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1275 sgSrcType.getShape().end());
1276
1277 // Determine the SG-level destination type.
1278 // For scalar results (all dims reduced), the sg result is also scalar.
1279 // For vector results, compute the sg destination shape from layout.
1280 Type sgDstType;
1281 if (dstVecType) {
1282 auto originalDstShape = dstVecType.getShape();
1283 SmallVector<int64_t> sgDstShape =
1284 getSgShapeAndCount(originalDstShape, layout).first;
1285 sgDstType = VectorType::get(sgDstShape, elemTy);
1286 } else {
1287 sgDstType = elemTy;
1288 }
1289
1290 for (auto sgSrc : sgSrcs) {
1291 // Create neutral accumulator for local reduction
1292 Value neutralLocalAcc = xegpu::createReductionNeutralValue(
1293 rewriter, loc, sgDstType, op.getKind());
1294 // Local reduction with neutral accumulator
1295 auto localReduce = vector::MultiDimReductionOp::create(
1296 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1297 reductionDims);
1298 localReductions.push_back(localReduce.getResult());
1299 }
1300
1301 // Check if cross-subgroup reduction is needed for any reduction dimension
1302 SmallVector<int64_t> crossSgReductionDims;
1303 for (int64_t reductionDim : reductionDims) {
1304 bool needsCrossSubgroupReduction =
1305 (sgLayout[reductionDim] > 1) &&
1306 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1307
1308 if (needsCrossSubgroupReduction) {
1309 crossSgReductionDims.push_back(reductionDim);
1310 }
1311 }
1312
1313 // If no cross-subgroup reduction needed, add accumulator and return
1314 if (crossSgReductionDims.empty()) {
1315 SmallVector<Value> results;
1316 for (auto localResult : localReductions) {
1317 auto finalResult = vector::makeArithReduction(
1318 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1319 results.push_back(finalResult);
1320 }
1321 rewriter.replaceOpWithMultiple(op, {results});
1322 return success();
1323 }
1324
1325 // Step 2: cross-subgroup reduction using SLM - allocating slm memory
1326 auto slmStoreDataShape = sgSrcShape;
1327 for (int64_t dim : reductionDims)
1328 slmStoreDataShape[dim] = 1;
1329 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1330 SmallVector<Value> slmStoreData;
1331 for (auto localResult : localReductions) {
1332 if (isScalarResult) {
1333 // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
1334 slmStoreData.push_back(vector::BroadcastOp::create(
1335 rewriter, loc, slmStoreDataType, localResult));
1336 } else {
1337 slmStoreData.push_back(vector::ShapeCastOp::create(
1338 rewriter, loc, slmStoreDataType, localResult));
1339 }
1340 }
1341 // for reduction dimension, SLM stores partial results from each subgroup
1342 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1343 originalSrcShape.end());
1344 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1345 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1346 for (int dim : reductionDims) {
1347 slmShape[dim] = sgLayout[dim];
1348 slmSgData[dim] = 1;
1349 }
1350 xegpu::LayoutAttr slmStoreLayout =
1351 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1352
1353 // Allocate SLM
1354 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1355 auto bytesPerElement = bitWidth / 8;
1356 auto slmSize = computeProduct(slmShape) * bytesPerElement;
1357 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1358 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1359
1360 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1361 elemTy, nullptr);
1362 auto memDesc =
1363 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1364
1365 // Step 3: Store local results to SLM
1366 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1367 rewriter.getIndexType(), nullptr);
1368
1369 auto slmStoreCoords =
1370 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1371 if (failed(slmStoreCoords))
1372 return failure();
1373 for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1374 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1375 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1376 coordOfr,
1377 /*layout=*/nullptr);
1378 }
1379
1380 gpu::BarrierOp::create(rewriter, loc);
1381
1382 // Step 4: Load from SLM for final reduction
1383 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1384 for (int64_t dim : reductionDims) {
1385 slmLoadDataShape[dim] = slmShape[dim];
1386 slmSgData[dim] = slmShape[dim];
1387 }
1388 xegpu::LayoutAttr slmLoadLayout =
1389 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1390 auto slmLoadCoords =
1391 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1392 if (failed(slmLoadCoords))
1393 return failure();
1394
1395 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1396 SmallVector<Value> slmLoadData;
1397 for (auto coord : *slmLoadCoords) {
1398 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1399 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1400 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1401 /*layout=*/nullptr));
1402 }
1403
1404 // Step 5: Perform final reduction with neutral accumulator and add the
1405 // original accumulator at the end
1406 Value neutralFinalAcc = xegpu::createReductionNeutralValue(
1407 rewriter, loc, sgDstType, op.getKind());
1408
1409 SmallVector<Value> finalResults;
1410 for (size_t i = 0; i < slmLoadData.size(); ++i) {
1411 auto loaded = slmLoadData[i];
1412 auto finalReduce = vector::MultiDimReductionOp::create(
1413 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1414 reductionDims);
1415 finalResults.push_back(vector::makeArithReduction(
1416 rewriter, loc, op.getKind(), finalReduce.getResult(),
1417 adaptor.getAcc()[i]));
1418 }
1419 rewriter.replaceOpWithMultiple(op, {finalResults});
1420 return success();
1421 }
1422};
1423
1424// This pattern transforms vector.transpose ops to work at subgroup level.
1425struct WgToSgVectorTransposeOp
1426 : public OpConversionPattern<vector::TransposeOp> {
1427 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1428
1429 LogicalResult
1430 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1431 ConversionPatternRewriter &rewriter) const override {
1432 VectorType resultType = op.getResultVectorType();
1433
1434 ArrayRef<int64_t> wgShape = resultType.getShape();
1435 xegpu::DistributeLayoutAttr layout =
1436 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1437 if (!layout || !layout.isForWorkgroup())
1438 return failure();
1439 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1440 xegpu::DistributeLayoutAttr sourceLayout =
1441 xegpu::getDistributeLayoutAttr(op.getVector());
1442 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1443 return failure();
1444
1445 SmallVector<int64_t> sourceSgLayout =
1446 sourceLayout.getEffectiveSgLayoutAsInt();
1447 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1448
1449 ArrayRef<int64_t> permutation = op.getPermutation();
1450 size_t permutationSize = permutation.size();
1451 if (sourceSgLayout.size() != permutationSize ||
1452 resultSgLayout.size() != permutationSize) {
1453 return rewriter.notifyMatchFailure(
1454 op, "Layouts and permutation must have the same rank");
1455 }
1456
1457 // Check that sgLayout, sgData & order are properly transposed for source
1458 // and result
1459 if (!layout.isTransposeOf(sourceLayout, permutation,
1460 xegpu::LayoutKind::Subgroup))
1461 return rewriter.notifyMatchFailure(
1462 op, "Result layout is not a valid transpose of source layout "
1463 "according to permutation");
1464
1465 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1466 VectorType newResultType =
1467 VectorType::get(sgShape, resultType.getElementType());
1468
1469 SmallVector<Value> newTransposeOps;
1470 for (auto src : adaptor.getVector()) {
1471 auto newTranspose = vector::TransposeOp::create(
1472 rewriter, op.getLoc(), newResultType, src, permutation);
1473 newTransposeOps.push_back(newTranspose.getResult());
1474 }
1475 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1476 return success();
1477 }
1478};
1479
1480// Distribute vector mask ops to work at subgroup level.
1481template <typename MaskOpType>
1482struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1483 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1484
1485 LogicalResult matchAndRewrite(
1486 MaskOpType op,
1487 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1488 ConversionPatternRewriter &rewriter) const override {
1489 xegpu::DistributeLayoutAttr layout =
1490 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1491 if (!layout || !layout.isForWorkgroup())
1492 return failure();
1493
1494 Location loc = op.getLoc();
1495 VectorType type = op.getResult().getType();
1496 auto wgShape = type.getShape();
1497
1498 SmallVector<Value> wgMaskDimSizes;
1499 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1500 for (int64_t maskSize : op.getMaskDimSizes()) {
1501 wgMaskDimSizes.push_back(
1502 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1503 }
1504 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1505 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1506 }
1507
1508 Value sgId =
1509 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1510 auto sgOffsets =
1511 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1512 if (failed(sgOffsets))
1513 return failure();
1514
1515 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1516 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1517
1518 // In each dimension, each subgroup computes its local mask size as:
1519 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1520 SmallVector<Value> newCreateMaskOps;
1521 for (auto offsetSet : *sgOffsets) {
1522 SmallVector<Value> maskOperands;
1523
1524 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1525 Value dimSizeVal =
1526 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1527 Value offset = offsetSet[i];
1528 Value adjustedMaskSize =
1529 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1530 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1531 Value nonNegative =
1532 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1533 Value sgMaskSize =
1534 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1535 maskOperands.push_back(sgMaskSize);
1536 }
1537
1538 auto newCreateMaskOp =
1539 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1540 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1541 }
1542
1543 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1544 return success();
1545 }
1546};
1547
1548using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1549using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1550} // namespace
1551
1552namespace mlir {
1553namespace xegpu {
1555 patterns
1556 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1557 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1558 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1559 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1560 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1561 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1562 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1563 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1564 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1565 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1566 patterns.getContext());
1567}
1568} // namespace xegpu
1569} // namespace mlir
1570
1571namespace {
1572struct XeGPUWgToSgDistributePass
1573 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1574 void runOnOperation() override;
1575};
1576} // namespace
1577
1578void XeGPUWgToSgDistributePass::runOnOperation() {
1579
1580 Operation *op = getOperation();
1582 signalPassFailure();
1583 return;
1584 }
1585
1586 // Track existing UnrealizedConversionCastOps
1587 SmallVector<Operation *> existingCastOps;
1588 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1589 existingCastOps.push_back(castOp.getOperation());
1590 });
1591
1592 {
1593 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1594 // VectorType operands. This first converts such operands to
1595 // RankedTensorType, propagates the layout attribute into the encoding
1596 // attribute, and finally converts the RankedTensorType to VectorType based
1597 // on the encoding.
1598
1599 TypeConverter converter;
1600 converter.addConversion([&](Type type) -> Type { return type; });
1601 converter.addConversion(
1602 [&](RankedTensorType type,
1603 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1604 // Only convert RankedTensorTypes that carry an XeGPU layout encoding.
1605 // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must
1606 // not be converted: VectorType does not support dynamic dimensions.
1607 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1608 type.getEncoding());
1609 if (!encoding)
1610 return std::nullopt;
1611
1612 Type elemTy = type.getElementType();
1613 ArrayRef<int64_t> shape = type.getShape();
1614
1615 int count;
1616 SmallVector<int64_t> subShape;
1617 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1618
1619 auto newTy = VectorType::get(subShape, elemTy);
1620 result.append(count, newTy);
1621 return success();
1622 });
1623
1625 converter);
1626 }
1627
1628 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1629 // as well as XeGPU, Arith, and Vector operations.
1630 MLIRContext *ctx = &getContext();
1631 RewritePatternSet patterns(ctx);
1632 ConversionTarget target(*ctx);
1633 TypeConverter converter;
1634 converter.addConversion([&](Type type) -> Type { return type; });
1635 converter.addConversion(
1636 [&](xegpu::TensorDescType type,
1637 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1638 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1639 // Only convert WG-level tensor descs. SG-level or layout-less types
1640 // are already legal and should pass through unchanged.
1641 if (!layout || !layout.isForWorkgroup())
1642 return std::nullopt;
1643
1644 Type elemTy = type.getElementType();
1645 ArrayRef<int64_t> shape = type.getShape();
1646
1647 int count;
1648 SmallVector<int64_t> subShape;
1649 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1650
1651 layout = layout.dropSgLayoutAndData();
1652
1653 auto newTy = xegpu::TensorDescType::get(
1654 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1655 result.append(count, newTy);
1656 return success();
1657 });
1658
1659 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1660 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1661 return createOp.getType();
1662 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1663 return loadOp.getTensorDescType();
1664 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1665 return storeOp.getTensorDescType();
1666 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1667 return updateOp.getType();
1668 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1669 return prefetchOp.getTensorDescType();
1670 return xegpu::TensorDescType();
1671 };
1672
1673 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1674 return !layout || !layout.isForWorkgroup();
1675 };
1676
1677 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1678 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1679 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1680 auto tdescTy = getTensorDescType(op);
1681 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1682 return isLegal(layout);
1683 });
1684
1685 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1686 auto layout = op.getLayoutCdAttr();
1687 return isLegal(layout);
1688 });
1689
1690 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1691 [=](xegpu::LoadMatrixOp op) -> bool {
1692 return isLegal(op.getLayoutAttr());
1693 });
1694
1695 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1696 [=](xegpu::StoreMatrixOp op) -> bool {
1697 return isLegal(op.getLayoutAttr());
1698 });
1699
1700 target.addDynamicallyLegalOp<arith::ConstantOp>(
1701 [=](arith::ConstantOp op) -> bool {
1702 auto vecType = dyn_cast<VectorType>(op.getType());
1703 if (!vecType)
1704 return true;
1705
1706 auto layout =
1707 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1708 return isLegal(layout);
1709 });
1710
1711 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1712 vector::TransposeOp, vector::BroadcastOp,
1713 vector::MultiDimReductionOp,
1714 vector::ConstantMaskOp, vector::CreateMaskOp>(
1715 [=](Operation *op) -> bool {
1716 // Check for either a SliceAttr or LayoutAttr on the result.
1717 auto layout =
1718 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1719 return isLegal(layout);
1720 });
1721
1722 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1723 [=](xegpu::LoadGatherOp op) -> bool {
1724 auto layout = op.getLayoutAttr();
1725 return isLegal(layout);
1726 });
1727
1728 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1729 [=](xegpu::StoreScatterOp op) -> bool {
1730 auto layout = op.getLayoutAttr();
1731 return isLegal(layout);
1732 });
1733
1734 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1735 [=](xegpu::ConvertLayoutOp op) -> bool {
1736 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1737 });
1738
1739 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1740 [=](Operation *op) -> std::optional<bool> {
1741 // Only handle elementwise mappable ops
1743 return true;
1744
1745 VectorType resultType =
1746 dyn_cast<VectorType>(op->getResult(0).getType());
1747 if (!resultType)
1748 return true;
1749
1750 // Check if all operands are vectors of the same shape
1751 // TODO: Support other types.
1752 for (Value operand : op->getOperands()) {
1753 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1754 if (!operandType || operandType.getShape() != resultType.getShape()) {
1755 return true;
1756 }
1757 }
1758
1759 xegpu::DistributeLayoutAttr layout =
1761 return isLegal(layout);
1762 });
1763
1764 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1765 [=](UnrealizedConversionCastOp op) {
1766 return llvm::is_contained(existingCastOps, op.getOperation());
1767 });
1768
1769 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1770
1772 target);
1774 if (failed(
1775 applyPartialConversion(getOperation(), target, std::move(patterns))))
1776 return signalPassFailure();
1777
1778 xegpu::removeTemporaryLayoutAttrs(getOperation());
1779}
return success()
b getContext())
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.