29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 if (layout && layout.isForWorkgroup()) {
57 if (!layout.getEffectiveSgDataAsInt().empty())
58 sgShape = layout.getEffectiveSgDataAsInt();
60 sgShape = *maybeDerivedSgData;
65 for (
size_t i = 0; i < distUnit.size(); ++i)
66 distUnit[i] = std::min(
shape[i], distUnit[i]);
69 return std::make_pair(sgShape, count);
78 typename = std::enable_if_t<llvm::is_one_of<
79 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
80 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
82genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
87 if (origOffsets.empty())
91 xegpu::DistributeLayoutAttr layout;
92 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
93 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
94 layout = op.getLayoutAttr();
96 layout = op.getDescLayoutAttr();
100 if (!layout || !layout.isForWorkgroup())
104 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
107 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
109 int64_t startOfRange = sgIdRange.getStart().getInt();
110 int64_t endOfRange = sgIdRange.getEnd().getInt();
112 if (layout.getNumSubgroups() != endOfRange - startOfRange)
113 return rewriter.notifyMatchFailure(
114 op,
"sg_layout size must match the sg_id_range");
116 if (startOfRange > 0) {
117 Value startOfRangeVal =
119 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
126 auto maybeDescOffsets =
127 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
128 if (
failed(maybeDescOffsets))
133 for (
const auto &sgOffsets : *maybeDescOffsets) {
136 offsetsList.push_back(std::move(newOffsets));
188struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
189 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
192 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter)
const override {
194 SmallVector<SmallVector<OpFoldResult>> offsetsList;
195 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
198 MLIRContext *ctx = op.getContext();
199 xegpu::TensorDescType tdescTy = op.getType();
200 ArrayRef<int64_t> wgShape = tdescTy.getShape();
201 Type elemTy = tdescTy.getElementType();
202 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
203 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
205 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
206 layout.dropSgLayoutAndData());
208 SmallVector<Value> newOps;
209 for (
auto offsets : offsetsList) {
210 auto newOp = xegpu::CreateNdDescOp::create(
211 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
212 op.getMixedSizes(), op.getMixedStrides());
214 newOps.push_back(newOp);
216 rewriter.replaceOpWithMultiple(op, {newOps});
224struct WgToSgCreateNdOpNoOffset
225 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
226 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
229 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter)
const override {
233 if (!op.getMixedOffsets().empty())
236 Location loc = op.getLoc();
237 MLIRContext *ctx = op.getContext();
238 xegpu::TensorDescType tdescTy = op.getType();
239 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
240 if (!layout || !layout.isForWorkgroup())
243 Type elemTy = tdescTy.getElementType();
244 ArrayRef<int64_t> wgShape = tdescTy.getShape();
246 SmallVector<int64_t> sgShape;
248 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
249 xegpu::TensorDescType newTdescTy =
250 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
251 layout.dropSgLayoutAndData());
253 SmallVector<Value> newCreateNdOps(count);
254 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
255 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
256 op.getSource(), op.getMixedSizes(),
257 op.getMixedStrides());
260 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
266struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
267 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
269 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter)
const override {
271 if (!op.getMixedOffsets().empty())
274 SmallVector<Value> newLoadOps;
275 for (
auto src : adaptor.getTensorDesc()) {
276 xegpu::TensorDescType tdescTy =
277 dyn_cast<xegpu::TensorDescType>(src.getType());
278 ArrayRef<int64_t> srcShape = tdescTy.getShape();
279 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
280 auto newLoadOp = xegpu::LoadNdOp::create(
281 rewriter, op.getLoc(), newResTy, src,
283 newLoadOps.push_back(newLoadOp);
285 rewriter.replaceOpWithMultiple(op, {newLoadOps});
286 return mlir::success();
293struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
294 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
296 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter)
const override {
298 if (!op.getMixedOffsets().empty())
301 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
302 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
305 rewriter.eraseOp(op);
312struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
313 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
315 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter)
const override {
318 SmallVector<SmallVector<OpFoldResult>> offsetsList;
319 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
322 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
324 layout = layout.dropSgLayoutAndData();
325 SmallVector<Value> newOps;
326 for (
auto [tdesc, offsets] :
327 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
328 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
329 VectorType newResTy =
330 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
331 auto newOp = xegpu::LoadNdOp::create(
332 rewriter, op.getLoc(), newResTy, tdesc, offsets,
333 nullptr,
nullptr, op.getL1HintAttr(),
334 op.getL2HintAttr(), op.getL3HintAttr(), layout);
335 newOps.push_back(newOp);
337 rewriter.replaceOpWithMultiple(op, {newOps});
345struct WgToSgStoreNdOpWithOffset
346 :
public OpConversionPattern<xegpu::StoreNdOp> {
347 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
349 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter)
const override {
351 SmallVector<SmallVector<OpFoldResult>> offsetsList;
352 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
355 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
357 layout = layout.dropSgLayoutAndData();
358 for (
auto [v, tdesc, offsets] :
359 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
360 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
361 op.getL1HintAttr(), op.getL2HintAttr(),
362 op.getL3HintAttr(), layout);
364 rewriter.eraseOp(op);
372struct WgToSgPrefetchNdOpWithOffset
373 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
374 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
376 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
377 ConversionPatternRewriter &rewriter)
const override {
378 SmallVector<SmallVector<OpFoldResult>> offsetsList;
379 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
382 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
384 layout = layout.dropSgLayoutAndData();
385 for (
auto [tdesc, offsets] :
386 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
387 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
388 op.getL1HintAttr(), op.getL2HintAttr(),
389 op.getL3HintAttr(), layout);
391 rewriter.eraseOp(op);
400struct WgToSgUpdateNdOffsetOp
401 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
402 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
404 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
406 llvm::SmallVector<Value> newUpdateTileOffsetOps;
407 for (
auto tDesc : adaptor.getTensorDesc()) {
408 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
409 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
410 op.getConstOffsets());
411 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
414 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
420struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
421 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
423 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 Location loc = op.getLoc();
426 VectorType resultTy = op.getResult().getType();
427 if (resultTy.getRank() != 2)
430 auto layoutCd = op.getLayoutCdAttr();
431 auto layoutA = op.getLayoutAAttr();
432 auto layoutB = op.getLayoutBAttr();
433 if (!layoutCd || !layoutA || !layoutB)
436 SmallVector<Value> newDpasOps;
437 for (
auto aVec : adaptor.getLhs()) {
438 for (
auto bVec : adaptor.getRhs()) {
440 llvm::SmallVector<Value> operands({aVec, bVec});
443 tmpC = adaptor.getAcc()[i++];
444 operands.push_back(tmpC);
447 ArrayRef<int64_t> aVecShape =
448 llvm::cast<VectorType>(aVec.getType()).getShape();
449 ArrayRef<int64_t> bVecShape =
450 llvm::cast<VectorType>(bVec.getType()).getShape();
451 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
452 resultTy.getElementType());
453 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
454 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
455 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
456 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
458 newDpasOps.push_back(newDpasOp);
461 rewriter.replaceOpWithMultiple(op, {newDpasOps});
467struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
468 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
470 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
473 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
474 if ((offsetSize != 0) || op.getConstOffsetsAttr())
477 for (
auto src : adaptor.getTensorDesc())
478 xegpu::PrefetchNdOp::create(
481 rewriter.eraseOp(op);
487struct WgToSgVectorBroadcastOp
488 :
public OpConversionPattern<vector::BroadcastOp> {
489 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
492 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter)
const override {
495 VectorType resultType = op.getResult().getType();
496 ArrayRef<int64_t> wgShape = resultType.getShape();
498 xegpu::DistributeLayoutAttr layout =
500 if (!layout || !layout.isForWorkgroup())
503 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
504 VectorType newResultType =
505 VectorType::get(sgShape, resultType.getElementType());
507 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
510 SmallVector<Value> newBroadcastOps;
511 for (
auto operand : adaptor.getOperands().front()) {
512 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
513 newResultType, operand);
515 newBroadcastOps.push_back(newBroadcast.getResult());
517 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
523struct WgToSgElementwiseOp :
public ConversionPattern {
525 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
529 ConversionPatternRewriter &rewriter)
const override {
535 assert(resultType &&
"Expected result to be a VectorType");
539 xegpu::DistributeLayoutAttr layout =
541 if (!layout || !layout.isForWorkgroup())
546 size_t numVariants = operands.empty() ? 0 : operands.front().size();
548 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
549 return operandVec.size() != numVariants;
554 VectorType newResultType =
555 VectorType::get(sgShape, resultType.getElementType());
557 for (
size_t i = 0; i < numVariants; ++i) {
559 for (
auto &operandVec : operands)
560 opOperands.push_back(operandVec[i]);
566 Operation *newOp = rewriter.create(state);
568 newResults.push_back(newOp->
getResult(0));
571 rewriter.replaceOpWithMultiple(op, {newResults});
602struct WgToSgConvertLayoutOp
603 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
604 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
608 ConversionPatternRewriter &rewriter)
const override {
611 VectorType resultType = op.getResult().getType();
613 auto inputLayout = op.getInputLayout();
614 auto targetLayout = op.getTargetLayout();
616 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
617 !targetLayout.isForWorkgroup())
618 return rewriter.notifyMatchFailure(
619 op,
"Input and target layouts must have subgroup layout");
622 inputLayout.getEffectiveSgLayoutAsInt();
625 targetLayout.getEffectiveSgLayoutAsInt();
629 if (inputLayout.isCompatibleWith(targetLayout,
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
635 if (inputLayout && targetLayout) {
636 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
642 rewriter.replaceOpWithMultiple(op, {newOps});
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
653 auto bytesPerElement = bitWidth / 8;
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(),
nullptr);
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
671 if (failed(storeCoords))
675 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
677 for (
Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets,
nullptr );
684 gpu::BarrierOp::create(rewriter, loc);
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
689 if (failed(loadCoords))
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
696 for (
auto coords : *loadCoords) {
698 for (
Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
705 finalResults.push_back(loadOp.getResult());
708 rewriter.replaceOpWithMultiple(op, {finalResults});
744struct UnrealizedConversionCastOpPattern
745 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter)
const override {
754 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(
ValueRange(inputs).getTypes()))
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
779 if (op.getNumResults() == 1 &&
781 rewriter.replaceOpWithMultiple(op, {inputs});
785 return mlir::failure();
790struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter)
const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
801 xegpu::DistributeLayoutAttr layout =
803 if (!layout || !layout.isForWorkgroup())
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
815 if (vecAttr.isSplat()) {
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
822 }
else if (sgShape == wgShape) {
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op,
"Unsupported element type for non-splat constant op.");
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op,
"Only 1D & 2D vector constant supported");
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op,
"Non-constant column stride in constant op.");
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op,
"Non-constant row stride in constant op.");
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
897 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
908 strideConsts.begin(),
911 SmallVector<Value> newConstOps;
912 for (
auto offsets : *sgOffsets) {
915 for (
size_t i = 0; i < strideConsts.size(); ++i) {
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
937struct WgToSgLoadGatherOpWithOffset
938 :
public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter)
const override {
944 if (!op.getOffsets())
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
951 ArrayRef<int64_t> wgShape = resultType.getShape();
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
955 if (!layout || !layout.isForWorkgroup())
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
971 SmallVector<Value> newLoadOps;
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (
auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
982 newLoadOps.push_back(newLoadOp);
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
991struct WgToSgStoreScatterOpWithOffset
992 :
public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter)
const override {
998 if (!op.getOffsets())
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1008 if (!layout || !layout.isForWorkgroup())
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (
auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1032 rewriter.eraseOp(op);
1037struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter)
const override {
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy &&
"the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (
auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1068struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter)
const override {
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1088struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter)
const override {
1093 xegpu::DistributeLayoutAttr layout =
1095 if (!layout || !layout.isForWorkgroup())
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1107 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (
auto offsets : *sgOffsets) {
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1131struct WgToSgVectorShapeCastOp
1132 :
public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter)
const override {
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1146 if (!layout || !layout.isForWorkgroup())
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1160 xegpu::DistributeLayoutAttr sourceLayout =
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op,
"ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op,
"The ShapeCast op only expands dimensions, the result layout "
1177 "must be a slice of the input layout, or vice versa.");
1178 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1179 layoutToDistribute =
1180 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1183 SmallVector<int64_t> sgShape =
1184 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1185 VectorType newResultType =
1186 VectorType::get(sgShape, resultType.getElementType());
1188 SmallVector<Value> newShapeCastOps;
1189 for (
auto src : adaptor.getSource()) {
1190 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1191 newResultType, src);
1192 newShapeCastOps.push_back(newShapeCast.getResult());
1195 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1200static Value createAccumulator(ConversionPatternRewriter &rewriter,
1202 vector::CombiningKind kind) {
1203 Type elemTy = type.getElementType();
1206 case vector::CombiningKind::ADD:
1207 case vector::CombiningKind::XOR:
1208 case vector::CombiningKind::OR:
1209 return arith::ConstantOp::create(
1210 rewriter, loc, type,
1213 case vector::CombiningKind::MUL:
1214 case vector::CombiningKind::AND:
1215 return arith::ConstantOp::create(
1216 rewriter, loc, type,
1219 case vector::CombiningKind::MINSI:
1221 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1222 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1223 return arith::ConstantOp::create(
1224 rewriter, loc, type,
1226 rewriter.getIntegerAttr(elemTy, maxVal)));
1230 case vector::CombiningKind::MINUI:
1231 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1232 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1233 return arith::ConstantOp::create(
1234 rewriter, loc, type,
1236 rewriter.getIntegerAttr(elemTy, maxVal)));
1240 case vector::CombiningKind::MAXSI:
1241 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1242 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1243 return arith::ConstantOp::create(
1244 rewriter, loc, type,
1246 rewriter.getIntegerAttr(elemTy, minVal)));
1250 case vector::CombiningKind::MAXUI:
1251 return arith::ConstantOp::create(
1252 rewriter, loc, type,
1255 case vector::CombiningKind::MINNUMF:
1256 case vector::CombiningKind::MINIMUMF:
1258 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1259 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1260 return arith::ConstantOp::create(
1261 rewriter, loc, type,
1266 case vector::CombiningKind::MAXNUMF:
1267 case vector::CombiningKind::MAXIMUMF:
1269 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1270 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(),
true);
1271 return arith::ConstantOp::create(
1272 rewriter, loc, type,
1312struct WgToSgMultiDimReductionOp
1313 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1314 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1317 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1318 ConversionPatternRewriter &rewriter)
const override {
1319 Location loc = op.getLoc();
1321 VectorType srcType = op.getSourceVectorType();
1322 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1326 auto originalSrcShape = srcType.getShape();
1327 auto originalDstShape = dstType.getShape();
1328 int srcVecRank = originalSrcShape.size();
1330 xegpu::DistributeLayoutAttr layout =
1332 if (!layout || !layout.isForWorkgroup())
1335 auto reductionDims = llvm::to_vector(op.getReductionDims());
1338 SmallVector<int64_t> sgLayout;
1339 SmallVector<int64_t> sgData;
1340 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1341 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1342 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1344 return rewriter.notifyMatchFailure(
1345 op,
"Reduction should have SliceAttr layout");
1347 Type elemTy = dstType.getElementType();
1350 SmallVector<Value> localReductions;
1351 SmallVector<int64_t> sgDstShape =
1352 getSgShapeAndCount(originalDstShape, layout).first;
1353 auto sgSrcs = adaptor.getSource();
1354 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1355 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1356 sgSrcType.getShape().end());
1358 VectorType newDstType = VectorType::get(sgDstShape, elemTy);
1359 for (
auto sgSrc : sgSrcs) {
1361 auto neutralLocalAcc =
1362 createAccumulator(rewriter, loc, newDstType, op.getKind());
1364 auto localReduce = vector::MultiDimReductionOp::create(
1365 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1367 localReductions.push_back(localReduce.getResult());
1371 SmallVector<int64_t> crossSgReductionDims;
1372 for (int64_t reductionDim : reductionDims) {
1373 bool needsCrossSubgroupReduction =
1374 (sgLayout[reductionDim] > 1) &&
1375 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1377 if (needsCrossSubgroupReduction) {
1378 crossSgReductionDims.push_back(reductionDim);
1383 if (crossSgReductionDims.empty()) {
1384 SmallVector<Value> results;
1385 for (
auto localResult : localReductions) {
1387 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1388 results.push_back(finalResult);
1390 rewriter.replaceOpWithMultiple(op, {results});
1395 auto slmStoreDataShape = sgSrcShape;
1396 for (int64_t dim : reductionDims)
1397 slmStoreDataShape[dim] = 1;
1398 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1399 Value slmStoreData = vector::ShapeCastOp::create(
1400 rewriter, loc, slmStoreDataType, localReductions[0]);
1402 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1403 originalSrcShape.end());
1405 for (int64_t dim : reductionDims)
1406 slmShape[dim] = sgLayout[dim];
1410 auto bytesPerElement = bitWidth / 8;
1412 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1413 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1415 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1418 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1421 if (localReductions.size() > 1) {
1422 return rewriter.notifyMatchFailure(
1424 "Multiple local reductions not supported in current implementation.");
1428 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1429 rewriter.getIndexType(),
nullptr);
1432 SmallVector<Value> sgLayoutValues;
1433 for (int64_t dim : sgLayout)
1434 sgLayoutValues.push_back(
1437 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1441 SmallVector<Value> sgIds = *sgIdsResult;
1443 auto getSlmOffsets = [&](int64_t reductionDimStride) {
1444 SmallVector<OpFoldResult> offsets;
1445 offsets.reserve(srcVecRank);
1446 for (
int i = 0; i < srcVecRank; ++i) {
1447 Value dimVal = sgIds[i];
1448 int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
1449 ? reductionDimStride
1454 arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1455 offsets.push_back(offsetVal);
1460 SmallVector<OpFoldResult> slmStoreOffsets =
1463 xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
1464 memDesc.getResult(), slmStoreOffsets,
1467 gpu::BarrierOp::create(rewriter, loc);
1470 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1471 for (int64_t dim : reductionDims)
1472 slmLoadDataShape[dim] = slmShape[dim];
1474 SmallVector<OpFoldResult> slmLoadOffsets =
1477 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1478 auto slmLoadOp = xegpu::LoadMatrixOp::create(
1479 rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
1483 auto neutralFinalAcc =
1484 createAccumulator(rewriter, loc, newDstType, op.getKind());
1486 auto finalReduce = vector::MultiDimReductionOp::create(
1487 rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(),
1488 neutralFinalAcc, reductionDims);
1492 finalReduce.getResult(),
1493 adaptor.getAcc()[0]);
1495 rewriter.replaceOp(op, finalResult);
1501struct WgToSgVectorTransposeOp
1502 :
public OpConversionPattern<vector::TransposeOp> {
1503 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1506 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1507 ConversionPatternRewriter &rewriter)
const override {
1508 VectorType resultType = op.getResultVectorType();
1510 ArrayRef<int64_t> wgShape = resultType.getShape();
1511 xegpu::DistributeLayoutAttr layout =
1513 if (!layout || !layout.isForWorkgroup())
1516 xegpu::DistributeLayoutAttr sourceLayout =
1518 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1521 SmallVector<int64_t> sourceSgLayout =
1522 sourceLayout.getEffectiveSgLayoutAsInt();
1523 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1525 ArrayRef<int64_t> permutation = op.getPermutation();
1526 size_t permutationSize = permutation.size();
1527 if (sourceSgLayout.size() != permutationSize ||
1528 resultSgLayout.size() != permutationSize) {
1529 return rewriter.notifyMatchFailure(
1530 op,
"Layouts and permutation must have the same rank");
1535 if (!layout.isTransposeOf(sourceLayout, permutation,
1536 xegpu::LayoutKind::Subgroup))
1537 return rewriter.notifyMatchFailure(
1538 op,
"Result layout is not a valid transpose of source layout "
1539 "according to permutation");
1541 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1542 VectorType newResultType =
1543 VectorType::get(sgShape, resultType.getElementType());
1545 SmallVector<Value> newTransposeOps;
1546 for (
auto src : adaptor.getVector()) {
1547 auto newTranspose = vector::TransposeOp::create(
1548 rewriter, op.getLoc(), newResultType, src, permutation);
1549 newTransposeOps.push_back(newTranspose.getResult());
1551 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1557template <
typename MaskOpType>
1558struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1559 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1561 LogicalResult matchAndRewrite(
1563 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1564 ConversionPatternRewriter &rewriter)
const override {
1565 xegpu::DistributeLayoutAttr layout =
1567 if (!layout || !layout.isForWorkgroup())
1570 Location loc = op.getLoc();
1571 VectorType type = op.getResult().getType();
1572 auto wgShape = type.getShape();
1574 SmallVector<Value> wgMaskDimSizes;
1575 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1576 for (int64_t maskSize : op.getMaskDimSizes()) {
1577 wgMaskDimSizes.push_back(
1580 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1581 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1585 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1587 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1591 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1592 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1596 SmallVector<Value> newCreateMaskOps;
1597 for (
auto offsetSet : *sgOffsets) {
1598 SmallVector<Value> maskOperands;
1600 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1603 Value offset = offsetSet[i];
1604 Value adjustedMaskSize =
1605 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1608 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1610 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1611 maskOperands.push_back(sgMaskSize);
1614 auto newCreateMaskOp =
1615 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1616 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1619 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1624using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1625using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1632 .
add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1633 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1634 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1635 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1636 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1637 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1638 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1639 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1640 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1641 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1648struct XeGPUWgToSgDistributePass
1650 void runOnOperation()
override;
1654void XeGPUWgToSgDistributePass::runOnOperation() {
1656 Operation *op = getOperation();
1658 signalPassFailure();
1663 SmallVector<Operation *> existingCastOps;
1664 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1665 existingCastOps.push_back(castOp.getOperation());
1675 TypeConverter converter;
1676 converter.addConversion([&](Type type) -> Type {
return type; });
1677 converter.addConversion(
1678 [&](RankedTensorType type,
1679 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1684 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
1686 return std::nullopt;
1688 Type elemTy = type.getElementType();
1689 ArrayRef<int64_t> shape = type.getShape();
1692 SmallVector<int64_t> subShape;
1693 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1695 auto newTy = VectorType::get(subShape, elemTy);
1696 result.append(count, newTy);
1707 RewritePatternSet patterns(ctx);
1708 ConversionTarget
target(*ctx);
1709 TypeConverter converter;
1710 converter.addConversion([&](Type type) -> Type {
return type; });
1711 converter.addConversion(
1712 [&](xegpu::TensorDescType type,
1713 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1714 Type elemTy = type.getElementType();
1715 ArrayRef<int64_t> shape = type.getShape();
1718 SmallVector<int64_t> subShape;
1719 xegpu::LayoutAttr layout = type.getLayoutAttr();
1720 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1723 layout = layout.dropSgLayoutAndData();
1725 auto newTy = xegpu::TensorDescType::get(
1726 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1727 result.append(count, newTy);
1731 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1732 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1733 return createOp.getType();
1734 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1735 return loadOp.getTensorDescType();
1736 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1737 return storeOp.getTensorDescType();
1738 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1740 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1741 return prefetchOp.getTensorDescType();
1742 return xegpu::TensorDescType();
1745 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1746 return !layout || !layout.isForWorkgroup();
1749 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1750 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1751 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1752 auto tdescTy = getTensorDescType(op);
1753 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1754 return isLegal(layout);
1757 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1758 auto layout = op.getLayoutCdAttr();
1759 return isLegal(layout);
1762 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1763 [=](xegpu::LoadMatrixOp op) ->
bool {
1764 return isLegal(op.getLayoutAttr());
1767 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1768 [=](xegpu::StoreMatrixOp op) ->
bool {
1769 return isLegal(op.getLayoutAttr());
1772 target.addDynamicallyLegalOp<arith::ConstantOp>(
1773 [=](arith::ConstantOp op) ->
bool {
1774 auto vecType = dyn_cast<VectorType>(op.getType());
1780 return isLegal(layout);
1783 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1784 vector::TransposeOp, vector::BroadcastOp,
1785 vector::MultiDimReductionOp,
1786 vector::ConstantMaskOp, vector::CreateMaskOp>(
1787 [=](Operation *op) ->
bool {
1791 return isLegal(layout);
1794 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1795 [=](xegpu::LoadGatherOp op) ->
bool {
1796 auto layout = op.getLayoutAttr();
1797 return isLegal(layout);
1800 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1801 [=](xegpu::StoreScatterOp op) ->
bool {
1802 auto layout = op.getLayoutAttr();
1803 return isLegal(layout);
1806 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1807 [=](xegpu::ConvertLayoutOp op) ->
bool {
1808 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1811 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1812 [=](Operation *op) -> std::optional<bool> {
1817 VectorType resultType =
1825 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1826 if (!operandType || operandType.getShape() != resultType.getShape()) {
1831 xegpu::DistributeLayoutAttr layout =
1833 return isLegal(layout);
1836 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1837 [=](UnrealizedConversionCastOp op) {
1838 return llvm::is_contained(existingCastOps, op.getOperation());
1841 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1847 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1848 return signalPassFailure();
1851 getOperation()->walk([](Operation *op) {
1852 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1855 SmallVector<StringAttr> attrsToRemove;
1857 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1858 attrsToRemove.push_back(namedAttr.getName());
1860 for (
auto attrName : attrsToRemove)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)