MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include <optional>
26
27namespace mlir {
28namespace xegpu {
29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31} // namespace xegpu
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Retrieve the RangeAttr if it is specified.
39static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
40 Operation *parent = op->getParentOfType<scf::IfOp>();
41 while (parent) {
42 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->getAttr("sg_id_range")))
44 return attr;
45 parent = parent->getParentOfType<scf::IfOp>();
46 }
47 return {};
48}
49
50static std::pair<SmallVector<int64_t>, int>
51getSgShapeAndCount(ArrayRef<int64_t> shape,
52 xegpu::DistributeLayoutAttr layout) {
53 int count = 1;
55 auto distributedShape = layout.computeDistributedShape(
56 SmallVector<int64_t>(shape.begin(), shape.end()));
57 if (failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
60 count = computeProduct(distributedShape.value()) / computeProduct(sgData);
61 return std::make_pair(sgData, count);
62}
63
64/// Utility helper for deriving a list of offsets for each sub-TensorDescs
65/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
66/// associated distribute layout attribute, the shape, subgroup id and the
67/// original offsets of the op
68template <typename OpType,
69 typename = std::enable_if_t<llvm::is_one_of<
70 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
71 xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
72static LogicalResult
73genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
75 Location loc = op.getLoc();
76 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
77 // not applicable to ops without offsets operands.
78 if (origOffsets.empty())
79 return failure();
80
81 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
82 xegpu::DistributeLayoutAttr layout;
83 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
84 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
85 layout = op.getLayoutAttr();
86 } else {
87 layout = op.getDescLayoutAttr();
88 }
89
90 // not applicable to ops without workgroup layout attributes
91 if (!layout || !layout.isForWorkgroup())
92 return failure();
93
94 Value sgId =
95 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
96
97 // verify and adjust the sgId if the range specifier is present
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
99 if (sgIdRange) {
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
102 // verify the RangeAttr against the layout attribute
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
104 return rewriter.notifyMatchFailure(
105 op, "sg_layout size must match the sg_id_range");
106 // adjust the sgId if necessary
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
109 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
111 }
112 }
113
114 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115 // descriptors to be accessed, based on the layout information.
116 ArrayRef<int64_t> wgShape = op.getDataShape();
117 auto maybeDescOffsets =
118 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119 if (failed(maybeDescOffsets))
120 return failure();
121
122 // Compute the final global offsets for each accessed sub-tensor
123 // or sub-memory descriptor.
124 for (const auto &sgOffsets : *maybeDescOffsets) {
126 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
127 offsetsList.push_back(std::move(newOffsets));
128 }
129
130 // callback(offsetsList);
131 return success();
132}
133
134/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
135/// from a workgroup descriptor. It replaces the offsets and sizes with
136/// appropriate values for the subgroup.
137/// It uses round-robin assignment to distribute the work to the subgroups.
138/// Following create_nd_desc operation:
139/// %tdesc = xegpu.create_nd_tdesc %src : memref<24x24xf32>
140/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
141/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
142/// is converted to 9 subgroup level operations based on the sg_layout &
143/// sg_data:
144/// %tdesc = xegpu.create_nd_tdesc %src : memref<24x24xf32> ->
145/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
146/// lane_data = [1, 1]>>
147///
148/// The sg_layout and sg_data attributes are dropped after the pass as they are
149/// no longer needed.
150///
151/// 24x24 matrix distribution example:
152/// sg_layout = [4, 4], sg_data = [2, 2]
153/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
154/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
155///
156/// +------------------------+
157/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
158/// |-----+-----+-----|
159/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
160/// |-----+-----+-----|
161/// | 8x8 | 8x8 | 8x8 |
162/// +------------------------+
163///
164/// Each 8x8 tile is further subdivided among subgroups:
165/// +------------------------+
166/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
167/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
168/// | 2x2 2x2 2x2 2x2 |
169/// | 2x2 2x2 2x2 2x2 |
170/// +------------------------+
171///
172/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
173/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
174
175/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
176/// pattern and all the other ops just follow.
177/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
178/// ops in the pass.
179// This pattern transforms the CreateNdDescOp to create a
180// subgroup descriptor from a workgroup descriptor.
181struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
182 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
183
184 LogicalResult
185 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187
188 Location loc = op.getLoc();
189 MLIRContext *ctx = op.getContext();
190 xegpu::TensorDescType tdescTy = op.getType();
191 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
192 if (!layout || !layout.isForWorkgroup())
193 return failure();
194
195 Type elemTy = tdescTy.getElementType();
196 ArrayRef<int64_t> wgShape = tdescTy.getShape();
197
198 SmallVector<int64_t> sgShape;
199 int count;
200 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
201 xegpu::TensorDescType newTdescTy =
202 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
203 layout.dropSgLayoutAndData());
204
205 SmallVector<Value> newCreateNdOps(count);
206 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
207 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
208 op.getSource(), op.getMixedSizes(),
209 op.getMixedStrides());
210 });
211
212 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
213 return success();
214 }
215};
216
217/// This pattern transforms the LoadNdOp to load subgroup data.
218struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
219 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
220 LogicalResult
221 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter) const override {
223
224 SmallVector<SmallVector<OpFoldResult>> offsetsList;
225 if (failed(genOffsetsList(rewriter, op, offsetsList)))
226 return failure();
227
228 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
229 if (layout)
230 layout = layout.dropSgLayoutAndData();
231 SmallVector<Value> newOps;
232 for (auto [tdesc, offsets] :
233 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
234 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
235 VectorType newResTy =
236 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
237 auto newOp = xegpu::LoadNdOp::create(
238 rewriter, op.getLoc(), newResTy, tdesc, offsets,
239 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
240 op.getL2HintAttr(), op.getL3HintAttr(), layout);
241 newOps.push_back(newOp);
242 }
243 rewriter.replaceOpWithMultiple(op, {newOps});
244
245 return success();
246 }
247};
248
249/// This pattern transforms the StoreNdOp to store subgroup data.
250struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
251 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
252 LogicalResult
253 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override {
255 SmallVector<SmallVector<OpFoldResult>> offsetsList;
256 if (failed(genOffsetsList(rewriter, op, offsetsList)))
257 return failure();
258
259 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
260 if (layout)
261 layout = layout.dropSgLayoutAndData();
262 for (auto [v, tdesc, offsets] :
263 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
264 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
265 op.getL1HintAttr(), op.getL2HintAttr(),
266 op.getL3HintAttr(), layout);
267 }
268 rewriter.eraseOp(op);
269
270 return success();
271 }
272};
273
274/// This pattern transforms the PrefetchNdOp to prefetch subgroup data.
275struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
276 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
277 LogicalResult
278 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) const override {
280 SmallVector<SmallVector<OpFoldResult>> offsetsList;
281 if (failed(genOffsetsList(rewriter, op, offsetsList)))
282 return failure();
283
284 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
285 if (layout)
286 layout = layout.dropSgLayoutAndData();
287 for (auto [tdesc, offsets] :
288 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
289 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
290 op.getL1HintAttr(), op.getL2HintAttr(),
291 op.getL3HintAttr(), layout);
292 }
293 rewriter.eraseOp(op);
294
295 return success();
296 }
297};
298
299/// This pattern transforms the DpasOp to work at subgroup level.
300struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
301 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
302 LogicalResult
303 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter) const override {
305 Location loc = op.getLoc();
306 VectorType resultTy = op.getResult().getType();
307 if (resultTy.getRank() != 2)
308 return failure();
309
310 auto layoutCd = op.getLayoutCdAttr();
311 auto layoutA = op.getLayoutAAttr();
312 auto layoutB = op.getLayoutBAttr();
313 if (!layoutCd || !layoutA || !layoutB)
314 return failure();
315 size_t i = 0;
316 SmallVector<Value> newDpasOps;
317 for (auto aVec : adaptor.getLhs()) {
318 for (auto bVec : adaptor.getRhs()) {
319
320 llvm::SmallVector<Value> operands({aVec, bVec});
321 Value tmpC;
322 if (op.getAcc()) {
323 tmpC = adaptor.getAcc()[i++];
324 operands.push_back(tmpC);
325 }
326
327 ArrayRef<int64_t> aVecShape =
328 cast<VectorType>(aVec.getType()).getShape();
329 ArrayRef<int64_t> bVecShape =
330 cast<VectorType>(bVec.getType()).getShape();
331 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
332 resultTy.getElementType());
333 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
334 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
335 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
336 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
337
338 newDpasOps.push_back(newDpasOp);
339 }
340 }
341 rewriter.replaceOpWithMultiple(op, {newDpasOps});
342 return success();
343 }
344};
345
346/// This pattern transforms the DpasMxOp to work at subgroup level.
347struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
348 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
349 LogicalResult
350 matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const override {
352
353 Location loc = op.getLoc();
354 VectorType resultTy = op.getResult().getType();
355
356 if (resultTy.getRank() != 2)
357 return failure();
358
359 auto layoutCd = op.getLayoutCdAttr();
360 auto layoutA = op.getLayoutAAttr();
361 auto layoutB = op.getLayoutBAttr();
362 auto layoutAScale = op.getLayoutAScaleAttr();
363 auto layoutBScale = op.getLayoutBScaleAttr();
364
365 if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
366 return failure();
367
368 size_t index_c = 0;
369 SmallVector<Value> newDpasMxOps;
370 for (auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
371 for (auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
372 Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
373 Value scaleAVal =
374 (op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
375 Value scaleBVal =
376 (op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
377
378 ArrayRef<int64_t> aVecShape =
379 cast<VectorType>(aVec.getType()).getShape();
380 ArrayRef<int64_t> bVecShape =
381 cast<VectorType>(bVec.getType()).getShape();
382 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
383 resultTy.getElementType());
384 auto newDpasMxOp = xegpu::DpasMxOp::create(
385 rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
386 layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
387 layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
388 layoutBScale.dropSgLayoutAndData());
389
390 newDpasMxOps.push_back(newDpasMxOp);
391 }
392 }
393 rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
394 return success();
395 }
396};
397
398/// This pattern transforms vector.broadcast ops to work at subgroup level.
399struct WgToSgVectorBroadcastOp
400 : public OpConversionPattern<vector::BroadcastOp> {
401 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
402
403 LogicalResult
404 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override {
406
407 VectorType resultType = op.getResult().getType();
408 ArrayRef<int64_t> wgShape = resultType.getShape();
409
410 xegpu::DistributeLayoutAttr layout =
411 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
412 if (!layout || !layout.isForWorkgroup())
413 return failure();
414
415 SmallVector<int64_t> sgShape;
416 int count;
417 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
418 VectorType newResultType =
419 VectorType::get(sgShape, resultType.getElementType());
420
421 SmallVector<Value> newBroadcastOps;
422 auto distSource = adaptor.getOperands().front();
423 int numDistributions = count / distSource.size();
424 for (int i = 0; i < numDistributions; ++i) {
425 for (auto operand : distSource) {
426 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
427 newResultType, operand);
428
429 newBroadcastOps.push_back(newBroadcast.getResult());
430 }
431 }
432 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
433 return success();
434 }
435};
436
437// This pattern transforms elementwise ops to work at subgroup level.
438struct WgToSgElementwiseOp : public ConversionPattern {
439 WgToSgElementwiseOp(MLIRContext *ctx)
440 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
441
442 LogicalResult
443 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
444 ConversionPatternRewriter &rewriter) const override {
445 // Only match ops with elementwise trait and single result.
447 return failure();
448
449 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
450 assert(resultType && "Expected result to be a VectorType");
451
452 ArrayRef<int64_t> wgShape = resultType.getShape();
453
454 xegpu::DistributeLayoutAttr layout =
455 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
456 if (!layout || !layout.isForWorkgroup())
457 return failure();
458
459 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
460
461 size_t numVariants = operands.empty() ? 0 : operands.front().size();
462
463 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
464 return operandVec.size() != numVariants;
465 }))
466 return failure();
467
468 SmallVector<Value> newResults;
469 VectorType newResultType =
470 VectorType::get(sgShape, resultType.getElementType());
471
472 for (size_t i = 0; i < numVariants; ++i) {
473 SmallVector<Value> opOperands;
474 for (auto &operandVec : operands)
475 opOperands.push_back(operandVec[i]);
476
477 OperationState state(op->getLoc(), op->getName());
478 state.addOperands(opOperands);
479 state.addTypes(newResultType);
480 state.addAttributes(op->getAttrs());
481 Operation *newOp = rewriter.create(state);
483 newResults.push_back(newOp->getResult(0));
484 }
485
486 rewriter.replaceOpWithMultiple(op, {newResults});
487 return success();
488 }
489};
490
491// clang-format off
492// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
493// If input_layout and target_layout have identical sg_layout and sg_data,
494// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
495// dropped. For example:
496// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
497// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
498// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
499// becomes:
500// #a = #xegpu.layout<inst_data = [16, 16]>
501// #b = #xegpu.layout<inst_data = [8, 16]>
502// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
503// (vector<16x16xf32> is determined by sg_data = [16, 16])
504//
505// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
506// For example:
507// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
508// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
509// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
510// is lowered to:
511// #a = #xegpu.layout<inst_data = [16, 16]>
512// #b = #xegpu.layout<inst_data = [8, 16]>
513// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
514// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
515// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
516// clang-format on
517struct WgToSgConvertLayoutOp
518 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
519 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
520
521 LogicalResult
522 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
523 ConversionPatternRewriter &rewriter) const override {
524 Location loc = op.getLoc();
525 auto inputLayout = op.getInputLayout();
526 auto targetLayout = op.getTargetLayout();
527
528 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
529 !targetLayout.isForWorkgroup())
530 return rewriter.notifyMatchFailure(
531 op, "Input and target layouts must have subgroup layout");
533 Type resultType = op.getResult().getType();
534 if (resultType.isIntOrFloat()) {
535 rewriter.replaceOp(op, op.getSource());
536 assert(!inputLayout.dropSgLayoutAndData() &&
537 !targetLayout.dropSgLayoutAndData() &&
538 "unexpected layout attributes for scalar type");
539 return success();
540 }
542 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
543 SmallVector<int64_t> inputSgLayout =
544 inputLayout.getEffectiveSgLayoutAsInt();
545 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
546 SmallVector<int64_t> targetSgLayout =
547 targetLayout.getEffectiveSgLayoutAsInt();
548 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
549
550 // Fast path: if sg_layout and sg_data are identical, no SLM needed
551 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
552 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
554 inputLayout = inputLayout.dropSgLayoutAndData();
555 targetLayout = targetLayout.dropSgLayoutAndData();
556
557 SmallVector<Value> newOps(adaptor.getSource());
558 if (inputLayout && targetLayout) {
559 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
560 auto newOp = xegpu::ConvertLayoutOp::create(
561 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
562 newOps[i] = newOp;
563 }
564 }
565 rewriter.replaceOpWithMultiple(op, {newOps});
566 return success();
567 }
568
569 // SLM path: layouts differ, need cross-subgroup data redistribution
570 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
571
572 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
573
574 // Calculate SLM size requirements
575 auto bitWidth = elemTy.getIntOrFloatBitWidth();
576 auto bytesPerElement = bitWidth / 8;
577 auto slmSize = computeProduct(slmShape) * bytesPerElement;
578
579 // Allocate SLM
580 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
581 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
582
583 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
584 elemTy, nullptr);
585 auto memDesc =
586 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
587
588 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
589 rewriter.getIndexType(), nullptr);
590
591 // STORE PHASE: Each subgroup stores in SLM using input layout
592 auto storeCoords = inputLayout.computeDistributedCoords(
593 rewriter, loc, sgId.getResult(), wgShape);
594 if (failed(storeCoords))
595 return failure();
596
597 // Store to SLM
598 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
599 SmallVector<OpFoldResult> storeMatrixOffsets;
600 for (Value coord : coords) {
601 storeMatrixOffsets.push_back(coord);
602 }
603 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
604 storeMatrixOffsets, nullptr /*layout*/);
605 }
606
607 gpu::BarrierOp::create(rewriter, loc);
608
609 // LOAD PHASE: Each target subgroup loads from SLM using target layout
610 auto loadCoords = targetLayout.computeDistributedCoords(
611 rewriter, loc, sgId.getResult(), wgShape);
612 if (failed(loadCoords))
613 return failure();
614
615 VectorType loadType = VectorType::get(targetSgData, elemTy);
616
617 // Load vectors from SLM
618 SmallVector<Value> finalResults;
619 for (auto coords : *loadCoords) {
620 SmallVector<OpFoldResult> loadMatrixOffsets;
621 for (Value coord : coords) {
622 loadMatrixOffsets.push_back(coord);
623 }
624 auto loadOp = xegpu::LoadMatrixOp::create(
625 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
626 targetLayout.dropSgLayoutAndData());
627
628 finalResults.push_back(loadOp.getResult());
629 }
630
631 rewriter.replaceOpWithMultiple(op, {finalResults});
632 return success();
633 }
634};
635
636// Handles UnrealizedConversionCastOp generated during
637// SCFStructuralTypeConversions (step 1). This op may appear as either a
638// target or source materialization for Vector values, e.g.:
639// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
640// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
641// it could be either 1:N or N:1 cast. In both cases, the pattern
642// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
643// for example, the following scf::forOp
644// ```
645// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
646// %n = use(%arg1): vector<128x128xf16>
647// scf.yield %n : vector<128x128xf16>
648// }
649// ```
650// Could be converted to:
651// ```
652// %1 = unrealized_conversion_cast %0
653// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
654// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
655// -> (vector<16x16xf16>, vector<16x16xf16) {
656// %m = unrealized_conversion_cast %arg1, %arg2
657// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
658// %n = use(%m): vector<128x128xf16>
659// %b = unrealized_conversion_cast %n
660// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
661// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
662// }
663// %cast = unrealized_conversion_cast %for:2
664// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
665// ```
666// TODO: remove it when context-aware type converter is ready.
667struct UnrealizedConversionCastOpPattern
668 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
669 using OpConversionPattern<
670 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
671
672 mlir::LogicalResult
673 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter) const override {
675 SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
676
677 auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
678 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
679
680 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
681 !llvm::all_equal(ValueRange(inputs).getTypes()))
682 return failure();
683
684 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
685 // It is generated by source materialization (e.g., inits to scf forOp).
686 // The input values provided by the adaptor should already be distributed,
687 // and their types should correspond exactly to the result types of the
688 // operation.
689 if (op.getNumOperands() == 1 &&
690 llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
691 rewriter.replaceOp(op, inputs);
692 return success();
693 }
694
695 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
696 // It is generated by target materialization (e.g., arguments/results
697 // of scf forOp). All input values must have the same vector type, and
698 // their shape must be evenly divisible by the output vector's shape
699 // (determined by the nature of the workgroup to subgroup distribution).
700 // TODO: it is not safe to do such forward, since such N:1 cast could be
701 // from others.
702 if (op.getNumResults() == 1 &&
703 computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
704 rewriter.replaceOpWithMultiple(op, {inputs});
705 return success();
706 }
707
708 return mlir::failure();
709 }
710};
711
712// This pattern distributes arith.constant op into subgroup-level constants
713struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
714 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
715
716 LogicalResult
717 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter) const override {
719 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
720 auto vecType = dyn_cast<VectorType>(op.getType());
721 if (!vecAttr || !vecType)
722 return failure();
723
724 xegpu::DistributeLayoutAttr layout =
725 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
726 if (!layout || !layout.isForWorkgroup())
727 return failure();
728
729 ArrayRef<int64_t> wgShape = vecType.getShape();
730 SmallVector<int64_t> sgShape;
731 int count;
732 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
733
734 auto newType = VectorType::get(sgShape, vecType.getElementType());
735 Location loc = op.getLoc();
736 auto eltType = vecType.getElementType();
737
738 if (vecAttr.isSplat()) {
739 // Splat: single value for all subgroups
740 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
741 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
742 SmallVector<Value> newConstOps;
743 for (int i = 0; i < count; ++i) {
744 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
745 newConstOps.push_back(cstOp);
746 }
747 rewriter.replaceOpWithMultiple(op, {newConstOps});
748 return success();
749 } else if (sgShape == wgShape) { // if the entire vector is shared by all
750 // subgroups, don't distribute
751 auto newConstOp =
752 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
753 rewriter.replaceOp(op, newConstOp);
754 return success();
755 } else {
756 // Non-splat constant
757 // Only supports 1D & 2D
758 // TODO: support other cases that require SLM access
759 if (!eltType.isIndex())
760 return rewriter.notifyMatchFailure(
761 op, "Unsupported element type for non-splat constant op.");
762
763 if (wgShape.size() > 2)
764 return rewriter.notifyMatchFailure(
765 op, "Only 1D & 2D vector constant supported");
766
767 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
768 int64_t rowStride = 0, colStride = 0;
769 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
770 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
771
772 // Compute colStride and rowStride, and check for constant strides.
773 if (cols > 1) {
774 colStride = cast<IntegerAttr>(values[1]).getInt() -
775 cast<IntegerAttr>(values[0]).getInt();
776 }
777 if (rows > 1) {
778 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
779 cast<IntegerAttr>(values[0]).getInt();
780 }
781
782 for (int64_t r = 0; r < rows; ++r) {
783 for (int64_t c = 0; c < cols; ++c) {
784 int64_t idx = r * cols + c;
785 // Check column stride
786 if (c > 0 && cols > 1) {
787 int64_t prevIdx = r * cols + (c - 1);
788 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
789 cast<IntegerAttr>(values[prevIdx]).getInt();
790 if (diff != colStride)
791 return rewriter.notifyMatchFailure(
792 op, "Non-constant column stride in constant op.");
793 }
794 // Check row stride
795 if (r > 0 && rows > 1) {
796 int64_t prevIdx = (r - 1) * cols + c;
797 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
798 cast<IntegerAttr>(values[prevIdx]).getInt();
799 if (diff != rowStride)
800 return rewriter.notifyMatchFailure(
801 op, "Non-constant row stride in constant op.");
802 }
803 }
804 }
805
806 // Create a constant for the base tile.
807 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
808 // For 1D case, extract the first sgShape[0] elements.
809 SmallVector<Attribute> baseTileValues;
810 int baseTileCols = sgShape[sgShape.size() - 1];
811 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
812 for (int64_t r = 0; r < baseTileRows; ++r) {
813 for (int64_t c = 0; c < baseTileCols; ++c) {
814 baseTileValues.push_back(values[r * cols + c]);
815 }
816 }
817
818 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
819 baseTileValues);
820 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
821
822 // Get subgroup id
823 Value sgId =
824 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
825 auto sgOffsets =
826 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
827 if (failed(sgOffsets))
828 return failure();
829
830 SmallVector<Value, 2> strideConsts;
831 strideConsts.push_back(
832 arith::ConstantIndexOp::create(rewriter, loc, colStride));
833 if (rows > 1)
834 strideConsts.insert(
835 strideConsts.begin(),
836 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
837
838 SmallVector<Value> newConstOps;
839 for (auto offsets : *sgOffsets) {
840 // Multiply offset with stride, broadcast it and add to baseConstVec
841 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
842 for (size_t i = 0; i < strideConsts.size(); ++i) {
843 Value mul =
844 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
845 offsets[i], strideConsts[i]);
846 mulOffset = arith::AddIOp::create(
847 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
848 }
849 // Broadcast to baseConstVec size
850 auto bcastOffset = vector::BroadcastOp::create(
851 rewriter, loc, baseConstVec.getType(), mulOffset);
852 auto finalConst =
853 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
854 newConstOps.push_back(finalConst);
855 }
856 rewriter.replaceOpWithMultiple(op, {newConstOps});
857 return success();
858 }
859 }
860};
861
862// This pattern transforms the LoadGatherOp with explicit offsets to load
863// subgroup data
864struct WgToSgLoadGatherOp : public OpConversionPattern<xegpu::LoadGatherOp> {
865 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
866 LogicalResult
867 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
868 ConversionPatternRewriter &rewriter) const override {
869
870 Location loc = op.getLoc();
871 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
872 if (!resultType)
873 return failure();
874 ArrayRef<int64_t> wgShape = resultType.getShape();
875
876 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
877
878 if (!layout || !layout.isForWorkgroup())
879 return failure();
880
881 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
882
883 // The offsets need to be distributed
884 auto offsetsVecType =
885 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
886 auto maskVecType =
887 dyn_cast<VectorType>(adaptor.getMask().front().getType());
888 if (!offsetsVecType || !maskVecType ||
889 offsetsVecType.getShape() != maskVecType.getShape()) {
890 return rewriter.notifyMatchFailure(op,
891 "offsets have not been distributed");
892 }
893
894 SmallVector<Value> newLoadOps;
895 auto chunkSizeAttr =
896 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
897 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
898 for (auto [offsets, mask] :
899 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
900 auto newLayout = layout.dropSgLayoutAndData();
901 auto newLoadOp = xegpu::LoadGatherOp::create(
902 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
903 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
904 newLayout);
905 newLoadOps.push_back(newLoadOp);
906 }
907 rewriter.replaceOpWithMultiple(op, {newLoadOps});
908 return success();
909 }
910};
911
912// This pattern transforms the StoreScatterOp with explicit offsets to store
913// subgroup data
914struct WgToSgStoreScatterOp
915 : public OpConversionPattern<xegpu::StoreScatterOp> {
916 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
917 LogicalResult
918 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920
921 Location loc = op.getLoc();
922 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
923 if (!valueType)
924 return failure();
925
926 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
927
928 if (!layout || !layout.isForWorkgroup())
929 return failure();
930
931 // The offsets need to be distributed
932 auto offsetsVecType =
933 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
934 auto maskVecType =
935 dyn_cast<VectorType>(adaptor.getMask().front().getType());
936 if (!offsetsVecType || !maskVecType ||
937 offsetsVecType.getShape() != maskVecType.getShape()) {
938 return rewriter.notifyMatchFailure(op,
939 "offsets have not been distributed");
940 }
941
942 auto chunkSizeOpt = op.getChunkSize();
943 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
944 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
945 for (auto [val, offs, mask] : llvm::zip(
946 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
947 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
948 mask, chunkSizeAttr, op.getL1HintAttr(),
949 op.getL2HintAttr(), op.getL3HintAttr(),
950 layout.dropSgLayoutAndData());
951 }
952 rewriter.eraseOp(op);
953 return success();
954 }
955};
956
957struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
958 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
959 LogicalResult
960 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter) const override {
962
963 SmallVector<SmallVector<OpFoldResult>> offsetsList;
964 if (failed(genOffsetsList(rewriter, op, offsetsList)))
965 return failure();
966
967 ArrayRef<int64_t> wgShape = op.getDataShape();
968 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
969 assert(valueTy && "the value type must be vector type!");
970 Type elemTy = valueTy.getElementType();
971
972 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
973 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
974 VectorType newResTy = VectorType::get(sgShape, elemTy);
975 SmallVector<Value> newOps;
976 for (auto offsets : offsetsList) {
977 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
978 op.getMemDesc(), offsets,
979 layout.dropSgLayoutAndData());
980 newOps.push_back(newOp);
981 }
982 rewriter.replaceOpWithMultiple(op, {newOps});
983
984 return success();
985 }
986};
987
988struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
989 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
990 LogicalResult
991 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter) const override {
993
994 SmallVector<SmallVector<OpFoldResult>> offsetsList;
995 if (failed(genOffsetsList(rewriter, op, offsetsList)))
996 return failure();
997
998 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
999 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1000 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1001 offsets, layout.dropSgLayoutAndData());
1002 rewriter.eraseOp(op);
1003 return success();
1004 }
1005};
1006
1007// This pattern distributes the vector.step ops to work at subgroup level
1008struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1009 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1010 LogicalResult
1011 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1012 ConversionPatternRewriter &rewriter) const override {
1013 xegpu::DistributeLayoutAttr layout =
1014 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1015 if (!layout || !layout.isForWorkgroup())
1016 return failure();
1017
1018 Location loc = op.getLoc();
1019 VectorType type = op.getResult().getType();
1020 auto wgShape = type.getShape();
1021 std::optional<SmallVector<int64_t>> sgShape =
1022 getSgShapeAndCount(wgShape, layout).first;
1023 if (!sgShape)
1024 return failure();
1025
1026 Value sgId =
1027 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1028 auto sgOffsets =
1029 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1030 if (failed(sgOffsets))
1031 return failure();
1032
1033 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1034 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1035 SmallVector<Value> newOps;
1036 for (auto offsets : *sgOffsets) {
1037 // Broadcast the offset scalar to a vector & add to the base steps
1038 auto bcastOffset =
1039 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1040 auto finalSteps =
1041 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1042 newOps.push_back(finalSteps);
1043 }
1044
1045 rewriter.replaceOpWithMultiple(op, {newOps});
1046 return success();
1047 }
1048};
1049
1050// This pattern transforms vector.shape_cast ops to work at subgroup level.
1051struct WgToSgVectorShapeCastOp
1052 : public OpConversionPattern<vector::ShapeCastOp> {
1053 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1054
1055 LogicalResult
1056 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter) const override {
1058
1059 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1060 if (!resultType)
1061 return failure();
1062
1063 ArrayRef<int64_t> wgShape = resultType.getShape();
1064 xegpu::DistributeLayoutAttr layout =
1065 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1066 if (!layout || !layout.isForWorkgroup())
1067 return failure();
1068
1069 // Check that srcShape and destShape, if they differ, only differ by
1070 // expand of unit dimensions.
1071 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1072 if (!srcType)
1073 return failure();
1074
1075 ArrayRef<int64_t> srcShape = srcType.getShape();
1076
1077 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1078 SmallVector<int64_t> expandedUnitDims;
1079 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1080 xegpu::DistributeLayoutAttr sourceLayout =
1081 xegpu::getTemporaryLayout(op->getOpOperand(0));
1082
1083 if (!sourceLayout.isSliceOf(layout))
1084 return rewriter.notifyMatchFailure(
1085 op, "The ShapeCast op only expands dimensions, the input layout "
1086 "must be a slice of the result layout.");
1087
1088 assert(layoutToDistribute.isEqualTo(
1089 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1090 "The sg_data for unit dimensions should be set as 1");
1091 }
1092
1093 SmallVector<int64_t> sgShape =
1094 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1095 VectorType newResultType =
1096 VectorType::get(sgShape, resultType.getElementType());
1097
1098 SmallVector<Value> newShapeCastOps;
1099 for (auto src : adaptor.getSource()) {
1100 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1101 newResultType, src);
1102 newShapeCastOps.push_back(newShapeCast.getResult());
1103 }
1104
1105 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1106 return success();
1107 }
1108};
1109
1110/// This pattern transforms vector.multi_dim_reduction operations from
1111/// workgroup-level to subgroup-level execution with support for multiple
1112/// reduction dimensions.
1113///
1114/// Steps include:
1115/// 1. LOCAL REDUCTION :
1116/// - Each subgroup performs local reduction on its data slice
1117/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1118/// phase
1119///
1120/// 2. CROSS-SUBGROUP :
1121/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1122/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1123/// - If not needed, adds original accumulator and returns local results
1124///
1125/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1126/// a) SLM Layout Design:
1127/// - Rows: subgroups participating in reduction (product of sg_layout in
1128/// reduction dims)
1129/// - Cols: total result elements across non-reduction dimensions
1130///
1131/// b) Store Phase:
1132/// - Each subgroup stores its local reduction result to SLM
1133/// - Row offset: linearized index of subgroup in reduction dimensions
1134/// - Col offset: linearized index of subgroup in non-reduction dimensions
1135///
1136/// c) Load and Final Reduction Phase:
1137/// - Each subgroup loads a column of data (all reduction participants for
1138/// its position)
1139/// - Performs final reduction along the loaded dimension
1140/// - Adds original accumulator to get final result
1141///
1142struct WgToSgMultiDimReductionOp
1143 : public OpConversionPattern<vector::MultiDimReductionOp> {
1144 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1145
1146 LogicalResult
1147 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override {
1149 Location loc = op.getLoc();
1150
1151 VectorType srcType = op.getSourceVectorType();
1152 Type resultTy = op.getResult().getType();
1153 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1154 bool isScalarResult = !dstVecType;
1155
1156 auto originalSrcShape = srcType.getShape();
1157 Type elemTy = srcType.getElementType();
1158
1159 xegpu::DistributeLayoutAttr layout =
1160 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1161 if (!layout || !layout.isForWorkgroup())
1162 return failure();
1163
1164 auto reductionDims = llvm::to_vector(op.getReductionDims());
1165
1166 // Get sg_layout and sg_data from the parent layout
1167 SmallVector<int64_t> sgLayout;
1168 SmallVector<int64_t> sgData;
1169 xegpu::DistributeLayoutAttr parentLayout;
1170 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1171 parentLayout = sliceAttr.getParent();
1172 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1173 sgData = parentLayout.getEffectiveSgDataAsInt();
1174 } else
1175 return rewriter.notifyMatchFailure(
1176 op, "Reduction should have SliceAttr layout");
1177
1178 // Step 1: perform local subgroup reductions with neutral accumulator
1179 SmallVector<Value> localReductions;
1180 auto sgSrcs = adaptor.getSource();
1181 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1182 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1183 sgSrcType.getShape().end());
1184
1185 // Determine the SG-level destination type.
1186 // For scalar results (all dims reduced), the sg result is also scalar.
1187 // For vector results, compute the sg destination shape from layout.
1188 Type sgDstType;
1189 if (dstVecType) {
1190 auto originalDstShape = dstVecType.getShape();
1191 SmallVector<int64_t> sgDstShape =
1192 getSgShapeAndCount(originalDstShape, layout).first;
1193 sgDstType = VectorType::get(sgDstShape, elemTy);
1194 } else {
1195 sgDstType = elemTy;
1196 }
1197
1198 for (auto sgSrc : sgSrcs) {
1199 // Create neutral accumulator for local reduction
1200 Value neutralLocalAcc = xegpu::createReductionNeutralValue(
1201 rewriter, loc, sgDstType, op.getKind());
1202 // Local reduction with neutral accumulator
1203 auto localReduce = vector::MultiDimReductionOp::create(
1204 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1205 reductionDims);
1206 localReductions.push_back(localReduce.getResult());
1207 }
1208
1209 // Check if cross-subgroup reduction is needed for any reduction dimension
1210 SmallVector<int64_t> crossSgReductionDims;
1211 for (int64_t reductionDim : reductionDims) {
1212 bool needsCrossSubgroupReduction =
1213 (sgLayout[reductionDim] > 1) &&
1214 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1215
1216 if (needsCrossSubgroupReduction) {
1217 crossSgReductionDims.push_back(reductionDim);
1218 }
1219 }
1220
1221 // If no cross-subgroup reduction needed, add accumulator and return
1222 if (crossSgReductionDims.empty()) {
1223 SmallVector<Value> results;
1224 for (auto localResult : localReductions) {
1225 auto finalResult = vector::makeArithReduction(
1226 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1227 results.push_back(finalResult);
1228 }
1229 rewriter.replaceOpWithMultiple(op, {results});
1230 return success();
1231 }
1232
1233 // Step 2: cross-subgroup reduction using SLM - allocating slm memory
1234 auto slmStoreDataShape = sgSrcShape;
1235 for (int64_t dim : reductionDims)
1236 slmStoreDataShape[dim] = 1;
1237 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1238 SmallVector<Value> slmStoreData;
1239 for (auto localResult : localReductions) {
1240 if (isScalarResult) {
1241 // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
1242 slmStoreData.push_back(vector::BroadcastOp::create(
1243 rewriter, loc, slmStoreDataType, localResult));
1244 } else {
1245 slmStoreData.push_back(vector::ShapeCastOp::create(
1246 rewriter, loc, slmStoreDataType, localResult));
1247 }
1248 }
1249 // for reduction dimension, SLM stores partial results from each subgroup
1250 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1251 originalSrcShape.end());
1252 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1253 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1254 for (int dim : reductionDims) {
1255 slmShape[dim] = sgLayout[dim];
1256 slmSgData[dim] = 1;
1257 }
1258 xegpu::LayoutAttr slmStoreLayout =
1259 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1260
1261 // Allocate SLM
1262 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1263 auto bytesPerElement = bitWidth / 8;
1264 auto slmSize = computeProduct(slmShape) * bytesPerElement;
1265 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1266 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1267
1268 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1269 elemTy, nullptr);
1270 auto memDesc =
1271 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1272
1273 // Step 3: Store local results to SLM
1274 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1275 rewriter.getIndexType(), nullptr);
1276
1277 auto slmStoreCoords =
1278 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1279 if (failed(slmStoreCoords))
1280 return failure();
1281 for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1282 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1283 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1284 coordOfr,
1285 /*layout=*/nullptr);
1286 }
1287
1288 gpu::BarrierOp::create(rewriter, loc);
1289
1290 // Step 4: Load from SLM for final reduction
1291 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1292 for (int64_t dim : reductionDims) {
1293 slmLoadDataShape[dim] = slmShape[dim];
1294 slmSgData[dim] = slmShape[dim];
1295 }
1296 xegpu::LayoutAttr slmLoadLayout =
1297 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1298 auto slmLoadCoords =
1299 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1300 if (failed(slmLoadCoords))
1301 return failure();
1302
1303 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1304 SmallVector<Value> slmLoadData;
1305 for (auto coord : *slmLoadCoords) {
1306 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1307 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1308 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1309 /*layout=*/nullptr));
1310 }
1311
1312 // Step 5: Perform final reduction with neutral accumulator and add the
1313 // original accumulator at the end
1314 Value neutralFinalAcc = xegpu::createReductionNeutralValue(
1315 rewriter, loc, sgDstType, op.getKind());
1316
1317 SmallVector<Value> finalResults;
1318 for (size_t i = 0; i < slmLoadData.size(); ++i) {
1319 auto loaded = slmLoadData[i];
1320 auto finalReduce = vector::MultiDimReductionOp::create(
1321 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1322 reductionDims);
1323 finalResults.push_back(vector::makeArithReduction(
1324 rewriter, loc, op.getKind(), finalReduce.getResult(),
1325 adaptor.getAcc()[i]));
1326 }
1327 rewriter.replaceOpWithMultiple(op, {finalResults});
1328 return success();
1329 }
1330};
1331
1332// This pattern transforms vector.transpose ops to work at subgroup level.
1333struct WgToSgVectorTransposeOp
1334 : public OpConversionPattern<vector::TransposeOp> {
1335 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1336
1337 LogicalResult
1338 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1339 ConversionPatternRewriter &rewriter) const override {
1340 VectorType resultType = op.getResultVectorType();
1341
1342 ArrayRef<int64_t> wgShape = resultType.getShape();
1343 xegpu::DistributeLayoutAttr layout =
1344 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1345 if (!layout || !layout.isForWorkgroup())
1346 return failure();
1347 // TODO-LayoutRefactor: handle the case using getTemporaryLayout
1348 xegpu::DistributeLayoutAttr sourceLayout =
1349 xegpu::getDistributeLayoutAttr(op.getVector());
1350 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1351 return failure();
1352
1353 SmallVector<int64_t> sourceSgLayout =
1354 sourceLayout.getEffectiveSgLayoutAsInt();
1355 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1356
1357 ArrayRef<int64_t> permutation = op.getPermutation();
1358 size_t permutationSize = permutation.size();
1359 if (sourceSgLayout.size() != permutationSize ||
1360 resultSgLayout.size() != permutationSize) {
1361 return rewriter.notifyMatchFailure(
1362 op, "Layouts and permutation must have the same rank");
1363 }
1364
1365 // Check that sgLayout, sgData & order are properly transposed for source
1366 // and result
1367 if (!layout.isTransposeOf(sourceLayout, permutation,
1368 xegpu::LayoutKind::Subgroup))
1369 return rewriter.notifyMatchFailure(
1370 op, "Result layout is not a valid transpose of source layout "
1371 "according to permutation");
1372
1373 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1374 VectorType newResultType =
1375 VectorType::get(sgShape, resultType.getElementType());
1376
1377 SmallVector<Value> newTransposeOps;
1378 for (auto src : adaptor.getVector()) {
1379 auto newTranspose = vector::TransposeOp::create(
1380 rewriter, op.getLoc(), newResultType, src, permutation);
1381 newTransposeOps.push_back(newTranspose.getResult());
1382 }
1383 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1384 return success();
1385 }
1386};
1387
1388// Distribute vector mask ops to work at subgroup level.
1389template <typename MaskOpType>
1390struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1391 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1392
1393 LogicalResult matchAndRewrite(
1394 MaskOpType op,
1395 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter) const override {
1397 xegpu::DistributeLayoutAttr layout =
1398 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1399 if (!layout || !layout.isForWorkgroup())
1400 return failure();
1401
1402 Location loc = op.getLoc();
1403 VectorType type = op.getResult().getType();
1404 auto wgShape = type.getShape();
1405
1406 SmallVector<Value> wgMaskDimSizes;
1407 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1408 for (int64_t maskSize : op.getMaskDimSizes()) {
1409 wgMaskDimSizes.push_back(
1410 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1411 }
1412 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1413 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1414 }
1415
1416 Value sgId =
1417 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1418 auto sgOffsets =
1419 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1420 if (failed(sgOffsets))
1421 return failure();
1422
1423 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1424 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1425
1426 // In each dimension, each subgroup computes its local mask size as:
1427 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1428 SmallVector<Value> newCreateMaskOps;
1429 for (auto offsetSet : *sgOffsets) {
1430 SmallVector<Value> maskOperands;
1431
1432 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1433 Value dimSizeVal =
1434 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1435 Value offset = offsetSet[i];
1436 Value adjustedMaskSize =
1437 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1438 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1439 Value nonNegative =
1440 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1441 Value sgMaskSize =
1442 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1443 maskOperands.push_back(sgMaskSize);
1444 }
1445
1446 auto newCreateMaskOp =
1447 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1448 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1449 }
1450
1451 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1452 return success();
1453 }
1454};
1455
1456using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1457using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1458
1459// This pattern transforms vector.bitcast ops to work at subgroup level.
1460struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
1461 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1462
1463 LogicalResult
1464 matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1465 ConversionPatternRewriter &rewriter) const override {
1466 VectorType resultType = op.getResultVectorType();
1467
1468 ArrayRef<int64_t> wgShape = resultType.getShape();
1469 xegpu::DistributeLayoutAttr layout =
1470 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1471 if (!layout || !layout.isForWorkgroup())
1472 return failure();
1473
1474 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1475 VectorType newResultType =
1476 VectorType::get(sgShape, resultType.getElementType());
1477
1478 SmallVector<Value> newBitCastOps;
1479 for (auto src : adaptor.getSource()) {
1480 auto newBitCast =
1481 vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1482 newBitCastOps.push_back(newBitCast.getResult());
1483 }
1484
1485 rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1486 return success();
1487 }
1488};
1489
1490// This pattern transforms vector.interleave ops to work at subgroup level.
1491struct WgToSgVectorInterleaveOp
1492 : public OpConversionPattern<vector::InterleaveOp> {
1493 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1494
1495 LogicalResult
1496 matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1497 ConversionPatternRewriter &rewriter) const override {
1498 VectorType resultType = op.getResultVectorType();
1499
1500 ArrayRef<int64_t> wgShape = resultType.getShape();
1501 xegpu::DistributeLayoutAttr layout =
1502 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1503 if (!layout || !layout.isForWorkgroup())
1504 return failure();
1505
1506 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1507 VectorType newResultType =
1508 VectorType::get(sgShape, resultType.getElementType());
1509
1510 SmallVector<Value> newInterleaveOps;
1511 // Interleave operates pairwise: each lhs value is interleaved with
1512 // corresponding rhs value
1513 for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1514 auto newInterleave = vector::InterleaveOp::create(
1515 rewriter, op.getLoc(), newResultType, lhs, rhs);
1516 newInterleaveOps.push_back(newInterleave.getResult());
1517 }
1518
1519 rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1520 return success();
1521 }
1522};
1523
1524// This pattern transforms vector.deinterleave ops to work at subgroup level.
1525struct WgToSgVectorDeinterleaveOp
1526 : public OpConversionPattern<vector::DeinterleaveOp> {
1527 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1528
1529 LogicalResult
1530 matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1531 ConversionPatternRewriter &rewriter) const override {
1532 SmallVector<Value> newRes1Ops;
1533 SmallVector<Value> newRes2Ops;
1534
1535 for (auto src : adaptor.getSource()) {
1536 auto newDeinterleave =
1537 vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1538 newRes1Ops.push_back(newDeinterleave.getRes1());
1539 newRes2Ops.push_back(newDeinterleave.getRes2());
1540 }
1541
1542 SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1543 rewriter.replaceOpWithMultiple(op, results);
1544 return success();
1545 }
1546};
1547
1548} // namespace
1549
1550namespace mlir {
1551namespace xegpu {
1553 patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1554 WgToSgDpasMxOp, WgToSgPrefetchNdOp,
1555 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
1556 WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1557 WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1558 WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1559 WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1560 WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1561 WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1562 WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
1563 patterns.getContext());
1564}
1565} // namespace xegpu
1566} // namespace mlir
1567
1568namespace {
1569struct XeGPUWgToSgDistributePass
1570 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1571 void runOnOperation() override;
1572};
1573} // namespace
1574
1575void XeGPUWgToSgDistributePass::runOnOperation() {
1576
1577 Operation *op = getOperation();
1579 signalPassFailure();
1580 return;
1581 }
1582
1583 // Track existing UnrealizedConversionCastOps
1584 SmallVector<Operation *> existingCastOps;
1585 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1586 existingCastOps.push_back(castOp.getOperation());
1587 });
1588
1589 {
1590 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1591 // VectorType operands. This first converts such operands to
1592 // RankedTensorType, propagates the layout attribute into the encoding
1593 // attribute, and finally converts the RankedTensorType to VectorType based
1594 // on the encoding.
1595
1596 TypeConverter converter;
1597 converter.addConversion([&](Type type) -> Type { return type; });
1598 converter.addConversion(
1599 [&](RankedTensorType type,
1600 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1601 // Only convert RankedTensorTypes that carry an XeGPU layout encoding.
1602 // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must
1603 // not be converted: VectorType does not support dynamic dimensions.
1604 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1605 type.getEncoding());
1606 if (!encoding)
1607 return std::nullopt;
1608
1609 Type elemTy = type.getElementType();
1610 ArrayRef<int64_t> shape = type.getShape();
1611
1612 int count;
1613 SmallVector<int64_t> subShape;
1614 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1615
1616 auto newTy = VectorType::get(subShape, elemTy);
1617 result.append(count, newTy);
1618 return success();
1619 });
1620
1622 converter);
1623 }
1624
1625 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1626 // as well as XeGPU, Arith, and Vector operations.
1627 MLIRContext *ctx = &getContext();
1628 RewritePatternSet patterns(ctx);
1629 ConversionTarget target(*ctx);
1630 TypeConverter converter;
1631 converter.addConversion([&](Type type) -> Type { return type; });
1632 converter.addConversion(
1633 [&](xegpu::TensorDescType type,
1634 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1635 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1636 // Only convert WG-level tensor descs. SG-level or layout-less types
1637 // are already legal and should pass through unchanged.
1638 if (!layout || !layout.isForWorkgroup())
1639 return std::nullopt;
1640
1641 Type elemTy = type.getElementType();
1642 ArrayRef<int64_t> shape = type.getShape();
1643
1644 int count;
1645 SmallVector<int64_t> subShape;
1646 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1647
1648 layout = layout.dropSgLayoutAndData();
1649
1650 auto newTy = xegpu::TensorDescType::get(
1651 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1652 result.append(count, newTy);
1653 return success();
1654 });
1655
1656 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1657 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1658 return createOp.getType();
1659 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1660 return loadOp.getTensorDescType();
1661 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1662 return storeOp.getTensorDescType();
1663 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1664 return prefetchOp.getTensorDescType();
1665 return xegpu::TensorDescType();
1666 };
1667
1668 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1669 return !layout || !layout.isForWorkgroup();
1670 };
1671
1672 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1673 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(
1674 [=](Operation *op) -> bool {
1675 auto tdescTy = getTensorDescType(op);
1676 auto layout =
1677 dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1678 return isLegal(layout);
1679 });
1680
1681 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1682 auto layout = op.getLayoutCdAttr();
1683 return isLegal(layout);
1684 });
1685
1686 target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1687 [=](xegpu::DpasMxOp op) -> bool {
1688 auto layout = op.getLayoutCdAttr();
1689 return isLegal(layout);
1690 });
1691
1692 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1693 [=](xegpu::LoadMatrixOp op) -> bool {
1694 return isLegal(op.getLayoutAttr());
1695 });
1696
1697 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1698 [=](xegpu::StoreMatrixOp op) -> bool {
1699 return isLegal(op.getLayoutAttr());
1700 });
1701
1702 target.addDynamicallyLegalOp<arith::ConstantOp>(
1703 [=](arith::ConstantOp op) -> bool {
1704 auto vecType = dyn_cast<VectorType>(op.getType());
1705 if (!vecType)
1706 return true;
1707
1708 auto layout =
1709 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1710 return isLegal(layout);
1711 });
1712
1713 target.addDynamicallyLegalOp<
1714 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1715 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1716 vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1717 vector::DeinterleaveOp>([=](Operation *op) -> bool {
1718 // Check for either a SliceAttr or LayoutAttr on the result.
1719 auto layout =
1720 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1721 return isLegal(layout);
1722 });
1723
1724 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1725 [=](xegpu::LoadGatherOp op) -> bool {
1726 auto layout = op.getLayoutAttr();
1727 return isLegal(layout);
1728 });
1729
1730 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1731 [=](xegpu::StoreScatterOp op) -> bool {
1732 auto layout = op.getLayoutAttr();
1733 return isLegal(layout);
1734 });
1735
1736 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1737 [=](xegpu::ConvertLayoutOp op) -> bool {
1738 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1739 });
1740
1741 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1742 [=](Operation *op) -> std::optional<bool> {
1743 // Only handle elementwise mappable ops
1745 return true;
1746
1747 VectorType resultType =
1748 dyn_cast<VectorType>(op->getResult(0).getType());
1749 if (!resultType)
1750 return true;
1751
1752 // Check if all operands are vectors of the same shape
1753 // TODO: Support other types.
1754 for (Value operand : op->getOperands()) {
1755 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1756 if (!operandType || operandType.getShape() != resultType.getShape()) {
1757 return true;
1758 }
1759 }
1760
1761 xegpu::DistributeLayoutAttr layout =
1763 return isLegal(layout);
1764 });
1765
1766 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1767 [=](UnrealizedConversionCastOp op) {
1768 return llvm::is_contained(existingCastOps, op.getOperation());
1769 });
1770
1771 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1772
1774 target);
1776 if (failed(
1777 applyPartialConversion(getOperation(), target, std::move(patterns))))
1778 return signalPassFailure();
1779
1780 xegpu::removeTemporaryLayoutAttrs(getOperation());
1781}
return success()
lhs
b getContext())
#define mul(a, b)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.