MLIR 23.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
25#include "llvm/ADT/SetVector.h"
26#include <optional>
27
28namespace mlir {
29namespace xegpu {
30#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
31#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
32} // namespace xegpu
33} // namespace mlir
34
35using namespace mlir;
36
37namespace {
38
39// Retrieve the RangeAttr if it is specified.
40static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
41 Operation *parent = op->getParentOfType<scf::IfOp>();
42 while (parent) {
43 if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
44 parent->getAttr("sg_id_range")))
45 return attr;
46 parent = parent->getParentOfType<scf::IfOp>();
47 }
48 return {};
49}
50
51static std::pair<SmallVector<int64_t>, int>
52getSgShapeAndCount(ArrayRef<int64_t> shape,
53 xegpu::DistributeLayoutAttr layout) {
54 int count = 1;
56 auto distributedShape = layout.computeDistributedShape(
57 SmallVector<int64_t>(shape.begin(), shape.end()));
58 if (failed(distributedShape))
59 return std::make_pair(sgShape, count);
60 auto sgData = layout.getEffectiveSgDataAsInt();
61 count = computeProduct(distributedShape.value()) / computeProduct(sgData);
62 return std::make_pair(sgData, count);
63}
64
65/// Utility helper for deriving a list of offsets for each sub-TensorDescs
66/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
67/// associated distribute layout attribute, the shape, subgroup id and the
68/// original offsets of the op
69template <typename OpType,
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
72 xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
73static LogicalResult
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
76 Location loc = op.getLoc();
77 SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
78 // not applicable to ops without offsets operands.
79 if (origOffsets.empty())
80 return failure();
81
82 // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
87 } else {
88 layout = op.getDescLayoutAttr();
89 }
90
91 // not applicable to ops without workgroup layout attributes
92 if (!layout || !layout.isForWorkgroup())
93 return failure();
94
95 Value sgId =
96 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
97
98 // verify and adjust the sgId if the range specifier is present
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 if (sgIdRange) {
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 // verify the RangeAttr against the layout attribute
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op, "sg_layout size must match the sg_id_range");
107 // adjust the sgId if necessary
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
110 arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
112 }
113 }
114
115 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
116 // descriptors to be accessed, based on the layout information.
117 ArrayRef<int64_t> wgShape = op.getDataShape();
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (failed(maybeDescOffsets))
121 return failure();
122
123 // Compute the final global offsets for each accessed sub-tensor
124 // or sub-memory descriptor.
125 for (const auto &sgOffsets : *maybeDescOffsets) {
127 rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
128 offsetsList.push_back(std::move(newOffsets));
129 }
130
131 // callback(offsetsList);
132 return success();
133}
134
135/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
136/// from a workgroup descriptor. It replaces the offsets and sizes with
137/// appropriate values for the subgroup.
138/// It uses round-robin assignment to distribute the work to the subgroups.
139/// Following create_nd_desc operation:
140/// %tdesc = xegpu.create_nd_tdesc %src : memref<24x24xf32>
141/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
142/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
143/// is converted to 9 subgroup level operations based on the sg_layout &
144/// sg_data:
145/// %tdesc = xegpu.create_nd_tdesc %src : memref<24x24xf32> ->
146/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
147/// lane_data = [1, 1]>>
148///
149/// The sg_layout and sg_data attributes are dropped after the pass as they are
150/// no longer needed.
151///
152/// 24x24 matrix distribution example:
153/// sg_layout = [4, 4], sg_data = [2, 2]
154/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
155/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
156///
157/// +------------------------+
158/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
159/// |-----+-----+-----|
160/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
161/// |-----+-----+-----|
162/// | 8x8 | 8x8 | 8x8 |
163/// +------------------------+
164///
165/// Each 8x8 tile is further subdivided among subgroups:
166/// +------------------------+
167/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
168/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
169/// | 2x2 2x2 2x2 2x2 |
170/// | 2x2 2x2 2x2 2x2 |
171/// +------------------------+
172///
173/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
174/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
175
176/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
177/// pattern and all the other ops just follow.
178/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
179/// ops in the pass.
180// This pattern transforms the CreateNdDescOp to create a
181// subgroup descriptor from a workgroup descriptor.
182struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
183 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
184
185 LogicalResult
186 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
187 ConversionPatternRewriter &rewriter) const override {
188
189 Location loc = op.getLoc();
190 MLIRContext *ctx = op.getContext();
191 xegpu::TensorDescType tdescTy = op.getType();
192 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
193 if (!layout || !layout.isForWorkgroup())
194 return failure();
195
196 Type elemTy = tdescTy.getElementType();
197 ArrayRef<int64_t> wgShape = tdescTy.getShape();
198
199 SmallVector<int64_t> sgShape;
200 int count;
201 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
202 xegpu::TensorDescType newTdescTy =
203 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
204 layout.dropSgLayoutAndData());
205
206 SmallVector<Value> newCreateNdOps(count);
207 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
208 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
209 op.getSource(), op.getMixedSizes(),
210 op.getMixedStrides());
211 });
212
213 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
214 return success();
215 }
216};
217
218/// This pattern transforms the LoadNdOp to load subgroup data.
219struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
220 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
221 LogicalResult
222 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224
225 SmallVector<SmallVector<OpFoldResult>> offsetsList;
226 if (failed(genOffsetsList(rewriter, op, offsetsList)))
227 return failure();
228
229 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
230 if (layout)
231 layout = layout.dropSgLayoutAndData();
232 SmallVector<Value> newOps;
233 for (auto [tdesc, offsets] :
234 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
235 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
236 VectorType newResTy =
237 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
238 auto newOp = xegpu::LoadNdOp::create(
239 rewriter, op.getLoc(), newResTy, tdesc, offsets,
240 /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
241 op.getL2HintAttr(), op.getL3HintAttr(), layout);
242 newOps.push_back(newOp);
243 }
244 rewriter.replaceOpWithMultiple(op, {newOps});
245
246 return success();
247 }
248};
249
250/// This pattern transforms the StoreNdOp to store subgroup data.
251struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
252 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
253 LogicalResult
254 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override {
256 SmallVector<SmallVector<OpFoldResult>> offsetsList;
257 if (failed(genOffsetsList(rewriter, op, offsetsList)))
258 return failure();
259
260 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
261 if (layout)
262 layout = layout.dropSgLayoutAndData();
263 for (auto [v, tdesc, offsets] :
264 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
265 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
266 op.getL1HintAttr(), op.getL2HintAttr(),
267 op.getL3HintAttr(), layout);
268 }
269 rewriter.eraseOp(op);
270
271 return success();
272 }
273};
274
275/// This pattern transforms the PrefetchNdOp to prefetch subgroup data.
276struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
277 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
278 LogicalResult
279 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 SmallVector<SmallVector<OpFoldResult>> offsetsList;
282 if (failed(genOffsetsList(rewriter, op, offsetsList)))
283 return failure();
284
285 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
286 if (layout)
287 layout = layout.dropSgLayoutAndData();
288 for (auto [tdesc, offsets] :
289 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
290 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
291 op.getL1HintAttr(), op.getL2HintAttr(),
292 op.getL3HintAttr(), layout);
293 }
294 rewriter.eraseOp(op);
295
296 return success();
297 }
298};
299
300/// This pattern transforms the DpasOp to work at subgroup level.
301struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
302 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
303 LogicalResult
304 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const override {
306 Location loc = op.getLoc();
307 VectorType resultTy = op.getResult().getType();
308 if (resultTy.getRank() < 2)
309 return failure();
310
311 auto layoutCd = op.getLayoutCdAttr();
312 auto layoutA = op.getLayoutAAttr();
313 auto layoutB = op.getLayoutBAttr();
314 if (!layoutCd || !layoutA || !layoutB)
315 return failure();
316 size_t i = 0;
317 SmallVector<Value> newDpasOps;
318 for (auto aVec : adaptor.getLhs()) {
319 for (auto bVec : adaptor.getRhs()) {
320
321 llvm::SmallVector<Value> operands({aVec, bVec});
322 Value tmpC;
323 if (op.getAcc()) {
324 tmpC = adaptor.getAcc()[i++];
325 operands.push_back(tmpC);
326 }
327
328 ArrayRef<int64_t> aVecShape =
329 cast<VectorType>(aVec.getType()).getShape();
330 ArrayRef<int64_t> bVecShape =
331 cast<VectorType>(bVec.getType()).getShape();
332 // Build result shape: batch dims from A + [M, N] from last dims of
333 // A and B.
334 SmallVector<int64_t> resShape(aVecShape.drop_back(2));
335 resShape.push_back(aVecShape[aVecShape.size() - 2]);
336 resShape.push_back(bVecShape[bVecShape.size() - 1]);
337 VectorType resTy = VectorType::get(resShape, resultTy.getElementType());
338 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
339 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
340 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
341 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
342
343 newDpasOps.push_back(newDpasOp);
344 }
345 }
346 rewriter.replaceOpWithMultiple(op, {newDpasOps});
347 return success();
348 }
349};
350
351/// This pattern transforms the DpasMxOp to work at subgroup level.
352struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
353 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
354 LogicalResult
355 matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
356 ConversionPatternRewriter &rewriter) const override {
357
358 Location loc = op.getLoc();
359 VectorType resultTy = op.getResult().getType();
360
361 if (resultTy.getRank() < 2)
362 return failure();
363
364 auto layoutCd = op.getLayoutCdAttr();
365 auto layoutA = op.getLayoutAAttr();
366 auto layoutB = op.getLayoutBAttr();
367 auto layoutAScale = op.getLayoutAScaleAttr();
368 auto layoutBScale = op.getLayoutBScaleAttr();
369
370 if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
371 return failure();
372
373 size_t index_c = 0;
374 SmallVector<Value> newDpasMxOps;
375 for (auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
376 for (auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
377 Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
378 Value scaleAVal =
379 (op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
380 Value scaleBVal =
381 (op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
382
383 ArrayRef<int64_t> aVecShape =
384 cast<VectorType>(aVec.getType()).getShape();
385 ArrayRef<int64_t> bVecShape =
386 cast<VectorType>(bVec.getType()).getShape();
387 // Build result shape: batch dims from A + [M, N]
388 SmallVector<int64_t> resShape(aVecShape.drop_back(2));
389 resShape.push_back(aVecShape[aVecShape.size() - 2]);
390 resShape.push_back(bVecShape[bVecShape.size() - 1]);
391 VectorType resTy = VectorType::get(resShape, resultTy.getElementType());
392 auto newDpasMxOp = xegpu::DpasMxOp::create(
393 rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
394 layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
395 layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
396 layoutBScale.dropSgLayoutAndData());
397
398 newDpasMxOps.push_back(newDpasMxOp);
399 }
400 }
401 rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
402 return success();
403 }
404};
405
406/// This pattern transforms vector.broadcast ops to work at subgroup level.
407struct WgToSgVectorBroadcastOp
408 : public OpConversionPattern<vector::BroadcastOp> {
409 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
410
411 LogicalResult
412 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter) const override {
414
415 VectorType resultType = op.getResult().getType();
416 ArrayRef<int64_t> wgShape = resultType.getShape();
417
418 xegpu::DistributeLayoutAttr layout =
419 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
420 if (!layout || !layout.isForWorkgroup())
421 return failure();
422
423 SmallVector<int64_t> sgShape;
424 int count;
425 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
426 VectorType newResultType =
427 VectorType::get(sgShape, resultType.getElementType());
428
429 SmallVector<Value> newBroadcastOps;
430 auto distSource = adaptor.getOperands().front();
431 int numDistributions = count / distSource.size();
432 for (int i = 0; i < numDistributions; ++i) {
433 for (auto operand : distSource) {
434 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
435 newResultType, operand);
436
437 newBroadcastOps.push_back(newBroadcast.getResult());
438 }
439 }
440 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
441 return success();
442 }
443};
444
445// This pattern transforms elementwise ops to work at subgroup level.
446struct WgToSgElementwiseOp : public ConversionPattern {
447 WgToSgElementwiseOp(MLIRContext *ctx)
448 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
450 LogicalResult
451 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
452 ConversionPatternRewriter &rewriter) const override {
453 // Only match ops with elementwise trait and single result.
455 return failure();
456
457 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
458 assert(resultType && "Expected result to be a VectorType");
459
460 ArrayRef<int64_t> wgShape = resultType.getShape();
461
462 xegpu::DistributeLayoutAttr layout =
463 xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
464 if (!layout || !layout.isForWorkgroup())
465 return failure();
466
467 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
469 size_t numVariants = operands.empty() ? 0 : operands.front().size();
470
471 if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
472 return operandVec.size() != numVariants;
473 }))
474 return failure();
475
477 VectorType newResultType =
478 VectorType::get(sgShape, resultType.getElementType());
479
480 for (size_t i = 0; i < numVariants; ++i) {
482 for (auto &operandVec : operands)
483 opOperands.push_back(operandVec[i]);
484
485 OperationState state(op->getLoc(), op->getName());
486 state.addOperands(opOperands);
487 state.addTypes(newResultType);
488 state.addAttributes(op->getAttrs());
489 Operation *newOp = rewriter.create(state);
491 newResults.push_back(newOp->getResult(0));
492 }
493
494 rewriter.replaceOpWithMultiple(op, {newResults});
495 return success();
496 }
497};
499// clang-format off
500// Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
501// If input_layout and target_layout have identical sg_layout and sg_data,
502// the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
503// dropped. For example:
504// #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
505// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
506// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
507// becomes:
508// #a = #xegpu.layout<inst_data = [16, 16]>
509// #b = #xegpu.layout<inst_data = [8, 16]>
510// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
511// (vector<16x16xf32> is determined by sg_data = [16, 16])
512//
513// If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
514// For example:
515// #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
516// #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
517// xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
518// is lowered to:
519// #a = #xegpu.layout<inst_data = [16, 16]>
520// #b = #xegpu.layout<inst_data = [8, 16]>
521// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
522// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
523// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
524// clang-format on
525struct WgToSgConvertLayoutOp
526 : public OpConversionPattern<xegpu::ConvertLayoutOp> {
527 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
528
529 LogicalResult
530 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
531 ConversionPatternRewriter &rewriter) const override {
532 Location loc = op.getLoc();
533 auto inputLayout = op.getInputLayout();
534 auto targetLayout = op.getTargetLayout();
535
536 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
537 !targetLayout.isForWorkgroup())
538 return rewriter.notifyMatchFailure(
539 op, "Input and target layouts must have subgroup layout");
540
541 Type resultType = op.getResult().getType();
542 if (resultType.isIntOrFloat()) {
543 rewriter.replaceOp(op, op.getSource());
544 assert(!inputLayout.dropSgLayoutAndData() &&
545 !targetLayout.dropSgLayoutAndData() &&
546 "unexpected layout attributes for scalar type");
547 return success();
548 }
549
550 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
551 SmallVector<int64_t> inputSgLayout =
552 inputLayout.getEffectiveSgLayoutAsInt();
553 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
554 SmallVector<int64_t> targetSgLayout =
555 targetLayout.getEffectiveSgLayoutAsInt();
556 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
557
558 // Fast path: if sg_layout and sg_data are identical, no SLM needed
559 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
560 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
562 inputLayout = inputLayout.dropSgLayoutAndData();
563 targetLayout = targetLayout.dropSgLayoutAndData();
564
565 SmallVector<Value> newOps(adaptor.getSource());
566 if (inputLayout && targetLayout) {
567 for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
568 auto newOp = xegpu::ConvertLayoutOp::create(
569 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
570 newOps[i] = newOp;
571 }
572 }
573 rewriter.replaceOpWithMultiple(op, {newOps});
574 return success();
575 }
576
577 // SLM path: layouts differ, need cross-subgroup data redistribution
578 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
579
580 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
581
582 // Calculate SLM size requirements
583 auto bitWidth = elemTy.getIntOrFloatBitWidth();
584 auto bytesPerElement = bitWidth / 8;
585 auto slmSize = computeProduct(slmShape) * bytesPerElement;
586
587 // Allocate SLM
588 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
589 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
590
591 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
592 elemTy, nullptr);
593 auto memDesc =
594 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
595
596 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
597 rewriter.getIndexType(), nullptr);
598
599 // STORE PHASE: Each subgroup stores in SLM using input layout
600 auto storeCoords = inputLayout.computeDistributedCoords(
601 rewriter, loc, sgId.getResult(), wgShape);
602 if (failed(storeCoords))
603 return failure();
604
605 // Store to SLM
606 for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
607 SmallVector<OpFoldResult> storeMatrixOffsets;
608 for (Value coord : coords) {
609 storeMatrixOffsets.push_back(coord);
610 }
611 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
612 storeMatrixOffsets, nullptr /*layout*/);
613 }
614
615 gpu::BarrierOp::create(rewriter, loc);
616
617 // LOAD PHASE: Each target subgroup loads from SLM using target layout
618 auto loadCoords = targetLayout.computeDistributedCoords(
619 rewriter, loc, sgId.getResult(), wgShape);
620 if (failed(loadCoords))
621 return failure();
622
623 VectorType loadType = VectorType::get(targetSgData, elemTy);
624
625 // Load vectors from SLM
626 SmallVector<Value> finalResults;
627 for (auto coords : *loadCoords) {
628 SmallVector<OpFoldResult> loadMatrixOffsets;
629 for (Value coord : coords) {
630 loadMatrixOffsets.push_back(coord);
631 }
632 auto loadOp = xegpu::LoadMatrixOp::create(
633 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
634 targetLayout.dropSgLayoutAndData());
635
636 finalResults.push_back(loadOp.getResult());
637 }
638
639 rewriter.replaceOpWithMultiple(op, {finalResults});
640 return success();
641 }
642};
643
644// This pattern distributes arith.constant op into subgroup-level constants
645struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
646 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
647
648 LogicalResult
649 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
652 auto vecType = dyn_cast<VectorType>(op.getType());
653 if (!vecAttr || !vecType)
654 return failure();
655
656 xegpu::DistributeLayoutAttr layout =
657 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
658 if (!layout || !layout.isForWorkgroup())
659 return failure();
660
661 ArrayRef<int64_t> wgShape = vecType.getShape();
662 SmallVector<int64_t> sgShape;
663 int count;
664 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
665
666 auto newType = VectorType::get(sgShape, vecType.getElementType());
667 Location loc = op.getLoc();
668 auto eltType = vecType.getElementType();
669
670 if (vecAttr.isSplat()) {
671 // Splat: single value for all subgroups
672 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
673 auto sgAttr = DenseElementsAttr::get(newType, singleVal);
674 SmallVector<Value> newConstOps;
675 for (int i = 0; i < count; ++i) {
676 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
677 newConstOps.push_back(cstOp);
678 }
679 rewriter.replaceOpWithMultiple(op, {newConstOps});
680 return success();
681 } else if (sgShape == wgShape) { // if the entire vector is shared by all
682 // subgroups, don't distribute
683 auto newConstOp =
684 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
685 rewriter.replaceOp(op, newConstOp);
686 return success();
687 } else {
688 // Non-splat constant
689 // Only supports 1D & 2D
690 // TODO: support other cases that require SLM access
691 if (!eltType.isIndex())
692 return rewriter.notifyMatchFailure(
693 op, "Unsupported element type for non-splat constant op.");
694
695 if (wgShape.size() > 2)
696 return rewriter.notifyMatchFailure(
697 op, "Only 1D & 2D vector constant supported");
698
699 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
700 int64_t rowStride = 0, colStride = 0;
701 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
702 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
703
704 // Compute colStride and rowStride, and check for constant strides.
705 if (cols > 1) {
706 colStride = cast<IntegerAttr>(values[1]).getInt() -
707 cast<IntegerAttr>(values[0]).getInt();
708 }
709 if (rows > 1) {
710 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
711 cast<IntegerAttr>(values[0]).getInt();
712 }
713
714 for (int64_t r = 0; r < rows; ++r) {
715 for (int64_t c = 0; c < cols; ++c) {
716 int64_t idx = r * cols + c;
717 // Check column stride
718 if (c > 0 && cols > 1) {
719 int64_t prevIdx = r * cols + (c - 1);
720 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
721 cast<IntegerAttr>(values[prevIdx]).getInt();
722 if (diff != colStride)
723 return rewriter.notifyMatchFailure(
724 op, "Non-constant column stride in constant op.");
725 }
726 // Check row stride
727 if (r > 0 && rows > 1) {
728 int64_t prevIdx = (r - 1) * cols + c;
729 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
730 cast<IntegerAttr>(values[prevIdx]).getInt();
731 if (diff != rowStride)
732 return rewriter.notifyMatchFailure(
733 op, "Non-constant row stride in constant op.");
734 }
735 }
736 }
737
738 // Create a constant for the base tile.
739 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
740 // For 1D case, extract the first sgShape[0] elements.
741 SmallVector<Attribute> baseTileValues;
742 int baseTileCols = sgShape[sgShape.size() - 1];
743 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
744 for (int64_t r = 0; r < baseTileRows; ++r) {
745 for (int64_t c = 0; c < baseTileCols; ++c) {
746 baseTileValues.push_back(values[r * cols + c]);
747 }
748 }
749
750 auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
751 baseTileValues);
752 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
753
754 // Get subgroup id
755 Value sgId =
756 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
757 auto sgOffsets =
758 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
759 if (failed(sgOffsets))
760 return failure();
761
762 SmallVector<Value, 2> strideConsts;
763 strideConsts.push_back(
764 arith::ConstantIndexOp::create(rewriter, loc, colStride));
765 if (rows > 1)
766 strideConsts.insert(
767 strideConsts.begin(),
768 arith::ConstantIndexOp::create(rewriter, loc, rowStride));
769
770 SmallVector<Value> newConstOps;
771 for (auto offsets : *sgOffsets) {
772 // Multiply offset with stride, broadcast it and add to baseConstVec
773 Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
774 for (size_t i = 0; i < strideConsts.size(); ++i) {
775 Value mul =
776 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
777 offsets[i], strideConsts[i]);
778 mulOffset = arith::AddIOp::create(
779 rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
780 }
781 // Broadcast to baseConstVec size
782 auto bcastOffset = vector::BroadcastOp::create(
783 rewriter, loc, baseConstVec.getType(), mulOffset);
784 auto finalConst =
785 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
786 newConstOps.push_back(finalConst);
787 }
788 rewriter.replaceOpWithMultiple(op, {newConstOps});
789 return success();
790 }
791 }
792};
793
794// This pattern transforms the LoadGatherOp with explicit offsets to load
795// subgroup data
796struct WgToSgLoadGatherOp : public OpConversionPattern<xegpu::LoadGatherOp> {
797 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
798 LogicalResult
799 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter) const override {
801
802 Location loc = op.getLoc();
803 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
804 if (!resultType)
805 return failure();
806 ArrayRef<int64_t> wgShape = resultType.getShape();
807
808 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
809
810 if (!layout || !layout.isForWorkgroup())
811 return failure();
812
813 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
814
815 // The offsets need to be distributed
816 auto offsetsVecType =
817 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
818 auto maskVecType =
819 dyn_cast<VectorType>(adaptor.getMask().front().getType());
820 if (!offsetsVecType || !maskVecType ||
821 offsetsVecType.getShape() != maskVecType.getShape()) {
822 return rewriter.notifyMatchFailure(op,
823 "offsets have not been distributed");
824 }
825
826 SmallVector<Value> newLoadOps;
827 auto chunkSizeAttr =
828 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
829 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
830 for (auto [offsets, mask] :
831 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
832 auto newLayout = layout.dropSgLayoutAndData();
833 auto newLoadOp = xegpu::LoadGatherOp::create(
834 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
835 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
836 newLayout);
837 newLoadOps.push_back(newLoadOp);
838 }
839 rewriter.replaceOpWithMultiple(op, {newLoadOps});
840 return success();
841 }
842};
843
844// This pattern transforms the StoreScatterOp with explicit offsets to store
845// subgroup data
846struct WgToSgStoreScatterOp
847 : public OpConversionPattern<xegpu::StoreScatterOp> {
848 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
849 LogicalResult
850 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
851 ConversionPatternRewriter &rewriter) const override {
852
853 Location loc = op.getLoc();
854 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
855 if (!valueType)
856 return failure();
857
858 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
859
860 if (!layout || !layout.isForWorkgroup())
861 return failure();
862
863 // The offsets need to be distributed
864 auto offsetsVecType =
865 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
866 auto maskVecType =
867 dyn_cast<VectorType>(adaptor.getMask().front().getType());
868 if (!offsetsVecType || !maskVecType ||
869 offsetsVecType.getShape() != maskVecType.getShape()) {
870 return rewriter.notifyMatchFailure(op,
871 "offsets have not been distributed");
872 }
873
874 auto chunkSizeOpt = op.getChunkSize();
875 int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
876 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
877 for (auto [val, offs, mask] : llvm::zip(
878 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
879 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
880 mask, chunkSizeAttr, op.getL1HintAttr(),
881 op.getL2HintAttr(), op.getL3HintAttr(),
882 layout.dropSgLayoutAndData());
883 }
884 rewriter.eraseOp(op);
885 return success();
886 }
887};
888
889struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
890 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
891 LogicalResult
892 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter) const override {
894
895 SmallVector<SmallVector<OpFoldResult>> offsetsList;
896 if (failed(genOffsetsList(rewriter, op, offsetsList)))
897 return failure();
898
899 ArrayRef<int64_t> wgShape = op.getDataShape();
900 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
901 assert(valueTy && "the value type must be vector type!");
902 Type elemTy = valueTy.getElementType();
903
904 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
905 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
906 VectorType newResTy = VectorType::get(sgShape, elemTy);
907 SmallVector<Value> newOps;
908 for (auto offsets : offsetsList) {
909 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
910 op.getMemDesc(), offsets,
911 layout.dropSgLayoutAndData());
912 newOps.push_back(newOp);
913 }
914 rewriter.replaceOpWithMultiple(op, {newOps});
915
916 return success();
917 }
918};
919
920struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
921 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
922 LogicalResult
923 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
924 ConversionPatternRewriter &rewriter) const override {
925
926 SmallVector<SmallVector<OpFoldResult>> offsetsList;
927 if (failed(genOffsetsList(rewriter, op, offsetsList)))
928 return failure();
929
930 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
931 for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
932 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
933 offsets, layout.dropSgLayoutAndData());
934 rewriter.eraseOp(op);
935 return success();
936 }
937};
938
939// This pattern distributes the vector.step ops to work at subgroup level
940struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
941 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
942 LogicalResult
943 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter) const override {
945 xegpu::DistributeLayoutAttr layout =
946 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
947 if (!layout || !layout.isForWorkgroup())
948 return failure();
949
950 Location loc = op.getLoc();
951 VectorType type = op.getResult().getType();
952 auto wgShape = type.getShape();
953 std::optional<SmallVector<int64_t>> sgShape =
954 getSgShapeAndCount(wgShape, layout).first;
955 if (!sgShape)
956 return failure();
957
958 Value sgId =
959 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
960 auto sgOffsets =
961 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
962 if (failed(sgOffsets))
963 return failure();
964
965 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
966 auto steps = vector::StepOp::create(rewriter, loc, newTy);
967 SmallVector<Value> newOps;
968 for (auto offsets : *sgOffsets) {
969 // Broadcast the offset scalar to a vector & add to the base steps
970 auto bcastOffset =
971 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
972 auto finalSteps =
973 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
974 newOps.push_back(finalSteps);
975 }
976
977 rewriter.replaceOpWithMultiple(op, {newOps});
978 return success();
979 }
980};
981
982// This pattern transforms vector.shape_cast ops to work at subgroup level.
983struct WgToSgVectorShapeCastOp
984 : public OpConversionPattern<vector::ShapeCastOp> {
985 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
986
987 LogicalResult
988 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
989 ConversionPatternRewriter &rewriter) const override {
990
991 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
992 if (!resultType)
993 return failure();
994
995 ArrayRef<int64_t> wgShape = resultType.getShape();
996 xegpu::DistributeLayoutAttr layout =
997 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
998 if (!layout || !layout.isForWorkgroup())
999 return failure();
1000
1001 // Check that srcShape and destShape, if they differ, only differ by
1002 // expand of unit dimensions.
1003 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1004 if (!srcType)
1005 return failure();
1006
1007 ArrayRef<int64_t> srcShape = srcType.getShape();
1008
1009 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1010 SmallVector<int64_t> expandedUnitDims;
1011 if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
1012 xegpu::DistributeLayoutAttr sourceLayout =
1013 xegpu::getTemporaryLayout(op->getOpOperand(0));
1014
1015 if (!sourceLayout.isSliceOf(layout))
1016 return rewriter.notifyMatchFailure(
1017 op, "The ShapeCast op only expands dimensions, the input layout "
1018 "must be a slice of the result layout.");
1019
1020 assert(layoutToDistribute.isEqualTo(
1021 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1022 "The sg_data for unit dimensions should be set as 1");
1023 }
1024
1025 SmallVector<int64_t> sgShape =
1026 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1027 VectorType newResultType =
1028 VectorType::get(sgShape, resultType.getElementType());
1029
1030 SmallVector<Value> newShapeCastOps;
1031 for (auto src : adaptor.getSource()) {
1032 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1033 newResultType, src);
1034 newShapeCastOps.push_back(newShapeCast.getResult());
1035 }
1036
1037 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1038 return success();
1039 }
1040};
1041
1042/// This pattern transforms vector.multi_dim_reduction operations from
1043/// workgroup-level to subgroup-level execution with support for multiple
1044/// reduction dimensions.
1045///
1046/// Steps include:
1047/// 1. LOCAL REDUCTION :
1048/// - Each subgroup performs local reduction on its data slice
1049/// - Uses ZERO accumulator to avoid double-counting during cross-subgroup
1050/// phase
1051///
1052/// 2. CROSS-SUBGROUP :
1053/// - Determines if cross-subgroup reduction is needed (when sg_layout > 1 in
1054/// reduction dims & sgData[reduction dims] < wgData[reduction dims])
1055/// - If not needed, adds original accumulator and returns local results
1056///
1057/// 3. SHARED LOCAL MEMORY (SLM) PHASE (when cross-subgroup reduction needed):
1058/// a) SLM Layout Design:
1059/// - Rows: subgroups participating in reduction (product of sg_layout in
1060/// reduction dims)
1061/// - Cols: total result elements across non-reduction dimensions
1062///
1063/// b) Store Phase:
1064/// - Each subgroup stores its local reduction result to SLM
1065/// - Row offset: linearized index of subgroup in reduction dimensions
1066/// - Col offset: linearized index of subgroup in non-reduction dimensions
1067///
1068/// c) Load and Final Reduction Phase:
1069/// - Each subgroup loads a column of data (all reduction participants for
1070/// its position)
1071/// - Performs final reduction along the loaded dimension
1072/// - Adds original accumulator to get final result
1073///
1074struct WgToSgMultiDimReductionOp
1075 : public OpConversionPattern<vector::MultiDimReductionOp> {
1076 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1077
1078 LogicalResult
1079 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1080 ConversionPatternRewriter &rewriter) const override {
1081 Location loc = op.getLoc();
1082
1083 VectorType srcType = op.getSourceVectorType();
1084 Type resultTy = op.getResult().getType();
1085 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1086 bool isScalarResult = !dstVecType;
1087
1088 auto originalSrcShape = srcType.getShape();
1089 Type elemTy = srcType.getElementType();
1090
1091 xegpu::DistributeLayoutAttr layout =
1092 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1093 if (!layout || !layout.isForWorkgroup())
1094 return failure();
1095
1096 auto reductionDims = llvm::to_vector(op.getReductionDims());
1097
1098 // Get sg_layout and sg_data from the parent layout
1099 SmallVector<int64_t> sgLayout;
1100 SmallVector<int64_t> sgData;
1101 xegpu::DistributeLayoutAttr parentLayout;
1102 if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1103 parentLayout = sliceAttr.getParent();
1104 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1105 sgData = parentLayout.getEffectiveSgDataAsInt();
1106 } else
1107 return rewriter.notifyMatchFailure(
1108 op, "Reduction should have SliceAttr layout");
1109
1110 // Step 1: perform local subgroup reductions with neutral accumulator
1111 SmallVector<Value> localReductions;
1112 auto sgSrcs = adaptor.getSource();
1113 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1114 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1115 sgSrcType.getShape().end());
1116
1117 // Determine the SG-level destination type.
1118 // For scalar results (all dims reduced), the sg result is also scalar.
1119 // For vector results, compute the sg destination shape from layout.
1120 Type sgDstType;
1121 if (dstVecType) {
1122 auto originalDstShape = dstVecType.getShape();
1123 SmallVector<int64_t> sgDstShape =
1124 getSgShapeAndCount(originalDstShape, layout).first;
1125 sgDstType = VectorType::get(sgDstShape, elemTy);
1126 } else {
1127 sgDstType = elemTy;
1128 }
1129
1130 for (auto sgSrc : sgSrcs) {
1131 // Create neutral accumulator for local reduction
1132 Value neutralLocalAcc = xegpu::createReductionNeutralValue(
1133 rewriter, loc, sgDstType, op.getKind());
1134 // Local reduction with neutral accumulator
1135 auto localReduce = vector::MultiDimReductionOp::create(
1136 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1137 reductionDims);
1138 localReductions.push_back(localReduce.getResult());
1139 }
1140
1141 // Check if cross-subgroup reduction is needed for any reduction dimension
1142 SmallVector<int64_t> crossSgReductionDims;
1143 for (int64_t reductionDim : reductionDims) {
1144 bool needsCrossSubgroupReduction =
1145 (sgLayout[reductionDim] > 1) &&
1146 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1147
1148 if (needsCrossSubgroupReduction) {
1149 crossSgReductionDims.push_back(reductionDim);
1150 }
1151 }
1152
1153 // If no cross-subgroup reduction needed, add accumulator and return
1154 if (crossSgReductionDims.empty()) {
1155 SmallVector<Value> results;
1156 for (auto localResult : localReductions) {
1157 auto finalResult = vector::makeArithReduction(
1158 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1159 results.push_back(finalResult);
1160 }
1161 rewriter.replaceOpWithMultiple(op, {results});
1162 return success();
1163 }
1164
1165 // Step 2: cross-subgroup reduction using SLM - allocating slm memory
1166 auto slmStoreDataShape = sgSrcShape;
1167 for (int64_t dim : reductionDims)
1168 slmStoreDataShape[dim] = 1;
1169 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1170 SmallVector<Value> slmStoreData;
1171 for (auto localResult : localReductions) {
1172 if (isScalarResult) {
1173 // Scalar result: broadcast scalar to vector<1x...x1> for SLM store
1174 slmStoreData.push_back(vector::BroadcastOp::create(
1175 rewriter, loc, slmStoreDataType, localResult));
1176 } else {
1177 slmStoreData.push_back(vector::ShapeCastOp::create(
1178 rewriter, loc, slmStoreDataType, localResult));
1179 }
1180 }
1181 // for reduction dimension, SLM stores partial results from each subgroup
1182 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1183 originalSrcShape.end());
1184 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1185 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1186 for (int dim : reductionDims) {
1187 slmShape[dim] = sgLayout[dim];
1188 slmSgData[dim] = 1;
1189 }
1190 xegpu::LayoutAttr slmStoreLayout =
1191 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1192
1193 // Allocate SLM
1194 auto bitWidth = elemTy.getIntOrFloatBitWidth();
1195 auto bytesPerElement = bitWidth / 8;
1196 auto slmSize = computeProduct(slmShape) * bytesPerElement;
1197 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1198 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1199
1200 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1201 elemTy, nullptr);
1202 auto memDesc =
1203 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1204
1205 // Step 3: Store local results to SLM
1206 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1207 rewriter.getIndexType(), nullptr);
1208
1209 auto slmStoreCoords =
1210 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1211 if (failed(slmStoreCoords))
1212 return failure();
1213 for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1214 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1215 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1216 coordOfr,
1217 /*layout=*/nullptr);
1218 }
1219
1220 gpu::BarrierOp::create(rewriter, loc);
1221
1222 // Step 4: Load from SLM for final reduction
1223 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1224 for (int64_t dim : reductionDims) {
1225 slmLoadDataShape[dim] = slmShape[dim];
1226 slmSgData[dim] = slmShape[dim];
1227 }
1228 xegpu::LayoutAttr slmLoadLayout =
1229 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1230 auto slmLoadCoords =
1231 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1232 if (failed(slmLoadCoords))
1233 return failure();
1234
1235 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1236 SmallVector<Value> slmLoadData;
1237 for (auto coord : *slmLoadCoords) {
1238 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1239 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1240 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1241 /*layout=*/nullptr));
1242 }
1243
1244 // Step 5: Perform final reduction with neutral accumulator and add the
1245 // original accumulator at the end
1246 Value neutralFinalAcc = xegpu::createReductionNeutralValue(
1247 rewriter, loc, sgDstType, op.getKind());
1248
1249 SmallVector<Value> finalResults;
1250 for (size_t i = 0; i < slmLoadData.size(); ++i) {
1251 auto loaded = slmLoadData[i];
1252 auto finalReduce = vector::MultiDimReductionOp::create(
1253 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1254 reductionDims);
1255 finalResults.push_back(vector::makeArithReduction(
1256 rewriter, loc, op.getKind(), finalReduce.getResult(),
1257 adaptor.getAcc()[i]));
1258 }
1259 rewriter.replaceOpWithMultiple(op, {finalResults});
1260 return success();
1261 }
1262};
1263
1264// This pattern transforms vector.transpose ops to work at subgroup level.
1265struct WgToSgVectorTransposeOp
1266 : public OpConversionPattern<vector::TransposeOp> {
1267 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1268
1269 LogicalResult
1270 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1271 ConversionPatternRewriter &rewriter) const override {
1272 VectorType resultType = op.getResultVectorType();
1273
1274 ArrayRef<int64_t> wgShape = resultType.getShape();
1275 xegpu::DistributeLayoutAttr layout =
1276 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1277 if (!layout || !layout.isForWorkgroup())
1278 return failure();
1279 xegpu::DistributeLayoutAttr sourceLayout =
1280 xegpu::getTemporaryLayout(op->getOpOperand(0));
1281 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1282 return failure();
1283
1284 SmallVector<int64_t> sourceSgLayout =
1285 sourceLayout.getEffectiveSgLayoutAsInt();
1286 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1287
1288 ArrayRef<int64_t> permutation = op.getPermutation();
1289 size_t permutationSize = permutation.size();
1290 if (sourceSgLayout.size() != permutationSize ||
1291 resultSgLayout.size() != permutationSize) {
1292 return rewriter.notifyMatchFailure(
1293 op, "Layouts and permutation must have the same rank");
1294 }
1295
1296 // Check that sgLayout, sgData & order are properly transposed for source
1297 // and result
1298 if (!layout.isTransposeOf(sourceLayout, permutation,
1299 xegpu::LayoutKind::Subgroup))
1300 return rewriter.notifyMatchFailure(
1301 op, "Result layout is not a valid transpose of source layout "
1302 "according to permutation");
1303
1304 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1305 VectorType newResultType =
1306 VectorType::get(sgShape, resultType.getElementType());
1307
1308 SmallVector<Value> newTransposeOps;
1309 for (auto src : adaptor.getVector()) {
1310 auto newTranspose = vector::TransposeOp::create(
1311 rewriter, op.getLoc(), newResultType, src, permutation);
1312 newTransposeOps.push_back(newTranspose.getResult());
1313 }
1314 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1315 return success();
1316 }
1317};
1318
1319// Distribute vector mask ops to work at subgroup level.
1320template <typename MaskOpType>
1321struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1322 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1323
1324 LogicalResult matchAndRewrite(
1325 MaskOpType op,
1326 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1327 ConversionPatternRewriter &rewriter) const override {
1328 xegpu::DistributeLayoutAttr layout =
1329 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1330 if (!layout || !layout.isForWorkgroup())
1331 return failure();
1332
1333 Location loc = op.getLoc();
1334 VectorType type = op.getResult().getType();
1335 auto wgShape = type.getShape();
1336
1337 SmallVector<Value> wgMaskDimSizes;
1338 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1339 for (int64_t maskSize : op.getMaskDimSizes()) {
1340 wgMaskDimSizes.push_back(
1341 arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1342 }
1343 } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1344 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1345 }
1346
1347 Value sgId =
1348 gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1349 auto sgOffsets =
1350 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1351 if (failed(sgOffsets))
1352 return failure();
1353
1354 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1355 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1356
1357 // In each dimension, each subgroup computes its local mask size as:
1358 // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
1359 SmallVector<Value> newCreateMaskOps;
1360 for (auto offsetSet : *sgOffsets) {
1361 SmallVector<Value> maskOperands;
1362
1363 for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1364 Value dimSizeVal =
1365 arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1366 Value offset = offsetSet[i];
1367 Value adjustedMaskSize =
1368 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1369 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1370 Value nonNegative =
1371 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1372 Value sgMaskSize =
1373 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1374 maskOperands.push_back(sgMaskSize);
1375 }
1376
1377 auto newCreateMaskOp =
1378 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1379 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1380 }
1381
1382 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1383 return success();
1384 }
1385};
1386
1387using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1388using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1389
1390// This pattern transforms vector.bitcast ops to work at subgroup level.
1391struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
1392 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1393
1394 LogicalResult
1395 matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter) const override {
1397 VectorType resultType = op.getResultVectorType();
1398
1399 ArrayRef<int64_t> wgShape = resultType.getShape();
1400 xegpu::DistributeLayoutAttr layout =
1401 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1402 if (!layout || !layout.isForWorkgroup())
1403 return failure();
1404
1405 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1406 VectorType newResultType =
1407 VectorType::get(sgShape, resultType.getElementType());
1408
1409 SmallVector<Value> newBitCastOps;
1410 for (auto src : adaptor.getSource()) {
1411 auto newBitCast =
1412 vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1413 newBitCastOps.push_back(newBitCast.getResult());
1414 }
1415
1416 rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1417 return success();
1418 }
1419};
1420
1421// This pattern transforms vector.interleave ops to work at subgroup level.
1422struct WgToSgVectorInterleaveOp
1423 : public OpConversionPattern<vector::InterleaveOp> {
1424 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1425
1426 LogicalResult
1427 matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1428 ConversionPatternRewriter &rewriter) const override {
1429 VectorType resultType = op.getResultVectorType();
1430
1431 ArrayRef<int64_t> wgShape = resultType.getShape();
1432 xegpu::DistributeLayoutAttr layout =
1433 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1434 if (!layout || !layout.isForWorkgroup())
1435 return failure();
1436
1437 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1438 VectorType newResultType =
1439 VectorType::get(sgShape, resultType.getElementType());
1440
1441 SmallVector<Value> newInterleaveOps;
1442 // Interleave operates pairwise: each lhs value is interleaved with
1443 // corresponding rhs value
1444 for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1445 auto newInterleave = vector::InterleaveOp::create(
1446 rewriter, op.getLoc(), newResultType, lhs, rhs);
1447 newInterleaveOps.push_back(newInterleave.getResult());
1448 }
1449
1450 rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1451 return success();
1452 }
1453};
1454
1455// This pattern transforms vector.deinterleave ops to work at subgroup level.
1456struct WgToSgVectorDeinterleaveOp
1457 : public OpConversionPattern<vector::DeinterleaveOp> {
1458 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1459
1460 LogicalResult
1461 matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 SmallVector<Value> newRes1Ops;
1464 SmallVector<Value> newRes2Ops;
1465
1466 for (auto src : adaptor.getSource()) {
1467 auto newDeinterleave =
1468 vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1469 newRes1Ops.push_back(newDeinterleave.getRes1());
1470 newRes2Ops.push_back(newDeinterleave.getRes2());
1471 }
1472
1473 SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1474 rewriter.replaceOpWithMultiple(op, results);
1475 return success();
1476 }
1477};
1478
1479} // namespace
1480
1481namespace mlir {
1482namespace xegpu {
1484 Operation *topLevelOp) {
1485 // Pass through all types by default.
1486 converter.addConversion([](Type type) -> Type { return type; });
1487
1488 // For TensorDescType, convert WG-level tensor descs to N SG-level descs.
1489 converter.addConversion(
1490 [](xegpu::TensorDescType type,
1491 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1492 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1493 if (!layout || !layout.isForWorkgroup())
1494 return std::nullopt;
1495
1496 Type elemTy = type.getElementType();
1497 ArrayRef<int64_t> shape = type.getShape();
1498
1499 int count;
1500 SmallVector<int64_t> subShape;
1501 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1502
1503 layout = layout.dropSgLayoutAndData();
1504
1505 auto newTy = xegpu::TensorDescType::get(
1506 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1507 result.append(count, newTy);
1508 return success();
1509 });
1510
1511 // Context-aware VectorType conversion based on sg_layout/sg_data
1512 // (1:1 shape-changing or 1:N).
1513 auto getSubShapeAndCount = [](VectorType vecTy,
1514 xegpu::DistributeLayoutAttr layout)
1515 -> std::pair<SmallVector<int64_t>, int> {
1516 if (!layout.isForWorkgroup())
1517 return {{}, 0};
1518 return getSgShapeAndCount(vecTy.getShape(), layout);
1519 };
1520 auto loopArgTypes =
1521 xegpu::precomputeLoopBlockArgTypes(topLevelOp, getSubShapeAndCount);
1522 xegpu::addVectorTypeConversion(converter, getSubShapeAndCount,
1523 std::move(loopArgTypes));
1524}
1525
1527 patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1528 WgToSgDpasMxOp, WgToSgPrefetchNdOp, WgToSgElementwiseOp,
1529 WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1530 WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1531 WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1532 WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1533 WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1534 WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1535 WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
1536 patterns.getContext());
1537}
1538} // namespace xegpu
1539} // namespace mlir
1540
1541namespace {
1542struct XeGPUWgToSgDistributePass
1543 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1544 void runOnOperation() override;
1545};
1546} // namespace
1547
1548void XeGPUWgToSgDistributePass::runOnOperation() {
1549
1550 Operation *op = getOperation();
1552 signalPassFailure();
1553 return;
1554 }
1555
1556 // Collect existing UnrealizedConversionCastOps. These must be preserved.
1557 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1558 getOperation()->walk(
1559 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1560
1561 // Perform workgroup to subgroup distribution for TensorDesc and Vector
1562 // values, as well as XeGPU, Arith, and Vector operations. Uses a
1563 // context-aware type converter that inspects Values to retrieve the
1564 // distribute layout attribute for 1:N type conversion.
1565 MLIRContext *ctx = &getContext();
1566 RewritePatternSet patterns(ctx);
1567 ConversionTarget target(*ctx);
1568 TypeConverter converter;
1569 // Source (N:1) and target (1:1) materializations using
1570 // UnrealizedConversionCastOp.
1571 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
1572 Location loc) -> Value {
1573 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
1574 .getResult(0);
1575 };
1576 converter.addSourceMaterialization(materializeCast);
1577 converter.addTargetMaterialization(materializeCast);
1579 getOperation());
1580
1581 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1582 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1583 return createOp.getType();
1584 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1585 return loadOp.getTensorDescType();
1586 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1587 return storeOp.getTensorDescType();
1588 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1589 return prefetchOp.getTensorDescType();
1590 return xegpu::TensorDescType();
1591 };
1592
1593 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1594 return !layout || !layout.isForWorkgroup();
1595 };
1596
1597 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1598 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(
1599 [=](Operation *op) -> bool {
1600 auto tdescTy = getTensorDescType(op);
1601 auto layout =
1602 dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1603 return isLegal(layout);
1604 });
1605
1606 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1607 auto layout = op.getLayoutCdAttr();
1608 return isLegal(layout);
1609 });
1610
1611 target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1612 [=](xegpu::DpasMxOp op) -> bool {
1613 auto layout = op.getLayoutCdAttr();
1614 return isLegal(layout);
1615 });
1616
1617 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1618 [=](xegpu::LoadMatrixOp op) -> bool {
1619 return isLegal(op.getLayoutAttr());
1620 });
1621
1622 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1623 [=](xegpu::StoreMatrixOp op) -> bool {
1624 return isLegal(op.getLayoutAttr());
1625 });
1626
1627 target.addDynamicallyLegalOp<arith::ConstantOp>(
1628 [=](arith::ConstantOp op) -> bool {
1629 auto vecType = dyn_cast<VectorType>(op.getType());
1630 if (!vecType)
1631 return true;
1632
1633 auto layout =
1634 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
1635 return isLegal(layout);
1636 });
1637
1638 target.addDynamicallyLegalOp<
1639 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1640 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1641 vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1642 vector::DeinterleaveOp>([=](Operation *op) -> bool {
1643 // Check for either a SliceAttr or LayoutAttr on the result.
1644 auto layout =
1645 xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
1646 return isLegal(layout);
1647 });
1648
1649 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1650 [=](xegpu::LoadGatherOp op) -> bool {
1651 auto layout = op.getLayoutAttr();
1652 return isLegal(layout);
1653 });
1654
1655 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1656 [=](xegpu::StoreScatterOp op) -> bool {
1657 auto layout = op.getLayoutAttr();
1658 return isLegal(layout);
1659 });
1660
1661 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1662 [=](xegpu::ConvertLayoutOp op) -> bool {
1663 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1664 });
1665
1666 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1667 [=](Operation *op) -> std::optional<bool> {
1668 // Only handle elementwise mappable ops
1670 return true;
1671
1672 VectorType resultType =
1673 dyn_cast<VectorType>(op->getResult(0).getType());
1674 if (!resultType)
1675 return true;
1676
1677 // Check if all operands are vectors of the same shape
1678 // TODO: Support other types.
1679 for (Value operand : op->getOperands()) {
1680 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1681 if (!operandType || operandType.getShape() != resultType.getShape()) {
1682 return true;
1683 }
1684 }
1685
1686 xegpu::DistributeLayoutAttr layout =
1688 return isLegal(layout);
1689 });
1690
1691 target.addLegalOp<UnrealizedConversionCastOp>();
1692
1693 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1694
1696 target);
1698 if (failed(
1699 applyPartialConversion(getOperation(), target, std::move(patterns))))
1700 return signalPassFailure();
1701
1702 // Fold cancelling cast chains and erase dead casts.
1703 xegpu::cleanupUnrealizedConversionCasts(getOperation(), existingCasts);
1704 xegpu::removeTemporaryLayoutAttrs(getOperation());
1705}
return success()
lhs
b getContext())
#define mul(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUWgToSgDistributeTypeConversions(TypeConverter &converter, Operation *topLevelOp)
Define the type conversions needed for XeGPU workgroup to subgroup distribution.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DenseMap< Value, SmallVector< Type > > precomputeLoopBlockArgTypes(Operation *topLevelOp, SubShapeAndCountFn getSubShapeAndCount)
Pre-computes distributed VectorType mappings for every value carried through an SCF loop under topLev...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
void addVectorTypeConversion(TypeConverter &converter, SubShapeAndCountFn getSubShapeAndCount, DenseMap< Value, SmallVector< Type > > loopArgTypes)
Adds a context-aware VectorType conversion to converter (1:1 shape-changing or 1:N,...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void cleanupUnrealizedConversionCasts(Operation *root, const llvm::SmallSetVector< UnrealizedConversionCastOp, 8 > &existingCasts)
Cleans up UnrealizedConversionCastOps inserted during SCF structural type conversion and/or XeGPU unr...
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.