28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(
shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout;
91 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
92 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
93 layout = op.getLayoutAttr();
95 layout = op.getDescLayoutAttr();
99 if (!layout || !layout.isForWorkgroup())
103 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
106 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
108 int64_t startOfRange = sgIdRange.getStart().getInt();
109 int64_t endOfRange = sgIdRange.getEnd().getInt();
111 if (layout.getNumSubgroups() != endOfRange - startOfRange)
112 return rewriter.notifyMatchFailure(
113 op,
"sg_layout size must match the sg_id_range");
115 if (startOfRange > 0) {
116 Value startOfRangeVal =
118 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
125 auto maybeDescOffsets =
126 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
127 if (
failed(maybeDescOffsets))
132 for (
const auto &sgOffsets : *maybeDescOffsets) {
135 offsetsList.push_back(std::move(newOffsets));
187struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
188 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
191 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override {
193 SmallVector<SmallVector<OpFoldResult>> offsetsList;
194 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
197 MLIRContext *ctx = op.getContext();
198 xegpu::TensorDescType tdescTy = op.getType();
199 ArrayRef<int64_t> wgShape = tdescTy.getShape();
200 Type elemTy = tdescTy.getElementType();
201 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
202 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
204 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
205 layout.dropSgLayoutAndData());
207 SmallVector<Value> newOps;
208 for (
auto offsets : offsetsList) {
209 auto newOp = xegpu::CreateNdDescOp::create(
210 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
211 op.getMixedSizes(), op.getMixedStrides());
213 newOps.push_back(newOp);
215 rewriter.replaceOpWithMultiple(op, {newOps});
223struct WgToSgCreateNdOpNoOffset
224 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
225 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
228 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override {
232 if (!op.getMixedOffsets().empty())
235 Location loc = op.getLoc();
236 MLIRContext *ctx = op.getContext();
237 xegpu::TensorDescType tdescTy = op.getType();
238 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
239 if (!layout || !layout.isForWorkgroup())
242 Type elemTy = tdescTy.getElementType();
243 ArrayRef<int64_t> wgShape = tdescTy.getShape();
245 SmallVector<int64_t> sgShape;
247 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
248 xegpu::TensorDescType newTdescTy =
249 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
250 layout.dropSgLayoutAndData());
252 SmallVector<Value> newCreateNdOps(count);
253 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
254 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
255 op.getSource(), op.getMixedSizes(),
256 op.getMixedStrides());
259 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
265struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
266 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
268 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter)
const override {
270 if (!op.getMixedOffsets().empty())
273 SmallVector<Value> newLoadOps;
274 for (
auto src : adaptor.getTensorDesc()) {
275 xegpu::TensorDescType tdescTy =
276 dyn_cast<xegpu::TensorDescType>(src.getType());
277 ArrayRef<int64_t> srcShape = tdescTy.getShape();
278 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
279 auto newLoadOp = xegpu::LoadNdOp::create(
280 rewriter, op.getLoc(), newResTy, src,
282 newLoadOps.push_back(newLoadOp);
284 rewriter.replaceOpWithMultiple(op, {newLoadOps});
285 return mlir::success();
292struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
293 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
295 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override {
297 if (!op.getMixedOffsets().empty())
300 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
301 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
302 op.getL2HintAttr(), op.getL3HintAttr());
304 rewriter.eraseOp(op);
311struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
312 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
314 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter)
const override {
317 SmallVector<SmallVector<OpFoldResult>> offsetsList;
318 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
321 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
323 layout = layout.dropSgLayoutAndData();
324 SmallVector<Value> newOps;
325 for (
auto [tdesc, offsets] :
326 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
327 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
328 VectorType newResTy =
329 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
330 auto newOp = xegpu::LoadNdOp::create(
331 rewriter, op.getLoc(), newResTy, tdesc, offsets,
332 nullptr,
nullptr, op.getL1HintAttr(),
333 op.getL2HintAttr(), op.getL3HintAttr(), layout);
334 newOps.push_back(newOp);
336 rewriter.replaceOpWithMultiple(op, {newOps});
344struct WgToSgStoreNdOpWithOffset
345 :
public OpConversionPattern<xegpu::StoreNdOp> {
346 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
348 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override {
350 SmallVector<SmallVector<OpFoldResult>> offsetsList;
351 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
354 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
356 layout = layout.dropSgLayoutAndData();
357 for (
auto [v, tdesc, offsets] :
358 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
359 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
360 op.getL1HintAttr(), op.getL2HintAttr(),
361 op.getL3HintAttr(), layout);
363 rewriter.eraseOp(op);
371struct WgToSgPrefetchNdOpWithOffset
372 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
373 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
375 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
376 ConversionPatternRewriter &rewriter)
const override {
377 SmallVector<SmallVector<OpFoldResult>> offsetsList;
378 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
381 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
383 layout = layout.dropSgLayoutAndData();
384 for (
auto [tdesc, offsets] :
385 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
386 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
387 op.getL1HintAttr(), op.getL2HintAttr(),
388 op.getL3HintAttr(), layout);
390 rewriter.eraseOp(op);
399struct WgToSgUpdateNdOffsetOp
400 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
401 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
403 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
405 llvm::SmallVector<Value> newUpdateTileOffsetOps;
406 for (
auto tDesc : adaptor.getTensorDesc()) {
407 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
408 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
409 op.getConstOffsets());
410 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
413 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
419struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
420 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
422 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter)
const override {
424 Location loc = op.getLoc();
425 VectorType resultTy = op.getResult().getType();
426 if (resultTy.getRank() != 2)
429 auto layoutCd = op.getLayoutCdAttr();
430 auto layoutA = op.getLayoutAAttr();
431 auto layoutB = op.getLayoutBAttr();
432 if (!layoutCd || !layoutA || !layoutB)
435 SmallVector<Value> newDpasOps;
436 for (
auto aVec : adaptor.getLhs()) {
437 for (
auto bVec : adaptor.getRhs()) {
439 llvm::SmallVector<Value> operands({aVec, bVec});
442 tmpC = adaptor.getAcc()[i++];
443 operands.push_back(tmpC);
446 ArrayRef<int64_t> aVecShape =
447 llvm::cast<VectorType>(aVec.getType()).getShape();
448 ArrayRef<int64_t> bVecShape =
449 llvm::cast<VectorType>(bVec.getType()).getShape();
450 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
451 resultTy.getElementType());
452 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
453 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
454 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
455 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
457 newDpasOps.push_back(newDpasOp);
460 rewriter.replaceOpWithMultiple(op, {newDpasOps});
466struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
467 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
469 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
470 ConversionPatternRewriter &rewriter)
const override {
472 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
473 if ((offsetSize != 0) || op.getConstOffsetsAttr())
476 for (
auto src : adaptor.getTensorDesc())
477 xegpu::PrefetchNdOp::create(
480 rewriter.eraseOp(op);
486struct WgToSgVectorBroadcastOp
487 :
public OpConversionPattern<vector::BroadcastOp> {
488 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
491 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter)
const override {
494 VectorType resultType = op.getResult().getType();
495 ArrayRef<int64_t> wgShape = resultType.getShape();
497 xegpu::DistributeLayoutAttr layout =
499 if (!layout || !layout.isForWorkgroup())
502 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
503 VectorType newResultType =
504 VectorType::get(sgShape, resultType.getElementType());
506 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
509 SmallVector<Value> newBroadcastOps;
510 for (
auto operand : adaptor.getOperands().front()) {
511 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
512 newResultType, operand);
514 layout.dropSgLayoutAndData());
516 newBroadcastOps.push_back(newBroadcast.getResult());
518 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
525 WgToSgElementwiseOp(MLIRContext *ctx)
526 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
529 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
530 ConversionPatternRewriter &rewriter)
const override {
536 assert(resultType &&
"Expected result to be a VectorType");
538 ArrayRef<int64_t> wgShape = resultType.getShape();
540 xegpu::DistributeLayoutAttr layout =
542 if (!layout || !layout.isForWorkgroup())
545 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
547 size_t numVariants = operands.empty() ? 0 : operands.front().size();
549 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
550 return operandVec.size() != numVariants;
554 SmallVector<Value> newResults;
555 VectorType newResultType =
556 VectorType::get(sgShape, resultType.getElementType());
558 for (
size_t i = 0; i < numVariants; ++i) {
559 SmallVector<Value> opOperands;
560 for (
auto &operandVec : operands)
561 opOperands.push_back(operandVec[i]);
564 state.addOperands(opOperands);
565 state.addTypes(newResultType);
569 Operation *newOp = rewriter.create(state);
570 newResults.push_back(newOp->
getResult(0));
573 rewriter.replaceOpWithMultiple(op, {newResults});
604struct WgToSgConvertLayoutOp
605 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
606 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
608 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter)
const override {
611 auto input = op.getInputLayout();
612 auto target = op.getTargetLayout();
614 if (!input || !
target || !input.isForWorkgroup() ||
616 return rewriter.notifyMatchFailure(
617 op,
"Input and target layouts must have subgroup layout");
619 SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
620 SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
622 SmallVector<int64_t> targetSgLayout =
target.getEffectiveSgLayoutAsInt();
623 SmallVector<int64_t> targetSgData =
target.getEffectiveSgDataAsInt();
628 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
629 inputOrder != targetOrder)
632 input = input.dropSgLayoutAndData();
635 SmallVector<Value> newOps(adaptor.getSource());
638 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
639 auto newOp = xegpu::ConvertLayoutOp::create(
640 rewriter, op.getLoc(), src.getType(), src, input,
target);
644 rewriter.replaceOpWithMultiple(op, {newOps});
680struct UnrealizedConversionCastOpPattern
681 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
682 using OpConversionPattern<
683 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
686 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter)
const override {
690 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
691 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
693 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
694 !llvm::all_equal(
ValueRange(inputs).getTypes()))
702 if (op.getNumOperands() == 1 &&
703 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
704 rewriter.replaceOp(op, inputs);
715 if (op.getNumResults() == 1 &&
717 rewriter.replaceOpWithMultiple(op, {inputs});
721 return mlir::failure();
726struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
727 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
730 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter)
const override {
732 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
733 auto vecType = dyn_cast<VectorType>(op.getType());
734 if (!vecAttr || !vecType)
737 xegpu::DistributeLayoutAttr layout =
739 if (!layout || !layout.isForWorkgroup())
742 ArrayRef<int64_t> wgShape = vecType.getShape();
743 SmallVector<int64_t> sgShape;
745 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
747 auto newType = VectorType::get(sgShape, vecType.getElementType());
748 Location loc = op.getLoc();
749 auto eltType = vecType.getElementType();
751 auto setLayout = [&](Value val) {
753 layout.dropSgLayoutAndData());
756 if (vecAttr.isSplat()) {
758 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
760 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
761 setLayout(cstOp->getResult(0));
762 rewriter.replaceOp(op, cstOp);
764 }
else if (sgShape == wgShape) {
767 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
768 setLayout(newConstOp->getResult(0));
769 rewriter.replaceOp(op, newConstOp);
775 if (!eltType.isIndex())
776 return rewriter.notifyMatchFailure(
777 op,
"Unsupported element type for non-splat constant op.");
779 if (wgShape.size() > 2)
780 return rewriter.notifyMatchFailure(
781 op,
"Only 1D & 2D vector constant supported");
783 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
784 int64_t rowStride = 0, colStride = 0;
785 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
786 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
790 colStride = cast<IntegerAttr>(values[1]).getInt() -
791 cast<IntegerAttr>(values[0]).getInt();
794 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
795 cast<IntegerAttr>(values[0]).getInt();
798 for (int64_t r = 0; r < rows; ++r) {
799 for (int64_t c = 0; c < cols; ++c) {
800 int64_t idx = r * cols + c;
802 if (c > 0 && cols > 1) {
803 int64_t prevIdx = r * cols + (c - 1);
804 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
805 cast<IntegerAttr>(values[prevIdx]).getInt();
806 if (diff != colStride)
807 return rewriter.notifyMatchFailure(
808 op,
"Non-constant column stride in constant op.");
811 if (r > 0 && rows > 1) {
812 int64_t prevIdx = (r - 1) * cols + c;
813 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
814 cast<IntegerAttr>(values[prevIdx]).getInt();
815 if (diff != rowStride)
816 return rewriter.notifyMatchFailure(
817 op,
"Non-constant row stride in constant op.");
825 SmallVector<Attribute> baseTileValues;
826 int baseTileCols = sgShape[sgShape.size() - 1];
827 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
828 for (int64_t r = 0; r < baseTileRows; ++r) {
829 for (int64_t c = 0; c < baseTileCols; ++c) {
830 baseTileValues.push_back(values[r * cols + c]);
836 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
840 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
842 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
846 SmallVector<Value, 2> strideConsts;
847 strideConsts.push_back(
851 strideConsts.begin(),
854 SmallVector<Value> newConstOps;
855 for (
auto offsets : *sgOffsets) {
858 for (
size_t i = 0; i < strideConsts.size(); ++i) {
860 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
861 offsets[i], strideConsts[i]);
862 mulOffset = arith::AddIOp::create(
863 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
866 auto bcastOffset = vector::BroadcastOp::create(
867 rewriter, loc, baseConstVec.getType(), mulOffset);
869 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
870 setLayout(baseConstVec);
871 setLayout(bcastOffset);
872 setLayout(finalConst);
873 newConstOps.push_back(finalConst);
875 rewriter.replaceOpWithMultiple(op, {newConstOps});
883struct WgToSgLoadGatherOpWithOffset
884 :
public OpConversionPattern<xegpu::LoadGatherOp> {
885 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
887 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
888 ConversionPatternRewriter &rewriter)
const override {
890 if (!op.getOffsets())
893 Location loc = op.getLoc();
894 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
897 ArrayRef<int64_t> wgShape = resultType.getShape();
899 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
901 if (!layout || !layout.isForWorkgroup())
904 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
907 auto offsetsVecType =
908 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
910 dyn_cast<VectorType>(adaptor.getMask().front().getType());
911 if (!offsetsVecType || !maskVecType ||
912 offsetsVecType.getShape() != maskVecType.getShape()) {
913 return rewriter.notifyMatchFailure(op,
914 "offsets have not been distributed");
917 SmallVector<Value> newLoadOps;
919 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
920 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
921 for (
auto [offsets, mask] :
922 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
923 auto newLayout = layout.dropSgLayoutAndData();
924 auto newLoadOp = xegpu::LoadGatherOp::create(
925 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
926 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
928 newLoadOp.setAnchorLayout(newLayout);
929 newLoadOps.push_back(newLoadOp);
931 rewriter.replaceOpWithMultiple(op, {newLoadOps});
938struct WgToSgStoreScatterOpWithOffset
939 :
public OpConversionPattern<xegpu::StoreScatterOp> {
940 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
942 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
943 ConversionPatternRewriter &rewriter)
const override {
945 if (!op.getOffsets())
948 Location loc = op.getLoc();
949 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
955 if (!layout || !layout.isForWorkgroup())
959 auto offsetsVecType =
960 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
962 dyn_cast<VectorType>(adaptor.getMask().front().getType());
963 if (!offsetsVecType || !maskVecType ||
964 offsetsVecType.getShape() != maskVecType.getShape()) {
965 return rewriter.notifyMatchFailure(op,
966 "offsets have not been distributed");
969 auto chunkSizeOpt = op.getChunkSize();
970 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
971 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
972 for (
auto [val, offs, mask] : llvm::zip(
973 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
974 auto store = xegpu::StoreScatterOp::create(
975 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
976 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
977 layout.dropSgLayoutAndData());
979 for (OpOperand &operand : store->getOpOperands()) {
981 if (operand.getOperandNumber() == 1)
986 rewriter.eraseOp(op);
991struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
992 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
994 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
995 ConversionPatternRewriter &rewriter)
const override {
997 SmallVector<SmallVector<OpFoldResult>> offsetsList;
998 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1001 ArrayRef<int64_t> wgShape = op.getDataShape();
1002 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1003 assert(valueTy &&
"the value type must be vector type!");
1004 Type elemTy = valueTy.getElementType();
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1007 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1008 VectorType newResTy = VectorType::get(sgShape, elemTy);
1009 SmallVector<Value> newOps;
1010 for (
auto offsets : offsetsList) {
1011 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1012 op.getMemDesc(), offsets,
1013 layout.dropSgLayoutAndData());
1014 newOps.push_back(newOp);
1016 rewriter.replaceOpWithMultiple(op, {newOps});
1022struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1023 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1025 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1026 ConversionPatternRewriter &rewriter)
const override {
1028 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1029 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1032 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1033 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1034 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1035 offsets, layout.dropSgLayoutAndData());
1036 rewriter.eraseOp(op);
1042struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1043 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1045 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1046 ConversionPatternRewriter &rewriter)
const override {
1047 xegpu::DistributeLayoutAttr layout =
1049 if (!layout || !layout.isForWorkgroup())
1052 Location loc = op.getLoc();
1053 VectorType type = op.getResult().getType();
1054 auto wgShape = type.getShape();
1055 std::optional<SmallVector<int64_t>> sgShape =
1056 getSgShapeAndCount(wgShape, layout).first;
1061 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1063 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1067 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1068 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1069 SmallVector<Value> newOps;
1070 for (
auto offsets : *sgOffsets) {
1073 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1075 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1077 layout.dropSgLayoutAndData());
1079 layout.dropSgLayoutAndData());
1081 layout.dropSgLayoutAndData());
1082 newOps.push_back(finalSteps);
1085 rewriter.replaceOpWithMultiple(op, {newOps});
1091struct WgToSgVectorShapeCastOp
1092 :
public OpConversionPattern<vector::ShapeCastOp> {
1093 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1096 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1097 ConversionPatternRewriter &rewriter)
const override {
1099 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1103 ArrayRef<int64_t> wgShape = resultType.getShape();
1104 xegpu::DistributeLayoutAttr layout =
1106 if (!layout || !layout.isForWorkgroup())
1111 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1115 ArrayRef<int64_t> srcShape = srcType.getShape();
1116 llvm::SetVector<int64_t> expandedUnitDims;
1120 auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
1121 ArrayRef<int64_t> dst) ->
bool {
1125 for (
size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
1126 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
1128 else if (dst[dstIdx] == 1)
1129 expandedUnitDims.insert(dstIdx);
1132 return srcIdx == src.size();
1134 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1136 if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
1137 xegpu::DistributeLayoutAttr sourceLayout =
1140 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1141 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1142 return isa<vector::BroadcastOp>(user);
1146 if (!usedByBroadcastOp(op))
1147 return rewriter.notifyMatchFailure(
1148 op,
"ShapeCast ops that expand unit dimensions and are used by "
1149 "non-broadcast operations are not supported.");
1151 if (!sourceLayout.isSliceOf(layout))
1152 return rewriter.notifyMatchFailure(
1153 op,
"The ShapeCast op only expands dimensions, the result layout "
1154 "must be a slice of the input layout, or vice versa.");
1155 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1156 layoutToDistribute =
1157 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1160 SmallVector<int64_t> sgShape =
1161 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1162 VectorType newResultType =
1163 VectorType::get(sgShape, resultType.getElementType());
1165 SmallVector<Value> newShapeCastOps;
1166 for (
auto src : adaptor.getSource()) {
1167 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1168 newResultType, src);
1170 layout.dropSgLayoutAndData());
1171 newShapeCastOps.push_back(newShapeCast.getResult());
1174 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1179static Value createAccumulator(ConversionPatternRewriter &rewriter,
1181 vector::CombiningKind kind) {
1182 Type elemTy = type.getElementType();
1185 case vector::CombiningKind::ADD:
1186 case vector::CombiningKind::XOR:
1187 case vector::CombiningKind::OR:
1188 return arith::ConstantOp::create(
1189 rewriter, loc, type,
1192 case vector::CombiningKind::MUL:
1193 case vector::CombiningKind::AND:
1194 return arith::ConstantOp::create(
1195 rewriter, loc, type,
1198 case vector::CombiningKind::MINSI:
1200 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1201 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1202 return arith::ConstantOp::create(
1203 rewriter, loc, type,
1205 rewriter.getIntegerAttr(elemTy, maxVal)));
1209 case vector::CombiningKind::MINUI:
1210 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1211 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1212 return arith::ConstantOp::create(
1213 rewriter, loc, type,
1215 rewriter.getIntegerAttr(elemTy, maxVal)));
1219 case vector::CombiningKind::MAXSI:
1220 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1221 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1222 return arith::ConstantOp::create(
1223 rewriter, loc, type,
1225 rewriter.getIntegerAttr(elemTy, minVal)));
1229 case vector::CombiningKind::MAXUI:
1230 return arith::ConstantOp::create(
1231 rewriter, loc, type,
1234 case vector::CombiningKind::MINNUMF:
1235 case vector::CombiningKind::MINIMUMF:
1237 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1238 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1239 return arith::ConstantOp::create(
1240 rewriter, loc, type,
1245 case vector::CombiningKind::MAXNUMF:
1246 case vector::CombiningKind::MAXIMUMF:
1248 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1249 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(),
true);
1250 return arith::ConstantOp::create(
1251 rewriter, loc, type,
1286static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
1294 Value dimVal = sgIds[dim];
1296 Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1298 arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
1299 stride *= sgLayout[dim];
1302 return linearizedOffset;
1337struct WgToSgMultiDimReductionOp
1338 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1339 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1342 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1343 ConversionPatternRewriter &rewriter)
const override {
1344 Location loc = op.getLoc();
1346 VectorType srcType = op.getSourceVectorType();
1347 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1351 auto originalSrcShape = srcType.getShape();
1352 xegpu::DistributeLayoutAttr layout =
1354 if (!layout || !layout.isForWorkgroup())
1357 auto reductionDims = llvm::to_vector(op.getReductionDims());
1360 SmallVector<int64_t> sgLayout;
1361 SmallVector<int64_t> sgData;
1362 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1363 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1364 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1366 return rewriter.notifyMatchFailure(
1367 op,
"Reduction should have SliceAttr layout");
1369 Type elemTy = dstType.getElementType();
1372 SmallVector<Value> localReductions;
1373 SmallVector<int64_t> sgShape =
1374 getSgShapeAndCount(originalSrcShape, layout).first;
1375 VectorType newDstType = VectorType::get(sgShape, elemTy);
1376 for (
auto sgSrc : adaptor.getSource()) {
1378 auto neutralLocalAcc =
1379 createAccumulator(rewriter, loc, newDstType, op.getKind());
1381 auto localReduce = vector::MultiDimReductionOp::create(
1382 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1384 localReductions.push_back(localReduce.getResult());
1388 SmallVector<int64_t> crossSgReductionDims;
1389 for (int64_t reductionDim : reductionDims) {
1390 bool needsCrossSubgroupReduction =
1391 (sgLayout[reductionDim] > 1) &&
1392 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1394 if (needsCrossSubgroupReduction) {
1395 crossSgReductionDims.push_back(reductionDim);
1400 if (crossSgReductionDims.empty()) {
1401 SmallVector<Value> results;
1402 for (
auto localResult : localReductions) {
1404 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1405 if (
auto defOp = finalResult.getDefiningOp())
1407 layout.dropSgLayoutAndData());
1408 results.push_back(finalResult);
1410 rewriter.replaceOpWithMultiple(op, {results});
1420 SmallVector<int64_t> storeShape2D = {1, localElements};
1421 VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
1422 auto storeShapeCast = vector::ShapeCastOp::create(
1423 rewriter, loc, storeType2D, localReductions[0]);
1424 Value storeData = storeShapeCast.getResult();
1428 int64_t totalReductionSubgroups = 1;
1429 for (int64_t dim : crossSgReductionDims) {
1430 totalReductionSubgroups *= sgLayout[dim];
1434 int64_t totalResultElements =
1435 localElements *
computeProduct(sgLayout) / totalReductionSubgroups;
1437 SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
1438 totalResultElements};
1442 auto bytesPerElement = bitWidth / 8;
1443 int64_t slmElements = slmShape2D[0] * slmShape2D[1];
1444 auto slmSize = slmElements * bytesPerElement;
1445 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1446 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1448 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
1449 slmShape2D, elemTy,
nullptr);
1451 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1454 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1455 rewriter.getIndexType(),
nullptr);
1458 SmallVector<Value> sgLayoutValues;
1459 for (int64_t dim : sgLayout)
1460 sgLayoutValues.push_back(
1463 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1467 SmallVector<Value> sgIds = *sgIdsResult;
1470 Value rowOffsetStore = linearizeSubgroupIndices(
1471 rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
1474 SmallVector<int64_t> nonReductionDims;
1475 for (
size_t i = 0; i < sgLayout.size(); ++i) {
1476 if (!llvm::is_contained(reductionDims,
static_cast<int64_t
>(i))) {
1477 nonReductionDims.push_back(
static_cast<int64_t
>(i));
1481 Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
1482 nonReductionDims, sgLayout);
1484 Value localElementsVal =
1487 arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
1489 SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
1491 auto storeMatrixLayout = xegpu::SliceAttr::get(
1492 rewriter.getContext(),
1493 xegpu::LayoutAttr::get(rewriter.getContext(),
nullptr,
1497 dyn_cast<xegpu::SliceAttr>(layout).getDims());
1498 xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
1499 storeOffsets2D, storeMatrixLayout);
1501 gpu::BarrierOp::create(rewriter, loc);
1504 SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
1505 VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
1511 SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
1513 auto loadOp = xegpu::LoadMatrixOp::create(
1514 rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
1518 SmallVector<int64_t> finalReductionDims = {0};
1519 SmallVector<int64_t> finalResultShape = {localElements};
1520 VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
1522 auto neutralFinalAcc =
1523 createAccumulator(rewriter, loc, finalResultType, op.getKind());
1525 auto finalReduce = vector::MultiDimReductionOp::create(
1526 rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
1527 neutralFinalAcc, finalReductionDims);
1530 Value originalAcc = adaptor.getAcc()[0];
1531 Value accToAdd = originalAcc;
1534 if (originalAcc.
getType() != finalReduce.getResult().getType()) {
1535 auto originalAccType = cast<VectorType>(originalAcc.
getType());
1536 auto finalResultType =
1537 cast<VectorType>(finalReduce.getResult().getType());
1540 if (originalAccType.getNumElements() ==
1541 finalResultType.getNumElements()) {
1542 auto shapeCast = vector::ShapeCastOp::create(
1543 rewriter, loc, finalResultType, originalAcc);
1544 accToAdd = shapeCast.getResult();
1549 rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
1551 if (
auto defOp = finalResult.getDefiningOp())
1553 layout.dropSgLayoutAndData());
1555 rewriter.replaceOp(op, finalResult);
1561struct WgToSgVectorTransposeOp
1562 :
public OpConversionPattern<vector::TransposeOp> {
1563 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1566 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1567 ConversionPatternRewriter &rewriter)
const override {
1568 VectorType resultType = op.getResultVectorType();
1570 ArrayRef<int64_t> wgShape = resultType.getShape();
1571 xegpu::DistributeLayoutAttr layout =
1573 if (!layout || !layout.isForWorkgroup())
1576 xegpu::DistributeLayoutAttr sourceLayout =
1578 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1581 SmallVector<int64_t> sourceSgLayout =
1582 sourceLayout.getEffectiveSgLayoutAsInt();
1583 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1587 if (!sourceOrder || !resultOrder) {
1588 return rewriter.notifyMatchFailure(
1589 op,
"Both source and result must have order attributes");
1592 ArrayRef<int64_t> permutation = op.getPermutation();
1593 size_t permutationSize = permutation.size();
1594 if (sourceSgLayout.size() != permutationSize ||
1595 resultSgLayout.size() != permutationSize) {
1596 return rewriter.notifyMatchFailure(
1597 op,
"Layouts and permutation must have the same rank");
1602 if (!layout.isTransposeOf(sourceLayout, permutation))
1603 return rewriter.notifyMatchFailure(
1604 op,
"Result layout is not a valid transpose of source layout "
1605 "according to permutation");
1607 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1608 VectorType newResultType =
1609 VectorType::get(sgShape, resultType.getElementType());
1610 SmallVector<Value> newTransposeOps;
1611 for (
auto src : adaptor.getVector()) {
1612 auto newTranspose = vector::TransposeOp::create(
1613 rewriter, op.getLoc(), newResultType, src, permutation);
1615 layout.dropSgLayoutAndData());
1616 newTransposeOps.push_back(newTranspose.getResult());
1619 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1625template <
typename MaskOpType>
1626struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1627 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1629 LogicalResult matchAndRewrite(
1631 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1632 ConversionPatternRewriter &rewriter)
const override {
1633 xegpu::DistributeLayoutAttr layout =
1635 if (!layout || !layout.isForWorkgroup())
1638 Location loc = op.getLoc();
1639 VectorType type = op.getResult().getType();
1640 auto wgShape = type.getShape();
1642 SmallVector<Value> wgMaskDimSizes;
1643 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1644 for (int64_t maskSize : op.getMaskDimSizes()) {
1645 wgMaskDimSizes.push_back(
1648 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1649 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1653 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1655 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1659 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1660 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1664 SmallVector<Value> newCreateMaskOps;
1665 for (
auto offsetSet : *sgOffsets) {
1666 SmallVector<Value> maskOperands;
1668 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1671 Value offset = offsetSet[i];
1672 Value adjustedMaskSize =
1673 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1676 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1678 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1679 maskOperands.push_back(sgMaskSize);
1682 auto newCreateMaskOp =
1683 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1685 layout.dropSgLayoutAndData());
1686 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1689 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1694using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1695using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1702 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1703 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1704 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1705 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1706 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1707 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1708 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1709 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1710 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1711 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1718struct XeGPUWgToSgDistributePass
1719 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1720 void runOnOperation()
override;
1724void XeGPUWgToSgDistributePass::runOnOperation() {
1734 SmallVector<Operation *> existingCastOps;
1735 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1736 existingCastOps.push_back(castOp.getOperation());
1746 TypeConverter converter;
1747 converter.addConversion([&](Type type) -> Type {
return type; });
1748 converter.addConversion(
1749 [&](RankedTensorType type,
1750 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1751 Type elemTy = type.getElementType();
1752 ArrayRef<int64_t> shape = type.getShape();
1755 SmallVector<int64_t> subShape;
1756 std::tie(subShape, count) = getSgShapeAndCount(
1758 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1760 auto newTy = VectorType::get(subShape, elemTy);
1761 result.append(count, newTy);
1773 ConversionTarget
target(*ctx);
1774 TypeConverter converter;
1775 converter.addConversion([&](Type type) -> Type {
return type; });
1776 converter.addConversion(
1777 [&](xegpu::TensorDescType type,
1778 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1779 Type elemTy = type.getElementType();
1780 ArrayRef<int64_t> shape = type.getShape();
1783 SmallVector<int64_t> subShape;
1784 xegpu::LayoutAttr layout = type.getLayoutAttr();
1785 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1788 layout = layout.dropSgLayoutAndData();
1790 auto newTy = xegpu::TensorDescType::get(
1791 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1792 result.append(count, newTy);
1796 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1797 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1798 return createOp.getType();
1799 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1800 return loadOp.getTensorDescType();
1801 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1802 return storeOp.getTensorDescType();
1803 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1805 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1806 return prefetchOp.getTensorDescType();
1807 return xegpu::TensorDescType();
1810 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1811 return !layout || !layout.isForWorkgroup();
1814 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1815 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1816 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1817 auto tdescTy = getTensorDescType(op);
1818 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1819 return isLegal(layout);
1822 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1823 auto layout = op.getLayoutCdAttr();
1824 return isLegal(layout);
1827 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1828 [=](xegpu::LoadMatrixOp op) ->
bool {
1829 return isLegal(op.getLayoutAttr());
1832 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1833 [=](xegpu::StoreMatrixOp op) ->
bool {
1834 return isLegal(op.getLayoutAttr());
1837 target.addDynamicallyLegalOp<arith::ConstantOp>(
1838 [=](arith::ConstantOp op) ->
bool {
1839 auto vecType = dyn_cast<VectorType>(op.getType());
1845 return isLegal(layout);
1848 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1849 vector::TransposeOp, vector::BroadcastOp,
1850 vector::MultiDimReductionOp,
1851 vector::ConstantMaskOp, vector::CreateMaskOp>(
1852 [=](Operation *op) ->
bool {
1856 return isLegal(layout);
1859 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1860 [=](xegpu::LoadGatherOp op) ->
bool {
1861 auto layout = op.getLayoutAttr();
1862 return isLegal(layout);
1865 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1866 [=](xegpu::StoreScatterOp op) ->
bool {
1867 auto layout = op.getLayoutAttr();
1868 return isLegal(layout);
1871 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1872 [=](xegpu::ConvertLayoutOp op) ->
bool {
1873 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1876 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1877 [=](Operation *op) -> std::optional<bool> {
1882 VectorType resultType =
1883 dyn_cast<VectorType>(op->getResult(0).getType());
1889 for (Value operand : op->getOperands()) {
1890 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1891 if (!operandType || operandType.getShape() != resultType.getShape()) {
1896 xegpu::DistributeLayoutAttr layout =
1898 return isLegal(layout);
1901 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1902 [=](UnrealizedConversionCastOp op) {
1903 return llvm::is_contained(existingCastOps, op.getOperation());
1906 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1912 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
1913 return signalPassFailure();
1920 getOperation()->walk([](Operation *op) {
1923 if (
auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
1925 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1926 if (
auto newLayout = layout.dropSgLayoutAndData())
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.