29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 auto distributedShape = layout.computeDistributedShape(
57 if (
failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
61 return std::make_pair(sgData, count);
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
72 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
79 if (origOffsets.empty())
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
88 layout = op.getDescLayoutAttr();
92 if (!layout || !layout.isForWorkgroup())
96 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op,
"sg_layout size must match the sg_id_range");
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (
failed(maybeDescOffsets))
125 for (
const auto &sgOffsets : *maybeDescOffsets) {
128 offsetsList.push_back(std::move(newOffsets));
180struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
181 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
184 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
186 SmallVector<SmallVector<OpFoldResult>> offsetsList;
187 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
190 MLIRContext *ctx = op.getContext();
191 xegpu::TensorDescType tdescTy = op.getType();
192 ArrayRef<int64_t> wgShape = tdescTy.getShape();
193 Type elemTy = tdescTy.getElementType();
194 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
195 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
197 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
198 layout.dropSgLayoutAndData());
200 SmallVector<Value> newOps;
201 for (
auto offsets : offsetsList) {
202 auto newOp = xegpu::CreateNdDescOp::create(
203 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
204 op.getMixedSizes(), op.getMixedStrides());
206 newOps.push_back(newOp);
208 rewriter.replaceOpWithMultiple(op, {newOps});
216struct WgToSgCreateNdOpNoOffset
217 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
218 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
221 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter)
const override {
225 if (!op.getMixedOffsets().empty())
228 Location loc = op.getLoc();
229 MLIRContext *ctx = op.getContext();
230 xegpu::TensorDescType tdescTy = op.getType();
231 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
232 if (!layout || !layout.isForWorkgroup())
235 Type elemTy = tdescTy.getElementType();
236 ArrayRef<int64_t> wgShape = tdescTy.getShape();
238 SmallVector<int64_t> sgShape;
240 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
241 xegpu::TensorDescType newTdescTy =
242 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
243 layout.dropSgLayoutAndData());
245 SmallVector<Value> newCreateNdOps(count);
246 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
247 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
248 op.getSource(), op.getMixedSizes(),
249 op.getMixedStrides());
252 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
258struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
259 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
261 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 if (!op.getMixedOffsets().empty())
266 SmallVector<Value> newLoadOps;
267 for (
auto src : adaptor.getTensorDesc()) {
268 xegpu::TensorDescType tdescTy =
269 dyn_cast<xegpu::TensorDescType>(src.getType());
270 ArrayRef<int64_t> srcShape = tdescTy.getShape();
271 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
272 auto newLoadOp = xegpu::LoadNdOp::create(
273 rewriter, op.getLoc(), newResTy, src,
275 newLoadOps.push_back(newLoadOp);
277 rewriter.replaceOpWithMultiple(op, {newLoadOps});
278 return mlir::success();
285struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
286 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
288 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter)
const override {
290 if (!op.getMixedOffsets().empty())
293 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
294 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
297 rewriter.eraseOp(op);
304struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
305 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
307 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override {
310 SmallVector<SmallVector<OpFoldResult>> offsetsList;
311 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
314 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
316 layout = layout.dropSgLayoutAndData();
317 SmallVector<Value> newOps;
318 for (
auto [tdesc, offsets] :
319 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
320 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
321 VectorType newResTy =
322 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
323 auto newOp = xegpu::LoadNdOp::create(
324 rewriter, op.getLoc(), newResTy, tdesc, offsets,
325 nullptr,
nullptr, op.getL1HintAttr(),
326 op.getL2HintAttr(), op.getL3HintAttr(), layout);
327 newOps.push_back(newOp);
329 rewriter.replaceOpWithMultiple(op, {newOps});
337struct WgToSgStoreNdOpWithOffset
338 :
public OpConversionPattern<xegpu::StoreNdOp> {
339 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
341 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const override {
343 SmallVector<SmallVector<OpFoldResult>> offsetsList;
344 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
347 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
349 layout = layout.dropSgLayoutAndData();
350 for (
auto [v, tdesc, offsets] :
351 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
352 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
353 op.getL1HintAttr(), op.getL2HintAttr(),
354 op.getL3HintAttr(), layout);
356 rewriter.eraseOp(op);
364struct WgToSgPrefetchNdOpWithOffset
365 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
366 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
368 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter)
const override {
370 SmallVector<SmallVector<OpFoldResult>> offsetsList;
371 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
374 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
376 layout = layout.dropSgLayoutAndData();
377 for (
auto [tdesc, offsets] :
378 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
379 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
380 op.getL1HintAttr(), op.getL2HintAttr(),
381 op.getL3HintAttr(), layout);
383 rewriter.eraseOp(op);
392struct WgToSgUpdateNdOffsetOp
393 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
394 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
396 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
397 ConversionPatternRewriter &rewriter)
const override {
398 llvm::SmallVector<Value> newUpdateTileOffsetOps;
399 for (
auto tDesc : adaptor.getTensorDesc()) {
400 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
401 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
402 op.getConstOffsets());
403 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
406 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
412struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
413 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
415 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter)
const override {
417 Location loc = op.getLoc();
418 VectorType resultTy = op.getResult().getType();
419 if (resultTy.getRank() != 2)
422 auto layoutCd = op.getLayoutCdAttr();
423 auto layoutA = op.getLayoutAAttr();
424 auto layoutB = op.getLayoutBAttr();
425 if (!layoutCd || !layoutA || !layoutB)
428 SmallVector<Value> newDpasOps;
429 for (
auto aVec : adaptor.getLhs()) {
430 for (
auto bVec : adaptor.getRhs()) {
432 llvm::SmallVector<Value> operands({aVec, bVec});
435 tmpC = adaptor.getAcc()[i++];
436 operands.push_back(tmpC);
439 ArrayRef<int64_t> aVecShape =
440 llvm::cast<VectorType>(aVec.getType()).getShape();
441 ArrayRef<int64_t> bVecShape =
442 llvm::cast<VectorType>(bVec.getType()).getShape();
443 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
444 resultTy.getElementType());
445 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
446 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
447 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
448 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
450 newDpasOps.push_back(newDpasOp);
453 rewriter.replaceOpWithMultiple(op, {newDpasOps});
459struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
460 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
462 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter)
const override {
465 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
466 if ((offsetSize != 0) || op.getConstOffsetsAttr())
469 for (
auto src : adaptor.getTensorDesc())
470 xegpu::PrefetchNdOp::create(
473 rewriter.eraseOp(op);
479struct WgToSgVectorBroadcastOp
480 :
public OpConversionPattern<vector::BroadcastOp> {
481 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
484 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const override {
487 VectorType resultType = op.getResult().getType();
488 ArrayRef<int64_t> wgShape = resultType.getShape();
490 xegpu::DistributeLayoutAttr layout =
492 if (!layout || !layout.isForWorkgroup())
495 SmallVector<int64_t> sgShape;
497 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
498 VectorType newResultType =
499 VectorType::get(sgShape, resultType.getElementType());
501 SmallVector<Value> newBroadcastOps;
502 auto distSource = adaptor.getOperands().front();
503 int numDistributions = count / distSource.size();
504 for (
int i = 0; i < numDistributions; ++i) {
505 for (
auto operand : distSource) {
506 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
507 newResultType, operand);
509 newBroadcastOps.push_back(newBroadcast.getResult());
512 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
519 WgToSgElementwiseOp(MLIRContext *ctx)
520 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
523 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
524 ConversionPatternRewriter &rewriter)
const override {
530 assert(resultType &&
"Expected result to be a VectorType");
532 ArrayRef<int64_t> wgShape = resultType.getShape();
534 xegpu::DistributeLayoutAttr layout =
536 if (!layout || !layout.isForWorkgroup())
539 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
541 size_t numVariants = operands.empty() ? 0 : operands.front().size();
543 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
544 return operandVec.size() != numVariants;
548 SmallVector<Value> newResults;
549 VectorType newResultType =
550 VectorType::get(sgShape, resultType.getElementType());
552 for (
size_t i = 0; i < numVariants; ++i) {
553 SmallVector<Value> opOperands;
554 for (
auto &operandVec : operands)
555 opOperands.push_back(operandVec[i]);
558 state.addOperands(opOperands);
559 state.addTypes(newResultType);
560 state.addAttributes(op->
getAttrs());
561 Operation *newOp = rewriter.create(state);
563 newResults.push_back(newOp->
getResult(0));
566 rewriter.replaceOpWithMultiple(op, {newResults});
597struct WgToSgConvertLayoutOp
598 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
599 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
602 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
603 ConversionPatternRewriter &rewriter)
const override {
604 Location loc = op.getLoc();
605 auto inputLayout = op.getInputLayout();
606 auto targetLayout = op.getTargetLayout();
608 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
609 !targetLayout.isForWorkgroup())
610 return rewriter.notifyMatchFailure(
611 op,
"Input and target layouts must have subgroup layout");
613 Type resultType = op.getResult().getType();
615 rewriter.replaceOp(op, op.getSource());
616 assert(!inputLayout.dropSgLayoutAndData() &&
617 !targetLayout.dropSgLayoutAndData() &&
618 "unexpected layout attributes for scalar type");
622 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
623 SmallVector<int64_t> inputSgLayout =
624 inputLayout.getEffectiveSgLayoutAsInt();
625 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
626 SmallVector<int64_t> targetSgLayout =
627 targetLayout.getEffectiveSgLayoutAsInt();
628 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
631 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
632 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
633 xegpu::LayoutKind::Subgroup)) {
634 inputLayout = inputLayout.dropSgLayoutAndData();
635 targetLayout = targetLayout.dropSgLayoutAndData();
637 SmallVector<Value> newOps(adaptor.getSource());
638 if (inputLayout && targetLayout) {
639 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
640 auto newOp = xegpu::ConvertLayoutOp::create(
641 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
645 rewriter.replaceOpWithMultiple(op, {newOps});
650 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
652 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
656 auto bytesPerElement = bitWidth / 8;
660 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
661 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
663 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
666 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
668 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
669 rewriter.getIndexType(),
nullptr);
672 auto storeCoords = inputLayout.computeDistributedCoords(
673 rewriter, loc, sgId.getResult(), wgShape);
678 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
679 SmallVector<OpFoldResult> storeMatrixOffsets;
680 for (Value coord : coords) {
681 storeMatrixOffsets.push_back(coord);
683 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
684 storeMatrixOffsets,
nullptr );
687 gpu::BarrierOp::create(rewriter, loc);
690 auto loadCoords = targetLayout.computeDistributedCoords(
691 rewriter, loc, sgId.getResult(), wgShape);
695 VectorType loadType = VectorType::get(targetSgData, elemTy);
698 SmallVector<Value> finalResults;
699 for (
auto coords : *loadCoords) {
700 SmallVector<OpFoldResult> loadMatrixOffsets;
701 for (Value coord : coords) {
702 loadMatrixOffsets.push_back(coord);
704 auto loadOp = xegpu::LoadMatrixOp::create(
705 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
706 targetLayout.dropSgLayoutAndData());
708 finalResults.push_back(loadOp.getResult());
711 rewriter.replaceOpWithMultiple(op, {finalResults});
747struct UnrealizedConversionCastOpPattern
748 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
749 using OpConversionPattern<
750 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
753 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
754 ConversionPatternRewriter &rewriter)
const override {
757 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
758 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
760 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
761 !llvm::all_equal(
ValueRange(inputs).getTypes()))
769 if (op.getNumOperands() == 1 &&
770 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
771 rewriter.replaceOp(op, inputs);
782 if (op.getNumResults() == 1 &&
784 rewriter.replaceOpWithMultiple(op, {inputs});
788 return mlir::failure();
793struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
794 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
797 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
798 ConversionPatternRewriter &rewriter)
const override {
799 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
800 auto vecType = dyn_cast<VectorType>(op.getType());
801 if (!vecAttr || !vecType)
804 xegpu::DistributeLayoutAttr layout =
806 if (!layout || !layout.isForWorkgroup())
809 ArrayRef<int64_t> wgShape = vecType.getShape();
810 SmallVector<int64_t> sgShape;
812 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
814 auto newType = VectorType::get(sgShape, vecType.getElementType());
815 Location loc = op.getLoc();
816 auto eltType = vecType.getElementType();
818 if (vecAttr.isSplat()) {
820 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
822 SmallVector<Value> newConstOps;
823 for (
int i = 0; i < count; ++i) {
824 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
825 newConstOps.push_back(cstOp);
827 rewriter.replaceOpWithMultiple(op, {newConstOps});
829 }
else if (sgShape == wgShape) {
832 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
833 rewriter.replaceOp(op, newConstOp);
839 if (!eltType.isIndex())
840 return rewriter.notifyMatchFailure(
841 op,
"Unsupported element type for non-splat constant op.");
843 if (wgShape.size() > 2)
844 return rewriter.notifyMatchFailure(
845 op,
"Only 1D & 2D vector constant supported");
847 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
848 int64_t rowStride = 0, colStride = 0;
849 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
850 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
854 colStride = cast<IntegerAttr>(values[1]).getInt() -
855 cast<IntegerAttr>(values[0]).getInt();
858 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
859 cast<IntegerAttr>(values[0]).getInt();
862 for (int64_t r = 0; r < rows; ++r) {
863 for (int64_t c = 0; c < cols; ++c) {
864 int64_t idx = r * cols + c;
866 if (c > 0 && cols > 1) {
867 int64_t prevIdx = r * cols + (c - 1);
868 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
869 cast<IntegerAttr>(values[prevIdx]).getInt();
870 if (diff != colStride)
871 return rewriter.notifyMatchFailure(
872 op,
"Non-constant column stride in constant op.");
875 if (r > 0 && rows > 1) {
876 int64_t prevIdx = (r - 1) * cols + c;
877 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
878 cast<IntegerAttr>(values[prevIdx]).getInt();
879 if (diff != rowStride)
880 return rewriter.notifyMatchFailure(
881 op,
"Non-constant row stride in constant op.");
889 SmallVector<Attribute> baseTileValues;
890 int baseTileCols = sgShape[sgShape.size() - 1];
891 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
892 for (int64_t r = 0; r < baseTileRows; ++r) {
893 for (int64_t c = 0; c < baseTileCols; ++c) {
894 baseTileValues.push_back(values[r * cols + c]);
900 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
904 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
906 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
910 SmallVector<Value, 2> strideConsts;
911 strideConsts.push_back(
915 strideConsts.begin(),
918 SmallVector<Value> newConstOps;
919 for (
auto offsets : *sgOffsets) {
922 for (
size_t i = 0; i < strideConsts.size(); ++i) {
924 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
925 offsets[i], strideConsts[i]);
926 mulOffset = arith::AddIOp::create(
927 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
930 auto bcastOffset = vector::BroadcastOp::create(
931 rewriter, loc, baseConstVec.getType(), mulOffset);
933 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
934 newConstOps.push_back(finalConst);
936 rewriter.replaceOpWithMultiple(op, {newConstOps});
944struct WgToSgLoadGatherOpWithOffset
945 :
public OpConversionPattern<xegpu::LoadGatherOp> {
946 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
948 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
949 ConversionPatternRewriter &rewriter)
const override {
951 Location loc = op.getLoc();
952 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
955 ArrayRef<int64_t> wgShape = resultType.getShape();
957 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
959 if (!layout || !layout.isForWorkgroup())
962 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
965 auto offsetsVecType =
966 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
968 dyn_cast<VectorType>(adaptor.getMask().front().getType());
969 if (!offsetsVecType || !maskVecType ||
970 offsetsVecType.getShape() != maskVecType.getShape()) {
971 return rewriter.notifyMatchFailure(op,
972 "offsets have not been distributed");
975 SmallVector<Value> newLoadOps;
977 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
978 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
979 for (
auto [offsets, mask] :
980 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
981 auto newLayout = layout.dropSgLayoutAndData();
982 auto newLoadOp = xegpu::LoadGatherOp::create(
983 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
984 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
986 newLoadOps.push_back(newLoadOp);
988 rewriter.replaceOpWithMultiple(op, {newLoadOps});
995struct WgToSgStoreScatterOpWithOffset
996 :
public OpConversionPattern<xegpu::StoreScatterOp> {
997 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
999 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
1000 ConversionPatternRewriter &rewriter)
const override {
1002 Location loc = op.getLoc();
1003 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1007 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1009 if (!layout || !layout.isForWorkgroup())
1013 auto offsetsVecType =
1014 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1016 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1017 if (!offsetsVecType || !maskVecType ||
1018 offsetsVecType.getShape() != maskVecType.getShape()) {
1019 return rewriter.notifyMatchFailure(op,
1020 "offsets have not been distributed");
1023 auto chunkSizeOpt = op.getChunkSize();
1024 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
1025 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1026 for (
auto [val, offs, mask] : llvm::zip(
1027 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1028 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1029 mask, chunkSizeAttr, op.getL1HintAttr(),
1030 op.getL2HintAttr(), op.getL3HintAttr(),
1031 layout.dropSgLayoutAndData());
1033 rewriter.eraseOp(op);
1038struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
1039 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1041 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter)
const override {
1044 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1045 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1048 ArrayRef<int64_t> wgShape = op.getDataShape();
1049 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1050 assert(valueTy &&
"the value type must be vector type!");
1051 Type elemTy = valueTy.getElementType();
1053 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1054 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1055 VectorType newResTy = VectorType::get(sgShape, elemTy);
1056 SmallVector<Value> newOps;
1057 for (
auto offsets : offsetsList) {
1058 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1059 op.getMemDesc(), offsets,
1060 layout.dropSgLayoutAndData());
1061 newOps.push_back(newOp);
1063 rewriter.replaceOpWithMultiple(op, {newOps});
1069struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1070 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1072 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1073 ConversionPatternRewriter &rewriter)
const override {
1075 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1076 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1079 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1080 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1081 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1082 offsets, layout.dropSgLayoutAndData());
1083 rewriter.eraseOp(op);
1089struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1090 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1092 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1093 ConversionPatternRewriter &rewriter)
const override {
1094 xegpu::DistributeLayoutAttr layout =
1096 if (!layout || !layout.isForWorkgroup())
1099 Location loc = op.getLoc();
1100 VectorType type = op.getResult().getType();
1101 auto wgShape = type.getShape();
1102 std::optional<SmallVector<int64_t>> sgShape =
1103 getSgShapeAndCount(wgShape, layout).first;
1108 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1110 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1114 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1115 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1116 SmallVector<Value> newOps;
1117 for (
auto offsets : *sgOffsets) {
1120 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1122 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1123 newOps.push_back(finalSteps);
1126 rewriter.replaceOpWithMultiple(op, {newOps});
1132struct WgToSgVectorShapeCastOp
1133 :
public OpConversionPattern<vector::ShapeCastOp> {
1134 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1137 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1138 ConversionPatternRewriter &rewriter)
const override {
1140 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1144 ArrayRef<int64_t> wgShape = resultType.getShape();
1145 xegpu::DistributeLayoutAttr layout =
1147 if (!layout || !layout.isForWorkgroup())
1152 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1156 ArrayRef<int64_t> srcShape = srcType.getShape();
1158 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1159 SmallVector<int64_t> expandedUnitDims;
1161 xegpu::DistributeLayoutAttr sourceLayout =
1164 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1165 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1166 return isa<vector::BroadcastOp>(user);
1170 if (!usedByBroadcastOp(op))
1171 return rewriter.notifyMatchFailure(
1172 op,
"ShapeCast ops that expand unit dimensions and are used by "
1173 "non-broadcast operations are not supported.");
1175 if (!sourceLayout.isSliceOf(layout))
1176 return rewriter.notifyMatchFailure(
1177 op,
"The ShapeCast op only expands dimensions, the input layout "
1178 "must be a slice of the result layout.");
1180 assert(layoutToDistribute.isEqualTo(
1181 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1182 "The sg_data for unit dimensions should be set as 1");
1185 SmallVector<int64_t> sgShape =
1186 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1187 VectorType newResultType =
1188 VectorType::get(sgShape, resultType.getElementType());
1190 SmallVector<Value> newShapeCastOps;
1191 for (
auto src : adaptor.getSource()) {
1192 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1193 newResultType, src);
1194 newShapeCastOps.push_back(newShapeCast.getResult());
1197 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1234struct WgToSgMultiDimReductionOp
1235 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1236 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1239 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1240 ConversionPatternRewriter &rewriter)
const override {
1241 Location loc = op.getLoc();
1243 VectorType srcType = op.getSourceVectorType();
1244 Type resultTy = op.getResult().getType();
1245 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1246 bool isScalarResult = !dstVecType;
1248 auto originalSrcShape = srcType.getShape();
1249 Type elemTy = srcType.getElementType();
1251 xegpu::DistributeLayoutAttr layout =
1253 if (!layout || !layout.isForWorkgroup())
1256 auto reductionDims = llvm::to_vector(op.getReductionDims());
1259 SmallVector<int64_t> sgLayout;
1260 SmallVector<int64_t> sgData;
1261 xegpu::DistributeLayoutAttr parentLayout;
1262 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1263 parentLayout = sliceAttr.getParent();
1264 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1265 sgData = parentLayout.getEffectiveSgDataAsInt();
1267 return rewriter.notifyMatchFailure(
1268 op,
"Reduction should have SliceAttr layout");
1271 SmallVector<Value> localReductions;
1272 auto sgSrcs = adaptor.getSource();
1273 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1274 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1275 sgSrcType.getShape().end());
1282 auto originalDstShape = dstVecType.getShape();
1283 SmallVector<int64_t> sgDstShape =
1284 getSgShapeAndCount(originalDstShape, layout).first;
1285 sgDstType = VectorType::get(sgDstShape, elemTy);
1290 for (
auto sgSrc : sgSrcs) {
1293 rewriter, loc, sgDstType, op.getKind());
1295 auto localReduce = vector::MultiDimReductionOp::create(
1296 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1298 localReductions.push_back(localReduce.getResult());
1302 SmallVector<int64_t> crossSgReductionDims;
1303 for (int64_t reductionDim : reductionDims) {
1304 bool needsCrossSubgroupReduction =
1305 (sgLayout[reductionDim] > 1) &&
1306 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1308 if (needsCrossSubgroupReduction) {
1309 crossSgReductionDims.push_back(reductionDim);
1314 if (crossSgReductionDims.empty()) {
1315 SmallVector<Value> results;
1316 for (
auto localResult : localReductions) {
1318 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1319 results.push_back(finalResult);
1321 rewriter.replaceOpWithMultiple(op, {results});
1326 auto slmStoreDataShape = sgSrcShape;
1327 for (int64_t dim : reductionDims)
1328 slmStoreDataShape[dim] = 1;
1329 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1330 SmallVector<Value> slmStoreData;
1331 for (
auto localResult : localReductions) {
1332 if (isScalarResult) {
1334 slmStoreData.push_back(vector::BroadcastOp::create(
1335 rewriter, loc, slmStoreDataType, localResult));
1337 slmStoreData.push_back(vector::ShapeCastOp::create(
1338 rewriter, loc, slmStoreDataType, localResult));
1342 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1343 originalSrcShape.end());
1344 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1345 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1346 for (
int dim : reductionDims) {
1347 slmShape[dim] = sgLayout[dim];
1350 xegpu::LayoutAttr slmStoreLayout =
1351 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1355 auto bytesPerElement = bitWidth / 8;
1357 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1358 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1360 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1363 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1366 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1367 rewriter.getIndexType(),
nullptr);
1369 auto slmStoreCoords =
1370 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1371 if (
failed(slmStoreCoords))
1373 for (
auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1374 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1375 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1380 gpu::BarrierOp::create(rewriter, loc);
1383 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1384 for (int64_t dim : reductionDims) {
1385 slmLoadDataShape[dim] = slmShape[dim];
1386 slmSgData[dim] = slmShape[dim];
1388 xegpu::LayoutAttr slmLoadLayout =
1389 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1390 auto slmLoadCoords =
1391 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1392 if (
failed(slmLoadCoords))
1395 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1396 SmallVector<Value> slmLoadData;
1397 for (
auto coord : *slmLoadCoords) {
1398 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1399 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1400 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1407 rewriter, loc, sgDstType, op.getKind());
1409 SmallVector<Value> finalResults;
1410 for (
size_t i = 0; i < slmLoadData.size(); ++i) {
1411 auto loaded = slmLoadData[i];
1412 auto finalReduce = vector::MultiDimReductionOp::create(
1413 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1416 rewriter, loc, op.getKind(), finalReduce.getResult(),
1417 adaptor.getAcc()[i]));
1419 rewriter.replaceOpWithMultiple(op, {finalResults});
1425struct WgToSgVectorTransposeOp
1426 :
public OpConversionPattern<vector::TransposeOp> {
1427 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1430 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1431 ConversionPatternRewriter &rewriter)
const override {
1432 VectorType resultType = op.getResultVectorType();
1434 ArrayRef<int64_t> wgShape = resultType.getShape();
1435 xegpu::DistributeLayoutAttr layout =
1437 if (!layout || !layout.isForWorkgroup())
1440 xegpu::DistributeLayoutAttr sourceLayout =
1442 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1445 SmallVector<int64_t> sourceSgLayout =
1446 sourceLayout.getEffectiveSgLayoutAsInt();
1447 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1449 ArrayRef<int64_t> permutation = op.getPermutation();
1450 size_t permutationSize = permutation.size();
1451 if (sourceSgLayout.size() != permutationSize ||
1452 resultSgLayout.size() != permutationSize) {
1453 return rewriter.notifyMatchFailure(
1454 op,
"Layouts and permutation must have the same rank");
1459 if (!layout.isTransposeOf(sourceLayout, permutation,
1460 xegpu::LayoutKind::Subgroup))
1461 return rewriter.notifyMatchFailure(
1462 op,
"Result layout is not a valid transpose of source layout "
1463 "according to permutation");
1465 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1466 VectorType newResultType =
1467 VectorType::get(sgShape, resultType.getElementType());
1469 SmallVector<Value> newTransposeOps;
1470 for (
auto src : adaptor.getVector()) {
1471 auto newTranspose = vector::TransposeOp::create(
1472 rewriter, op.getLoc(), newResultType, src, permutation);
1473 newTransposeOps.push_back(newTranspose.getResult());
1475 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1481template <
typename MaskOpType>
1482struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1483 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1485 LogicalResult matchAndRewrite(
1487 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1488 ConversionPatternRewriter &rewriter)
const override {
1489 xegpu::DistributeLayoutAttr layout =
1491 if (!layout || !layout.isForWorkgroup())
1494 Location loc = op.getLoc();
1495 VectorType type = op.getResult().getType();
1496 auto wgShape = type.getShape();
1498 SmallVector<Value> wgMaskDimSizes;
1499 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1500 for (int64_t maskSize : op.getMaskDimSizes()) {
1501 wgMaskDimSizes.push_back(
1504 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1505 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1509 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1511 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1515 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1516 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1520 SmallVector<Value> newCreateMaskOps;
1521 for (
auto offsetSet : *sgOffsets) {
1522 SmallVector<Value> maskOperands;
1524 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1527 Value offset = offsetSet[i];
1528 Value adjustedMaskSize =
1529 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1532 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1534 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1535 maskOperands.push_back(sgMaskSize);
1538 auto newCreateMaskOp =
1539 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1540 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1543 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1548using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1549using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1556 .
add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1557 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1558 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1559 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1560 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1561 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1562 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1563 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1564 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1565 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1572struct XeGPUWgToSgDistributePass
1573 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1574 void runOnOperation()
override;
1578void XeGPUWgToSgDistributePass::runOnOperation() {
1580 Operation *op = getOperation();
1582 signalPassFailure();
1587 SmallVector<Operation *> existingCastOps;
1588 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1589 existingCastOps.push_back(castOp.getOperation());
1599 TypeConverter converter;
1600 converter.addConversion([&](Type type) -> Type {
return type; });
1601 converter.addConversion(
1602 [&](RankedTensorType type,
1603 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1607 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1608 type.getEncoding());
1610 return std::nullopt;
1612 Type elemTy = type.getElementType();
1613 ArrayRef<int64_t> shape = type.getShape();
1616 SmallVector<int64_t> subShape;
1617 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1619 auto newTy = VectorType::get(subShape, elemTy);
1620 result.append(count, newTy);
1631 RewritePatternSet patterns(ctx);
1632 ConversionTarget
target(*ctx);
1633 TypeConverter converter;
1634 converter.addConversion([&](Type type) -> Type {
return type; });
1635 converter.addConversion(
1636 [&](xegpu::TensorDescType type,
1637 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1638 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1641 if (!layout || !layout.isForWorkgroup())
1642 return std::nullopt;
1644 Type elemTy = type.getElementType();
1645 ArrayRef<int64_t> shape = type.getShape();
1648 SmallVector<int64_t> subShape;
1649 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1651 layout = layout.dropSgLayoutAndData();
1653 auto newTy = xegpu::TensorDescType::get(
1654 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1655 result.append(count, newTy);
1659 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1660 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1661 return createOp.getType();
1662 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1663 return loadOp.getTensorDescType();
1664 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1665 return storeOp.getTensorDescType();
1666 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1668 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1669 return prefetchOp.getTensorDescType();
1670 return xegpu::TensorDescType();
1673 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1674 return !layout || !layout.isForWorkgroup();
1677 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1678 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1679 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1680 auto tdescTy = getTensorDescType(op);
1681 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1682 return isLegal(layout);
1685 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1686 auto layout = op.getLayoutCdAttr();
1687 return isLegal(layout);
1690 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1691 [=](xegpu::LoadMatrixOp op) ->
bool {
1692 return isLegal(op.getLayoutAttr());
1695 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1696 [=](xegpu::StoreMatrixOp op) ->
bool {
1697 return isLegal(op.getLayoutAttr());
1700 target.addDynamicallyLegalOp<arith::ConstantOp>(
1701 [=](arith::ConstantOp op) ->
bool {
1702 auto vecType = dyn_cast<VectorType>(op.getType());
1708 return isLegal(layout);
1711 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1712 vector::TransposeOp, vector::BroadcastOp,
1713 vector::MultiDimReductionOp,
1714 vector::ConstantMaskOp, vector::CreateMaskOp>(
1715 [=](Operation *op) ->
bool {
1719 return isLegal(layout);
1722 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1723 [=](xegpu::LoadGatherOp op) ->
bool {
1724 auto layout = op.getLayoutAttr();
1725 return isLegal(layout);
1728 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1729 [=](xegpu::StoreScatterOp op) ->
bool {
1730 auto layout = op.getLayoutAttr();
1731 return isLegal(layout);
1734 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1735 [=](xegpu::ConvertLayoutOp op) ->
bool {
1736 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1739 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1740 [=](Operation *op) -> std::optional<bool> {
1745 VectorType resultType =
1753 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1754 if (!operandType || operandType.getShape() != resultType.getShape()) {
1759 xegpu::DistributeLayoutAttr layout =
1761 return isLegal(layout);
1764 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1765 [=](UnrealizedConversionCastOp op) {
1766 return llvm::is_contained(existingCastOps, op.getOperation());
1769 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1775 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1776 return signalPassFailure();
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.