29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 auto distributedShape = layout.computeDistributedShape(
57 if (
failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
61 return std::make_pair(sgData, count);
68template <
typename OpType,
69 typename = std::enable_if_t<llvm::is_one_of<
70 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
71 xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
73genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
78 if (origOffsets.empty())
82 xegpu::DistributeLayoutAttr layout;
83 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
84 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
85 layout = op.getLayoutAttr();
87 layout = op.getDescLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
95 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
104 return rewriter.notifyMatchFailure(
105 op,
"sg_layout size must match the sg_id_range");
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
117 auto maybeDescOffsets =
118 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119 if (
failed(maybeDescOffsets))
124 for (
const auto &sgOffsets : *maybeDescOffsets) {
127 offsetsList.push_back(std::move(newOffsets));
181struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
182 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
185 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter)
const override {
188 Location loc = op.getLoc();
190 xegpu::TensorDescType tdescTy = op.getType();
191 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
192 if (!layout || !layout.isForWorkgroup())
195 Type elemTy = tdescTy.getElementType();
196 ArrayRef<int64_t> wgShape = tdescTy.getShape();
198 SmallVector<int64_t> sgShape;
200 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
201 xegpu::TensorDescType newTdescTy =
202 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
203 layout.dropSgLayoutAndData());
205 SmallVector<Value> newCreateNdOps(count);
206 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
207 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
208 op.getSource(), op.getMixedSizes(),
209 op.getMixedStrides());
212 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
218struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
219 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
221 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter)
const override {
224 SmallVector<SmallVector<OpFoldResult>> offsetsList;
225 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
228 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
230 layout = layout.dropSgLayoutAndData();
231 SmallVector<Value> newOps;
232 for (
auto [tdesc, offsets] :
233 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
234 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
235 VectorType newResTy =
236 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
237 auto newOp = xegpu::LoadNdOp::create(
238 rewriter, op.getLoc(), newResTy, tdesc, offsets,
239 nullptr,
nullptr, op.getL1HintAttr(),
240 op.getL2HintAttr(), op.getL3HintAttr(), layout);
241 newOps.push_back(newOp);
243 rewriter.replaceOpWithMultiple(op, {newOps});
250struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
251 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 SmallVector<SmallVector<OpFoldResult>> offsetsList;
256 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
259 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
261 layout = layout.dropSgLayoutAndData();
262 for (
auto [v, tdesc, offsets] :
263 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
264 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
265 op.getL1HintAttr(), op.getL2HintAttr(),
266 op.getL3HintAttr(), layout);
268 rewriter.eraseOp(op);
275struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
276 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
278 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter)
const override {
280 SmallVector<SmallVector<OpFoldResult>> offsetsList;
281 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
284 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
286 layout = layout.dropSgLayoutAndData();
287 for (
auto [tdesc, offsets] :
288 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
289 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
290 op.getL1HintAttr(), op.getL2HintAttr(),
291 op.getL3HintAttr(), layout);
293 rewriter.eraseOp(op);
300struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
301 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
303 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const override {
305 Location loc = op.getLoc();
306 VectorType resultTy = op.getResult().getType();
307 if (resultTy.getRank() != 2)
310 auto layoutCd = op.getLayoutCdAttr();
311 auto layoutA = op.getLayoutAAttr();
312 auto layoutB = op.getLayoutBAttr();
313 if (!layoutCd || !layoutA || !layoutB)
316 SmallVector<Value> newDpasOps;
317 for (
auto aVec : adaptor.getLhs()) {
318 for (
auto bVec : adaptor.getRhs()) {
320 llvm::SmallVector<Value> operands({aVec, bVec});
323 tmpC = adaptor.getAcc()[i++];
324 operands.push_back(tmpC);
327 ArrayRef<int64_t> aVecShape =
328 cast<VectorType>(aVec.getType()).getShape();
329 ArrayRef<int64_t> bVecShape =
330 cast<VectorType>(bVec.getType()).getShape();
331 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
332 resultTy.getElementType());
333 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
334 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
335 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
336 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
338 newDpasOps.push_back(newDpasOp);
341 rewriter.replaceOpWithMultiple(op, {newDpasOps});
347struct WgToSgDpasMxOp :
public OpConversionPattern<xegpu::DpasMxOp> {
348 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
350 matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter)
const override {
353 Location loc = op.getLoc();
354 VectorType resultTy = op.getResult().getType();
356 if (resultTy.getRank() != 2)
359 auto layoutCd = op.getLayoutCdAttr();
360 auto layoutA = op.getLayoutAAttr();
361 auto layoutB = op.getLayoutBAttr();
362 auto layoutAScale = op.getLayoutAScaleAttr();
363 auto layoutBScale = op.getLayoutBScaleAttr();
365 if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
369 SmallVector<Value> newDpasMxOps;
370 for (
auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
371 for (
auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
372 Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
374 (op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
376 (op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
378 ArrayRef<int64_t> aVecShape =
379 cast<VectorType>(aVec.getType()).getShape();
380 ArrayRef<int64_t> bVecShape =
381 cast<VectorType>(bVec.getType()).getShape();
382 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
383 resultTy.getElementType());
384 auto newDpasMxOp = xegpu::DpasMxOp::create(
385 rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
386 layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
387 layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
388 layoutBScale.dropSgLayoutAndData());
390 newDpasMxOps.push_back(newDpasMxOp);
393 rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
399struct WgToSgVectorBroadcastOp
400 :
public OpConversionPattern<vector::BroadcastOp> {
401 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
404 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
407 VectorType resultType = op.getResult().getType();
408 ArrayRef<int64_t> wgShape = resultType.getShape();
410 xegpu::DistributeLayoutAttr layout =
412 if (!layout || !layout.isForWorkgroup())
415 SmallVector<int64_t> sgShape;
417 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
418 VectorType newResultType =
419 VectorType::get(sgShape, resultType.getElementType());
421 SmallVector<Value> newBroadcastOps;
422 auto distSource = adaptor.getOperands().front();
423 int numDistributions = count / distSource.size();
424 for (
int i = 0; i < numDistributions; ++i) {
425 for (
auto operand : distSource) {
426 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
427 newResultType, operand);
429 newBroadcastOps.push_back(newBroadcast.getResult());
432 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
439 WgToSgElementwiseOp(MLIRContext *ctx)
440 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
443 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
444 ConversionPatternRewriter &rewriter)
const override {
450 assert(resultType &&
"Expected result to be a VectorType");
452 ArrayRef<int64_t> wgShape = resultType.getShape();
454 xegpu::DistributeLayoutAttr layout =
456 if (!layout || !layout.isForWorkgroup())
459 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
461 size_t numVariants = operands.empty() ? 0 : operands.front().size();
463 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
464 return operandVec.size() != numVariants;
468 SmallVector<Value> newResults;
469 VectorType newResultType =
470 VectorType::get(sgShape, resultType.getElementType());
472 for (
size_t i = 0; i < numVariants; ++i) {
473 SmallVector<Value> opOperands;
474 for (
auto &operandVec : operands)
475 opOperands.push_back(operandVec[i]);
478 state.addOperands(opOperands);
479 state.addTypes(newResultType);
480 state.addAttributes(op->
getAttrs());
481 Operation *newOp = rewriter.create(state);
483 newResults.push_back(newOp->
getResult(0));
486 rewriter.replaceOpWithMultiple(op, {newResults});
517struct WgToSgConvertLayoutOp
518 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
519 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
522 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
523 ConversionPatternRewriter &rewriter)
const override {
525 auto inputLayout = op.getInputLayout();
526 auto targetLayout = op.getTargetLayout();
528 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
529 !targetLayout.isForWorkgroup())
530 return rewriter.notifyMatchFailure(
531 op,
"Input and target layouts must have subgroup layout");
533 Type resultType = op.getResult().getType();
535 rewriter.replaceOp(op, op.getSource());
536 assert(!inputLayout.dropSgLayoutAndData() &&
537 !targetLayout.dropSgLayoutAndData() &&
538 "unexpected layout attributes for scalar type");
544 inputLayout.getEffectiveSgLayoutAsInt();
547 targetLayout.getEffectiveSgLayoutAsInt();
552 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
554 inputLayout = inputLayout.dropSgLayoutAndData();
555 targetLayout = targetLayout.dropSgLayoutAndData();
558 if (inputLayout && targetLayout) {
559 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
560 auto newOp = xegpu::ConvertLayoutOp::create(
561 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
565 rewriter.replaceOpWithMultiple(op, {newOps});
570 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
576 auto bytesPerElement = bitWidth / 8;
580 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
581 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
583 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
586 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
588 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
589 rewriter.getIndexType(),
nullptr);
592 auto storeCoords = inputLayout.computeDistributedCoords(
593 rewriter, loc, sgId.getResult(), wgShape);
594 if (failed(storeCoords))
598 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
600 for (
Value coord : coords) {
601 storeMatrixOffsets.push_back(coord);
603 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
604 storeMatrixOffsets,
nullptr );
607 gpu::BarrierOp::create(rewriter, loc);
610 auto loadCoords = targetLayout.computeDistributedCoords(
611 rewriter, loc, sgId.getResult(), wgShape);
612 if (failed(loadCoords))
615 VectorType loadType = VectorType::get(targetSgData, elemTy);
619 for (
auto coords : *loadCoords) {
621 for (
Value coord : coords) {
622 loadMatrixOffsets.push_back(coord);
624 auto loadOp = xegpu::LoadMatrixOp::create(
625 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
626 targetLayout.dropSgLayoutAndData());
628 finalResults.push_back(loadOp.getResult());
631 rewriter.replaceOpWithMultiple(op, {finalResults});
667struct UnrealizedConversionCastOpPattern
668 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
669 using OpConversionPattern<
670 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
673 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter)
const override {
677 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
678 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
680 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
681 !llvm::all_equal(
ValueRange(inputs).getTypes()))
689 if (op.getNumOperands() == 1 &&
690 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
691 rewriter.replaceOp(op, inputs);
702 if (op.getNumResults() == 1 &&
704 rewriter.replaceOpWithMultiple(op, {inputs});
708 return mlir::failure();
713struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
714 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
717 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter)
const override {
719 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
720 auto vecType = dyn_cast<VectorType>(op.getType());
721 if (!vecAttr || !vecType)
724 xegpu::DistributeLayoutAttr layout =
726 if (!layout || !layout.isForWorkgroup())
729 ArrayRef<int64_t> wgShape = vecType.getShape();
730 SmallVector<int64_t> sgShape;
732 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
734 auto newType = VectorType::get(sgShape, vecType.getElementType());
735 Location loc = op.getLoc();
736 auto eltType = vecType.getElementType();
738 if (vecAttr.isSplat()) {
740 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
742 SmallVector<Value> newConstOps;
743 for (
int i = 0; i < count; ++i) {
744 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
745 newConstOps.push_back(cstOp);
747 rewriter.replaceOpWithMultiple(op, {newConstOps});
749 }
else if (sgShape == wgShape) {
752 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
753 rewriter.replaceOp(op, newConstOp);
759 if (!eltType.isIndex())
760 return rewriter.notifyMatchFailure(
761 op,
"Unsupported element type for non-splat constant op.");
763 if (wgShape.size() > 2)
764 return rewriter.notifyMatchFailure(
765 op,
"Only 1D & 2D vector constant supported");
767 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
768 int64_t rowStride = 0, colStride = 0;
769 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
770 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
774 colStride = cast<IntegerAttr>(values[1]).getInt() -
775 cast<IntegerAttr>(values[0]).getInt();
778 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
779 cast<IntegerAttr>(values[0]).getInt();
782 for (int64_t r = 0; r < rows; ++r) {
783 for (int64_t c = 0; c < cols; ++c) {
784 int64_t idx = r * cols + c;
786 if (c > 0 && cols > 1) {
787 int64_t prevIdx = r * cols + (c - 1);
788 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
789 cast<IntegerAttr>(values[prevIdx]).getInt();
790 if (diff != colStride)
791 return rewriter.notifyMatchFailure(
792 op,
"Non-constant column stride in constant op.");
795 if (r > 0 && rows > 1) {
796 int64_t prevIdx = (r - 1) * cols + c;
797 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
798 cast<IntegerAttr>(values[prevIdx]).getInt();
799 if (diff != rowStride)
800 return rewriter.notifyMatchFailure(
801 op,
"Non-constant row stride in constant op.");
809 SmallVector<Attribute> baseTileValues;
810 int baseTileCols = sgShape[sgShape.size() - 1];
811 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
812 for (int64_t r = 0; r < baseTileRows; ++r) {
813 for (int64_t c = 0; c < baseTileCols; ++c) {
814 baseTileValues.push_back(values[r * cols + c]);
820 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
824 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
826 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
830 SmallVector<Value, 2> strideConsts;
831 strideConsts.push_back(
835 strideConsts.begin(),
838 SmallVector<Value> newConstOps;
839 for (
auto offsets : *sgOffsets) {
842 for (
size_t i = 0; i < strideConsts.size(); ++i) {
844 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
845 offsets[i], strideConsts[i]);
846 mulOffset = arith::AddIOp::create(
847 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
850 auto bcastOffset = vector::BroadcastOp::create(
851 rewriter, loc, baseConstVec.getType(), mulOffset);
853 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
854 newConstOps.push_back(finalConst);
856 rewriter.replaceOpWithMultiple(op, {newConstOps});
864struct WgToSgLoadGatherOp :
public OpConversionPattern<xegpu::LoadGatherOp> {
865 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
867 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
868 ConversionPatternRewriter &rewriter)
const override {
870 Location loc = op.getLoc();
871 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
874 ArrayRef<int64_t> wgShape = resultType.getShape();
876 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
878 if (!layout || !layout.isForWorkgroup())
881 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
884 auto offsetsVecType =
885 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
887 dyn_cast<VectorType>(adaptor.getMask().front().getType());
888 if (!offsetsVecType || !maskVecType ||
889 offsetsVecType.getShape() != maskVecType.getShape()) {
890 return rewriter.notifyMatchFailure(op,
891 "offsets have not been distributed");
894 SmallVector<Value> newLoadOps;
896 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
897 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
898 for (
auto [offsets, mask] :
899 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
900 auto newLayout = layout.dropSgLayoutAndData();
901 auto newLoadOp = xegpu::LoadGatherOp::create(
902 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
903 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
905 newLoadOps.push_back(newLoadOp);
907 rewriter.replaceOpWithMultiple(op, {newLoadOps});
914struct WgToSgStoreScatterOp
915 :
public OpConversionPattern<xegpu::StoreScatterOp> {
916 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
918 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
921 Location loc = op.getLoc();
922 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
926 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
928 if (!layout || !layout.isForWorkgroup())
932 auto offsetsVecType =
933 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
935 dyn_cast<VectorType>(adaptor.getMask().front().getType());
936 if (!offsetsVecType || !maskVecType ||
937 offsetsVecType.getShape() != maskVecType.getShape()) {
938 return rewriter.notifyMatchFailure(op,
939 "offsets have not been distributed");
942 auto chunkSizeOpt = op.getChunkSize();
943 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
944 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
945 for (
auto [val, offs, mask] : llvm::zip(
946 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
947 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
948 mask, chunkSizeAttr, op.getL1HintAttr(),
949 op.getL2HintAttr(), op.getL3HintAttr(),
950 layout.dropSgLayoutAndData());
952 rewriter.eraseOp(op);
957struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
958 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
960 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter)
const override {
963 SmallVector<SmallVector<OpFoldResult>> offsetsList;
964 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
967 ArrayRef<int64_t> wgShape = op.getDataShape();
968 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
969 assert(valueTy &&
"the value type must be vector type!");
970 Type elemTy = valueTy.getElementType();
972 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
973 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
974 VectorType newResTy = VectorType::get(sgShape, elemTy);
975 SmallVector<Value> newOps;
976 for (
auto offsets : offsetsList) {
977 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
978 op.getMemDesc(), offsets,
979 layout.dropSgLayoutAndData());
980 newOps.push_back(newOp);
982 rewriter.replaceOpWithMultiple(op, {newOps});
988struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
989 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
991 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter)
const override {
994 SmallVector<SmallVector<OpFoldResult>> offsetsList;
995 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
998 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
999 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1000 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1001 offsets, layout.dropSgLayoutAndData());
1002 rewriter.eraseOp(op);
1008struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1009 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1011 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1012 ConversionPatternRewriter &rewriter)
const override {
1013 xegpu::DistributeLayoutAttr layout =
1015 if (!layout || !layout.isForWorkgroup())
1018 Location loc = op.getLoc();
1019 VectorType type = op.getResult().getType();
1020 auto wgShape = type.getShape();
1021 std::optional<SmallVector<int64_t>> sgShape =
1022 getSgShapeAndCount(wgShape, layout).first;
1027 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1029 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1033 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1034 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1035 SmallVector<Value> newOps;
1036 for (
auto offsets : *sgOffsets) {
1039 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1041 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1042 newOps.push_back(finalSteps);
1045 rewriter.replaceOpWithMultiple(op, {newOps});
1051struct WgToSgVectorShapeCastOp
1052 :
public OpConversionPattern<vector::ShapeCastOp> {
1053 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1056 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter)
const override {
1059 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1063 ArrayRef<int64_t> wgShape = resultType.getShape();
1064 xegpu::DistributeLayoutAttr layout =
1066 if (!layout || !layout.isForWorkgroup())
1071 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1075 ArrayRef<int64_t> srcShape = srcType.getShape();
1077 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1078 SmallVector<int64_t> expandedUnitDims;
1080 xegpu::DistributeLayoutAttr sourceLayout =
1083 if (!sourceLayout.isSliceOf(layout))
1084 return rewriter.notifyMatchFailure(
1085 op,
"The ShapeCast op only expands dimensions, the input layout "
1086 "must be a slice of the result layout.");
1088 assert(layoutToDistribute.isEqualTo(
1089 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1090 "The sg_data for unit dimensions should be set as 1");
1093 SmallVector<int64_t> sgShape =
1094 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1095 VectorType newResultType =
1096 VectorType::get(sgShape, resultType.getElementType());
1098 SmallVector<Value> newShapeCastOps;
1099 for (
auto src : adaptor.getSource()) {
1100 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1101 newResultType, src);
1102 newShapeCastOps.push_back(newShapeCast.getResult());
1105 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1142struct WgToSgMultiDimReductionOp
1143 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1144 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1147 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1149 Location loc = op.getLoc();
1151 VectorType srcType = op.getSourceVectorType();
1152 Type resultTy = op.getResult().getType();
1153 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1154 bool isScalarResult = !dstVecType;
1156 auto originalSrcShape = srcType.getShape();
1157 Type elemTy = srcType.getElementType();
1159 xegpu::DistributeLayoutAttr layout =
1161 if (!layout || !layout.isForWorkgroup())
1164 auto reductionDims = llvm::to_vector(op.getReductionDims());
1167 SmallVector<int64_t> sgLayout;
1168 SmallVector<int64_t> sgData;
1169 xegpu::DistributeLayoutAttr parentLayout;
1170 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1171 parentLayout = sliceAttr.getParent();
1172 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1173 sgData = parentLayout.getEffectiveSgDataAsInt();
1175 return rewriter.notifyMatchFailure(
1176 op,
"Reduction should have SliceAttr layout");
1179 SmallVector<Value> localReductions;
1180 auto sgSrcs = adaptor.getSource();
1181 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1182 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1183 sgSrcType.getShape().end());
1190 auto originalDstShape = dstVecType.getShape();
1191 SmallVector<int64_t> sgDstShape =
1192 getSgShapeAndCount(originalDstShape, layout).first;
1193 sgDstType = VectorType::get(sgDstShape, elemTy);
1198 for (
auto sgSrc : sgSrcs) {
1201 rewriter, loc, sgDstType, op.getKind());
1203 auto localReduce = vector::MultiDimReductionOp::create(
1204 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1206 localReductions.push_back(localReduce.getResult());
1210 SmallVector<int64_t> crossSgReductionDims;
1211 for (int64_t reductionDim : reductionDims) {
1212 bool needsCrossSubgroupReduction =
1213 (sgLayout[reductionDim] > 1) &&
1214 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1216 if (needsCrossSubgroupReduction) {
1217 crossSgReductionDims.push_back(reductionDim);
1222 if (crossSgReductionDims.empty()) {
1223 SmallVector<Value> results;
1224 for (
auto localResult : localReductions) {
1226 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1227 results.push_back(finalResult);
1229 rewriter.replaceOpWithMultiple(op, {results});
1234 auto slmStoreDataShape = sgSrcShape;
1235 for (int64_t dim : reductionDims)
1236 slmStoreDataShape[dim] = 1;
1237 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1238 SmallVector<Value> slmStoreData;
1239 for (
auto localResult : localReductions) {
1240 if (isScalarResult) {
1242 slmStoreData.push_back(vector::BroadcastOp::create(
1243 rewriter, loc, slmStoreDataType, localResult));
1245 slmStoreData.push_back(vector::ShapeCastOp::create(
1246 rewriter, loc, slmStoreDataType, localResult));
1250 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1251 originalSrcShape.end());
1252 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1253 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1254 for (
int dim : reductionDims) {
1255 slmShape[dim] = sgLayout[dim];
1258 xegpu::LayoutAttr slmStoreLayout =
1259 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1263 auto bytesPerElement = bitWidth / 8;
1265 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1266 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1268 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1271 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1274 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1275 rewriter.getIndexType(),
nullptr);
1277 auto slmStoreCoords =
1278 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1279 if (
failed(slmStoreCoords))
1281 for (
auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1282 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1283 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1288 gpu::BarrierOp::create(rewriter, loc);
1291 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1292 for (int64_t dim : reductionDims) {
1293 slmLoadDataShape[dim] = slmShape[dim];
1294 slmSgData[dim] = slmShape[dim];
1296 xegpu::LayoutAttr slmLoadLayout =
1297 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1298 auto slmLoadCoords =
1299 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1300 if (
failed(slmLoadCoords))
1303 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1304 SmallVector<Value> slmLoadData;
1305 for (
auto coord : *slmLoadCoords) {
1306 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1307 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1308 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1315 rewriter, loc, sgDstType, op.getKind());
1317 SmallVector<Value> finalResults;
1318 for (
size_t i = 0; i < slmLoadData.size(); ++i) {
1319 auto loaded = slmLoadData[i];
1320 auto finalReduce = vector::MultiDimReductionOp::create(
1321 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1324 rewriter, loc, op.getKind(), finalReduce.getResult(),
1325 adaptor.getAcc()[i]));
1327 rewriter.replaceOpWithMultiple(op, {finalResults});
1333struct WgToSgVectorTransposeOp
1334 :
public OpConversionPattern<vector::TransposeOp> {
1335 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1338 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1339 ConversionPatternRewriter &rewriter)
const override {
1340 VectorType resultType = op.getResultVectorType();
1342 ArrayRef<int64_t> wgShape = resultType.getShape();
1343 xegpu::DistributeLayoutAttr layout =
1345 if (!layout || !layout.isForWorkgroup())
1348 xegpu::DistributeLayoutAttr sourceLayout =
1350 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1353 SmallVector<int64_t> sourceSgLayout =
1354 sourceLayout.getEffectiveSgLayoutAsInt();
1355 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1357 ArrayRef<int64_t> permutation = op.getPermutation();
1358 size_t permutationSize = permutation.size();
1359 if (sourceSgLayout.size() != permutationSize ||
1360 resultSgLayout.size() != permutationSize) {
1361 return rewriter.notifyMatchFailure(
1362 op,
"Layouts and permutation must have the same rank");
1367 if (!layout.isTransposeOf(sourceLayout, permutation,
1368 xegpu::LayoutKind::Subgroup))
1369 return rewriter.notifyMatchFailure(
1370 op,
"Result layout is not a valid transpose of source layout "
1371 "according to permutation");
1373 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1374 VectorType newResultType =
1375 VectorType::get(sgShape, resultType.getElementType());
1377 SmallVector<Value> newTransposeOps;
1378 for (
auto src : adaptor.getVector()) {
1379 auto newTranspose = vector::TransposeOp::create(
1380 rewriter, op.getLoc(), newResultType, src, permutation);
1381 newTransposeOps.push_back(newTranspose.getResult());
1383 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1389template <
typename MaskOpType>
1390struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1391 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1393 LogicalResult matchAndRewrite(
1395 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter)
const override {
1397 xegpu::DistributeLayoutAttr layout =
1399 if (!layout || !layout.isForWorkgroup())
1402 Location loc = op.getLoc();
1403 VectorType type = op.getResult().getType();
1404 auto wgShape = type.getShape();
1406 SmallVector<Value> wgMaskDimSizes;
1407 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1408 for (int64_t maskSize : op.getMaskDimSizes()) {
1409 wgMaskDimSizes.push_back(
1412 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1413 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1417 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1419 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1423 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1424 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1428 SmallVector<Value> newCreateMaskOps;
1429 for (
auto offsetSet : *sgOffsets) {
1430 SmallVector<Value> maskOperands;
1432 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1435 Value offset = offsetSet[i];
1436 Value adjustedMaskSize =
1437 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1440 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1442 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1443 maskOperands.push_back(sgMaskSize);
1446 auto newCreateMaskOp =
1447 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1448 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1451 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1456using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1457using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1460struct WgToSgVectorBitCastOp :
public OpConversionPattern<vector::BitCastOp> {
1461 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1464 matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1465 ConversionPatternRewriter &rewriter)
const override {
1466 VectorType resultType = op.getResultVectorType();
1468 ArrayRef<int64_t> wgShape = resultType.getShape();
1469 xegpu::DistributeLayoutAttr layout =
1471 if (!layout || !layout.isForWorkgroup())
1474 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1475 VectorType newResultType =
1476 VectorType::get(sgShape, resultType.getElementType());
1478 SmallVector<Value> newBitCastOps;
1479 for (
auto src : adaptor.getSource()) {
1481 vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1482 newBitCastOps.push_back(newBitCast.getResult());
1485 rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1491struct WgToSgVectorInterleaveOp
1492 :
public OpConversionPattern<vector::InterleaveOp> {
1493 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1496 matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1497 ConversionPatternRewriter &rewriter)
const override {
1498 VectorType resultType = op.getResultVectorType();
1500 ArrayRef<int64_t> wgShape = resultType.getShape();
1501 xegpu::DistributeLayoutAttr layout =
1503 if (!layout || !layout.isForWorkgroup())
1506 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1507 VectorType newResultType =
1508 VectorType::get(sgShape, resultType.getElementType());
1510 SmallVector<Value> newInterleaveOps;
1513 for (
auto [
lhs,
rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1514 auto newInterleave = vector::InterleaveOp::create(
1515 rewriter, op.getLoc(), newResultType,
lhs,
rhs);
1516 newInterleaveOps.push_back(newInterleave.getResult());
1519 rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1525struct WgToSgVectorDeinterleaveOp
1526 :
public OpConversionPattern<vector::DeinterleaveOp> {
1527 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1530 matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1531 ConversionPatternRewriter &rewriter)
const override {
1532 SmallVector<Value> newRes1Ops;
1533 SmallVector<Value> newRes2Ops;
1535 for (
auto src : adaptor.getSource()) {
1536 auto newDeinterleave =
1537 vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1538 newRes1Ops.push_back(newDeinterleave.getRes1());
1539 newRes2Ops.push_back(newDeinterleave.getRes2());
1542 SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1543 rewriter.replaceOpWithMultiple(op, results);
1553 patterns.
add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1554 WgToSgDpasMxOp, WgToSgPrefetchNdOp,
1555 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
1556 WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1557 WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1558 WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1559 WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1560 WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1561 WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1562 WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
1569struct XeGPUWgToSgDistributePass
1571 void runOnOperation()
override;
1575void XeGPUWgToSgDistributePass::runOnOperation() {
1577 Operation *op = getOperation();
1579 signalPassFailure();
1584 SmallVector<Operation *> existingCastOps;
1585 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1586 existingCastOps.push_back(castOp.getOperation());
1596 TypeConverter converter;
1597 converter.addConversion([&](Type type) -> Type {
return type; });
1598 converter.addConversion(
1599 [&](RankedTensorType type,
1600 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1604 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1605 type.getEncoding());
1607 return std::nullopt;
1609 Type elemTy = type.getElementType();
1610 ArrayRef<int64_t> shape = type.getShape();
1613 SmallVector<int64_t> subShape;
1614 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1616 auto newTy = VectorType::get(subShape, elemTy);
1617 result.append(count, newTy);
1628 RewritePatternSet patterns(ctx);
1629 ConversionTarget
target(*ctx);
1630 TypeConverter converter;
1631 converter.addConversion([&](Type type) -> Type {
return type; });
1632 converter.addConversion(
1633 [&](xegpu::TensorDescType type,
1634 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1635 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1638 if (!layout || !layout.isForWorkgroup())
1639 return std::nullopt;
1641 Type elemTy = type.getElementType();
1642 ArrayRef<int64_t> shape = type.getShape();
1645 SmallVector<int64_t> subShape;
1646 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1648 layout = layout.dropSgLayoutAndData();
1650 auto newTy = xegpu::TensorDescType::get(
1651 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1652 result.append(count, newTy);
1656 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1657 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1658 return createOp.getType();
1659 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1660 return loadOp.getTensorDescType();
1661 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1662 return storeOp.getTensorDescType();
1663 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1664 return prefetchOp.getTensorDescType();
1665 return xegpu::TensorDescType();
1668 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1669 return !layout || !layout.isForWorkgroup();
1672 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1673 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(
1674 [=](Operation *op) ->
bool {
1675 auto tdescTy = getTensorDescType(op);
1677 dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1678 return isLegal(layout);
1681 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1682 auto layout = op.getLayoutCdAttr();
1683 return isLegal(layout);
1686 target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1687 [=](xegpu::DpasMxOp op) ->
bool {
1688 auto layout = op.getLayoutCdAttr();
1689 return isLegal(layout);
1692 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1693 [=](xegpu::LoadMatrixOp op) ->
bool {
1694 return isLegal(op.getLayoutAttr());
1697 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1698 [=](xegpu::StoreMatrixOp op) ->
bool {
1699 return isLegal(op.getLayoutAttr());
1702 target.addDynamicallyLegalOp<arith::ConstantOp>(
1703 [=](arith::ConstantOp op) ->
bool {
1704 auto vecType = dyn_cast<VectorType>(op.getType());
1710 return isLegal(layout);
1713 target.addDynamicallyLegalOp<
1714 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1715 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1716 vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1717 vector::DeinterleaveOp>([=](Operation *op) ->
bool {
1721 return isLegal(layout);
1724 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1725 [=](xegpu::LoadGatherOp op) ->
bool {
1726 auto layout = op.getLayoutAttr();
1727 return isLegal(layout);
1730 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1731 [=](xegpu::StoreScatterOp op) ->
bool {
1732 auto layout = op.getLayoutAttr();
1733 return isLegal(layout);
1736 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1737 [=](xegpu::ConvertLayoutOp op) ->
bool {
1738 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1741 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1742 [=](Operation *op) -> std::optional<bool> {
1747 VectorType resultType =
1755 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1756 if (!operandType || operandType.getShape() != resultType.getShape()) {
1761 xegpu::DistributeLayoutAttr layout =
1763 return isLegal(layout);
1766 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1767 [=](UnrealizedConversionCastOp op) {
1768 return llvm::is_contained(existingCastOps, op.getOperation());
1771 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1777 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1778 return signalPassFailure();
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.