29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 auto distributedShape = layout.computeDistributedShape(
57 if (
failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
61 return std::make_pair(sgData, count);
68template <
typename OpType,
69 typename = std::enable_if_t<llvm::is_one_of<
70 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
71 xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
73genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
78 if (origOffsets.empty())
82 xegpu::DistributeLayoutAttr layout;
83 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
84 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
85 layout = op.getLayoutAttr();
87 layout = op.getDescLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
95 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
104 return rewriter.notifyMatchFailure(
105 op,
"sg_layout size must match the sg_id_range");
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
117 auto maybeDescOffsets =
118 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119 if (
failed(maybeDescOffsets))
124 for (
const auto &sgOffsets : *maybeDescOffsets) {
127 offsetsList.push_back(std::move(newOffsets));
181struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
182 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
185 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter)
const override {
188 Location loc = op.getLoc();
190 xegpu::TensorDescType tdescTy = op.getType();
191 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
192 if (!layout || !layout.isForWorkgroup())
195 Type elemTy = tdescTy.getElementType();
196 ArrayRef<int64_t> wgShape = tdescTy.getShape();
198 SmallVector<int64_t> sgShape;
200 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
201 xegpu::TensorDescType newTdescTy =
202 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
203 layout.dropSgLayoutAndData());
205 SmallVector<Value> newCreateNdOps(count);
206 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
207 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
208 op.getSource(), op.getMixedSizes(),
209 op.getMixedStrides());
212 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
218struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
219 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
221 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter)
const override {
224 SmallVector<SmallVector<OpFoldResult>> offsetsList;
225 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
228 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
230 layout = layout.dropSgLayoutAndData();
231 SmallVector<Value> newOps;
232 for (
auto [tdesc, offsets] :
233 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
234 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
235 VectorType newResTy =
236 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
237 auto newOp = xegpu::LoadNdOp::create(
238 rewriter, op.getLoc(), newResTy, tdesc, offsets,
239 nullptr,
nullptr, op.getL1HintAttr(),
240 op.getL2HintAttr(), op.getL3HintAttr(), layout);
241 newOps.push_back(newOp);
243 rewriter.replaceOpWithMultiple(op, {newOps});
250struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
251 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 SmallVector<SmallVector<OpFoldResult>> offsetsList;
256 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
259 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
261 layout = layout.dropSgLayoutAndData();
262 for (
auto [v, tdesc, offsets] :
263 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
264 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
265 op.getL1HintAttr(), op.getL2HintAttr(),
266 op.getL3HintAttr(), layout);
268 rewriter.eraseOp(op);
275struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
276 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
278 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter)
const override {
280 SmallVector<SmallVector<OpFoldResult>> offsetsList;
281 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
284 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
286 layout = layout.dropSgLayoutAndData();
287 for (
auto [tdesc, offsets] :
288 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
289 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
290 op.getL1HintAttr(), op.getL2HintAttr(),
291 op.getL3HintAttr(), layout);
293 rewriter.eraseOp(op);
300struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
301 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
303 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const override {
305 Location loc = op.getLoc();
306 VectorType resultTy = op.getResult().getType();
307 if (resultTy.getRank() != 2)
310 auto layoutCd = op.getLayoutCdAttr();
311 auto layoutA = op.getLayoutAAttr();
312 auto layoutB = op.getLayoutBAttr();
313 if (!layoutCd || !layoutA || !layoutB)
316 SmallVector<Value> newDpasOps;
317 for (
auto aVec : adaptor.getLhs()) {
318 for (
auto bVec : adaptor.getRhs()) {
320 llvm::SmallVector<Value> operands({aVec, bVec});
323 tmpC = adaptor.getAcc()[i++];
324 operands.push_back(tmpC);
327 ArrayRef<int64_t> aVecShape =
328 cast<VectorType>(aVec.getType()).getShape();
329 ArrayRef<int64_t> bVecShape =
330 cast<VectorType>(bVec.getType()).getShape();
331 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
332 resultTy.getElementType());
333 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
334 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
335 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
336 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
338 newDpasOps.push_back(newDpasOp);
341 rewriter.replaceOpWithMultiple(op, {newDpasOps});
347struct WgToSgDpasMxOp :
public OpConversionPattern<xegpu::DpasMxOp> {
348 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
350 matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter)
const override {
353 Location loc = op.getLoc();
354 VectorType resultTy = op.getResult().getType();
356 if (resultTy.getRank() != 2)
359 auto layoutCd = op.getLayoutCdAttr();
360 auto layoutA = op.getLayoutAAttr();
361 auto layoutB = op.getLayoutBAttr();
362 auto layoutAScale = op.getLayoutAScaleAttr();
363 auto layoutBScale = op.getLayoutBScaleAttr();
365 if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
369 SmallVector<Value> newDpasMxOps;
370 for (
auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
371 for (
auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
372 Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
374 (op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
376 (op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
378 ArrayRef<int64_t> aVecShape =
379 cast<VectorType>(aVec.getType()).getShape();
380 ArrayRef<int64_t> bVecShape =
381 cast<VectorType>(bVec.getType()).getShape();
382 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
383 resultTy.getElementType());
384 auto newDpasMxOp = xegpu::DpasMxOp::create(
385 rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
386 layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
387 layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
388 layoutBScale.dropSgLayoutAndData());
390 newDpasMxOps.push_back(newDpasMxOp);
393 rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
399struct WgToSgVectorBroadcastOp
400 :
public OpConversionPattern<vector::BroadcastOp> {
401 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
404 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
407 VectorType resultType = op.getResult().getType();
408 ArrayRef<int64_t> wgShape = resultType.getShape();
410 xegpu::DistributeLayoutAttr layout =
412 if (!layout || !layout.isForWorkgroup())
415 SmallVector<int64_t> sgShape;
417 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
418 VectorType newResultType =
419 VectorType::get(sgShape, resultType.getElementType());
421 SmallVector<Value> newBroadcastOps;
422 auto distSource = adaptor.getOperands().front();
423 int numDistributions = count / distSource.size();
424 for (
int i = 0; i < numDistributions; ++i) {
425 for (
auto operand : distSource) {
426 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
427 newResultType, operand);
429 newBroadcastOps.push_back(newBroadcast.getResult());
432 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
439 WgToSgElementwiseOp(MLIRContext *ctx)
440 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
443 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
444 ConversionPatternRewriter &rewriter)
const override {
450 assert(resultType &&
"Expected result to be a VectorType");
454 xegpu::DistributeLayoutAttr layout =
456 if (!layout || !layout.isForWorkgroup())
461 size_t numVariants = operands.empty() ? 0 : operands.front().size();
463 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
464 return operandVec.size() != numVariants;
469 VectorType newResultType =
470 VectorType::get(sgShape, resultType.getElementType());
472 for (
size_t i = 0; i < numVariants; ++i) {
474 for (
auto &operandVec : operands)
475 opOperands.push_back(operandVec[i]);
478 state.addOperands(opOperands);
479 state.addTypes(newResultType);
480 state.addAttributes(op->
getAttrs());
483 newResults.push_back(newOp->
getResult(0));
486 rewriter.replaceOpWithMultiple(op, {newResults});
517struct WgToSgConvertLayoutOp
518 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
519 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
523 ConversionPatternRewriter &rewriter)
const override {
525 auto inputLayout = op.getInputLayout();
526 auto targetLayout = op.getTargetLayout();
528 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
529 !targetLayout.isForWorkgroup())
530 return rewriter.notifyMatchFailure(
531 op,
"Input and target layouts must have subgroup layout");
533 Type resultType = op.getResult().getType();
535 rewriter.replaceOp(op, op.getSource());
536 assert(!inputLayout.dropSgLayoutAndData() &&
537 !targetLayout.dropSgLayoutAndData() &&
538 "unexpected layout attributes for scalar type");
544 inputLayout.getEffectiveSgLayoutAsInt();
547 targetLayout.getEffectiveSgLayoutAsInt();
552 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
554 inputLayout = inputLayout.dropSgLayoutAndData();
555 targetLayout = targetLayout.dropSgLayoutAndData();
558 if (inputLayout && targetLayout) {
559 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
560 auto newOp = xegpu::ConvertLayoutOp::create(
561 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
565 rewriter.replaceOpWithMultiple(op, {newOps});
570 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
576 auto bytesPerElement = bitWidth / 8;
580 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
581 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
583 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
586 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
588 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
589 rewriter.getIndexType(),
nullptr);
592 auto storeCoords = inputLayout.computeDistributedCoords(
593 rewriter, loc, sgId.getResult(), wgShape);
594 if (failed(storeCoords))
598 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
600 for (
Value coord : coords) {
601 storeMatrixOffsets.push_back(coord);
603 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
604 storeMatrixOffsets,
nullptr );
607 gpu::BarrierOp::create(rewriter, loc);
610 auto loadCoords = targetLayout.computeDistributedCoords(
611 rewriter, loc, sgId.getResult(), wgShape);
612 if (failed(loadCoords))
615 VectorType loadType = VectorType::get(targetSgData, elemTy);
619 for (
auto coords : *loadCoords) {
621 for (
Value coord : coords) {
622 loadMatrixOffsets.push_back(coord);
624 auto loadOp = xegpu::LoadMatrixOp::create(
625 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
626 targetLayout.dropSgLayoutAndData());
628 finalResults.push_back(loadOp.getResult());
631 rewriter.replaceOpWithMultiple(op, {finalResults});
667struct UnrealizedConversionCastOpPattern
668 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
669 using OpConversionPattern<
670 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
673 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
OneToNOpAdaptor adaptor,
674 ConversionPatternRewriter &rewriter)
const override {
677 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
678 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
680 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
681 !llvm::all_equal(
ValueRange(inputs).getTypes()))
689 if (op.getNumOperands() == 1 &&
690 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
691 rewriter.replaceOp(op, inputs);
702 if (op.getNumResults() == 1 &&
704 rewriter.replaceOpWithMultiple(op, {inputs});
708 return mlir::failure();
713struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
714 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
717 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter)
const override {
719 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
720 auto vecType = dyn_cast<VectorType>(op.getType());
721 if (!vecAttr || !vecType)
724 xegpu::DistributeLayoutAttr layout =
726 if (!layout || !layout.isForWorkgroup())
729 ArrayRef<int64_t> wgShape = vecType.getShape();
730 SmallVector<int64_t> sgShape;
732 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
734 auto newType = VectorType::get(sgShape, vecType.getElementType());
735 Location loc = op.getLoc();
736 auto eltType = vecType.getElementType();
738 if (vecAttr.isSplat()) {
740 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
742 SmallVector<Value> newConstOps;
743 for (
int i = 0; i < count; ++i) {
744 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
745 newConstOps.push_back(cstOp);
747 rewriter.replaceOpWithMultiple(op, {newConstOps});
749 }
else if (sgShape == wgShape) {
752 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
753 rewriter.replaceOp(op, newConstOp);
759 if (!eltType.isIndex())
760 return rewriter.notifyMatchFailure(
761 op,
"Unsupported element type for non-splat constant op.");
763 if (wgShape.size() > 2)
764 return rewriter.notifyMatchFailure(
765 op,
"Only 1D & 2D vector constant supported");
767 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
768 int64_t rowStride = 0, colStride = 0;
769 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
770 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
774 colStride = cast<IntegerAttr>(values[1]).getInt() -
775 cast<IntegerAttr>(values[0]).getInt();
778 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
779 cast<IntegerAttr>(values[0]).getInt();
782 for (int64_t r = 0; r < rows; ++r) {
783 for (int64_t c = 0; c < cols; ++c) {
784 int64_t idx = r * cols + c;
786 if (c > 0 && cols > 1) {
787 int64_t prevIdx = r * cols + (c - 1);
788 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
789 cast<IntegerAttr>(values[prevIdx]).getInt();
790 if (diff != colStride)
791 return rewriter.notifyMatchFailure(
792 op,
"Non-constant column stride in constant op.");
795 if (r > 0 && rows > 1) {
796 int64_t prevIdx = (r - 1) * cols + c;
797 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
798 cast<IntegerAttr>(values[prevIdx]).getInt();
799 if (diff != rowStride)
800 return rewriter.notifyMatchFailure(
801 op,
"Non-constant row stride in constant op.");
809 SmallVector<Attribute> baseTileValues;
810 int baseTileCols = sgShape[sgShape.size() - 1];
811 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
812 for (int64_t r = 0; r < baseTileRows; ++r) {
813 for (int64_t c = 0; c < baseTileCols; ++c) {
814 baseTileValues.push_back(values[r * cols + c]);
820 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
824 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
826 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
830 SmallVector<Value, 2> strideConsts;
831 strideConsts.push_back(
835 strideConsts.begin(),
838 SmallVector<Value> newConstOps;
839 for (
auto offsets : *sgOffsets) {
842 for (
size_t i = 0; i < strideConsts.size(); ++i) {
844 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
845 offsets[i], strideConsts[i]);
846 mulOffset = arith::AddIOp::create(
847 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
850 auto bcastOffset = vector::BroadcastOp::create(
851 rewriter, loc, baseConstVec.getType(), mulOffset);
853 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
854 newConstOps.push_back(finalConst);
856 rewriter.replaceOpWithMultiple(op, {newConstOps});
864struct WgToSgLoadGatherOp :
public OpConversionPattern<xegpu::LoadGatherOp> {
865 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
867 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
868 ConversionPatternRewriter &rewriter)
const override {
870 Location loc = op.getLoc();
871 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
874 ArrayRef<int64_t> wgShape = resultType.getShape();
876 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
878 if (!layout || !layout.isForWorkgroup())
881 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
884 auto offsetsVecType =
885 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
887 dyn_cast<VectorType>(adaptor.getMask().front().getType());
888 if (!offsetsVecType || !maskVecType ||
889 offsetsVecType.getShape() != maskVecType.getShape()) {
890 return rewriter.notifyMatchFailure(op,
891 "offsets have not been distributed");
894 SmallVector<Value> newLoadOps;
896 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
897 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
898 for (
auto [offsets, mask] :
899 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
900 auto newLayout = layout.dropSgLayoutAndData();
901 auto newLoadOp = xegpu::LoadGatherOp::create(
902 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
903 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
905 newLoadOps.push_back(newLoadOp);
907 rewriter.replaceOpWithMultiple(op, {newLoadOps});
914struct WgToSgStoreScatterOp
915 :
public OpConversionPattern<xegpu::StoreScatterOp> {
916 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
918 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
921 Location loc = op.getLoc();
922 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
926 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
928 if (!layout || !layout.isForWorkgroup())
932 auto offsetsVecType =
933 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
935 dyn_cast<VectorType>(adaptor.getMask().front().getType());
936 if (!offsetsVecType || !maskVecType ||
937 offsetsVecType.getShape() != maskVecType.getShape()) {
938 return rewriter.notifyMatchFailure(op,
939 "offsets have not been distributed");
942 auto chunkSizeOpt = op.getChunkSize();
943 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
944 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
945 for (
auto [val, offs, mask] : llvm::zip(
946 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
947 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
948 mask, chunkSizeAttr, op.getL1HintAttr(),
949 op.getL2HintAttr(), op.getL3HintAttr(),
950 layout.dropSgLayoutAndData());
952 rewriter.eraseOp(op);
957struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
958 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
960 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter)
const override {
963 SmallVector<SmallVector<OpFoldResult>> offsetsList;
964 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
967 ArrayRef<int64_t> wgShape = op.getDataShape();
968 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
969 assert(valueTy &&
"the value type must be vector type!");
970 Type elemTy = valueTy.getElementType();
972 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
973 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
974 VectorType newResTy = VectorType::get(sgShape, elemTy);
975 SmallVector<Value> newOps;
976 for (
auto offsets : offsetsList) {
977 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
978 op.getMemDesc(), offsets,
979 layout.dropSgLayoutAndData());
980 newOps.push_back(newOp);
982 rewriter.replaceOpWithMultiple(op, {newOps});
988struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
989 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
991 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter)
const override {
994 SmallVector<SmallVector<OpFoldResult>> offsetsList;
995 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
998 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
999 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1000 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1001 offsets, layout.dropSgLayoutAndData());
1002 rewriter.eraseOp(op);
1008struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1009 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1011 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1012 ConversionPatternRewriter &rewriter)
const override {
1013 xegpu::DistributeLayoutAttr layout =
1015 if (!layout || !layout.isForWorkgroup())
1018 Location loc = op.getLoc();
1019 VectorType type = op.getResult().getType();
1020 auto wgShape = type.getShape();
1021 std::optional<SmallVector<int64_t>> sgShape =
1022 getSgShapeAndCount(wgShape, layout).first;
1027 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1029 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1033 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1034 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1035 SmallVector<Value> newOps;
1036 for (
auto offsets : *sgOffsets) {
1039 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1041 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1042 newOps.push_back(finalSteps);
1045 rewriter.replaceOpWithMultiple(op, {newOps});
1051struct WgToSgVectorShapeCastOp
1052 :
public OpConversionPattern<vector::ShapeCastOp> {
1053 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1056 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter)
const override {
1059 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1063 ArrayRef<int64_t> wgShape = resultType.getShape();
1064 xegpu::DistributeLayoutAttr layout =
1066 if (!layout || !layout.isForWorkgroup())
1071 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1075 ArrayRef<int64_t> srcShape = srcType.getShape();
1077 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1078 SmallVector<int64_t> expandedUnitDims;
1080 xegpu::DistributeLayoutAttr sourceLayout =
1083 if (!sourceLayout.isSliceOf(layout))
1084 return rewriter.notifyMatchFailure(
1085 op,
"The ShapeCast op only expands dimensions, the input layout "
1086 "must be a slice of the result layout.");
1088 assert(layoutToDistribute.isEqualTo(
1089 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1090 "The sg_data for unit dimensions should be set as 1");
1093 SmallVector<int64_t> sgShape =
1094 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1095 VectorType newResultType =
1096 VectorType::get(sgShape, resultType.getElementType());
1098 SmallVector<Value> newShapeCastOps;
1099 for (
auto src : adaptor.getSource()) {
1100 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1101 newResultType, src);
1102 newShapeCastOps.push_back(newShapeCast.getResult());
1105 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1142struct WgToSgMultiDimReductionOp
1143 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1144 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1147 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1149 Location loc = op.getLoc();
1151 VectorType srcType = op.getSourceVectorType();
1152 Type resultTy = op.getResult().getType();
1153 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1154 bool isScalarResult = !dstVecType;
1156 auto originalSrcShape = srcType.getShape();
1157 Type elemTy = srcType.getElementType();
1159 xegpu::DistributeLayoutAttr layout =
1161 if (!layout || !layout.isForWorkgroup())
1164 auto reductionDims = llvm::to_vector(op.getReductionDims());
1167 SmallVector<int64_t> sgLayout;
1168 SmallVector<int64_t> sgData;
1169 xegpu::DistributeLayoutAttr parentLayout;
1170 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1171 parentLayout = sliceAttr.getParent();
1172 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1173 sgData = parentLayout.getEffectiveSgDataAsInt();
1175 return rewriter.notifyMatchFailure(
1176 op,
"Reduction should have SliceAttr layout");
1179 SmallVector<Value> localReductions;
1180 auto sgSrcs = adaptor.getSource();
1181 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1182 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1183 sgSrcType.getShape().end());
1190 auto originalDstShape = dstVecType.getShape();
1191 SmallVector<int64_t> sgDstShape =
1192 getSgShapeAndCount(originalDstShape, layout).first;
1193 sgDstType = VectorType::get(sgDstShape, elemTy);
1198 for (
auto sgSrc : sgSrcs) {
1201 rewriter, loc, sgDstType, op.getKind());
1203 auto localReduce = vector::MultiDimReductionOp::create(
1204 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1206 localReductions.push_back(localReduce.getResult());
1210 SmallVector<int64_t> crossSgReductionDims;
1211 for (int64_t reductionDim : reductionDims) {
1212 bool needsCrossSubgroupReduction =
1213 (sgLayout[reductionDim] > 1) &&
1214 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1216 if (needsCrossSubgroupReduction) {
1217 crossSgReductionDims.push_back(reductionDim);
1222 if (crossSgReductionDims.empty()) {
1223 SmallVector<Value> results;
1224 for (
auto localResult : localReductions) {
1226 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1227 results.push_back(finalResult);
1229 rewriter.replaceOpWithMultiple(op, {results});
1234 auto slmStoreDataShape = sgSrcShape;
1235 for (int64_t dim : reductionDims)
1236 slmStoreDataShape[dim] = 1;
1237 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1238 SmallVector<Value> slmStoreData;
1239 for (
auto localResult : localReductions) {
1240 if (isScalarResult) {
1242 slmStoreData.push_back(vector::BroadcastOp::create(
1243 rewriter, loc, slmStoreDataType, localResult));
1245 slmStoreData.push_back(vector::ShapeCastOp::create(
1246 rewriter, loc, slmStoreDataType, localResult));
1250 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1251 originalSrcShape.end());
1252 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1253 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1254 for (
int dim : reductionDims) {
1255 slmShape[dim] = sgLayout[dim];
1258 xegpu::LayoutAttr slmStoreLayout =
1259 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1263 auto bytesPerElement = bitWidth / 8;
1265 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1266 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1268 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1271 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1274 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1275 rewriter.getIndexType(),
nullptr);
1277 auto slmStoreCoords =
1278 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1279 if (
failed(slmStoreCoords))
1281 for (
auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1282 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1283 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1288 gpu::BarrierOp::create(rewriter, loc);
1291 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1292 for (int64_t dim : reductionDims) {
1293 slmLoadDataShape[dim] = slmShape[dim];
1294 slmSgData[dim] = slmShape[dim];
1296 xegpu::LayoutAttr slmLoadLayout =
1297 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1298 auto slmLoadCoords =
1299 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1300 if (
failed(slmLoadCoords))
1303 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1304 SmallVector<Value> slmLoadData;
1305 for (
auto coord : *slmLoadCoords) {
1306 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1307 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1308 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1315 rewriter, loc, sgDstType, op.getKind());
1317 SmallVector<Value> finalResults;
1318 for (
size_t i = 0; i < slmLoadData.size(); ++i) {
1319 auto loaded = slmLoadData[i];
1320 auto finalReduce = vector::MultiDimReductionOp::create(
1321 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1324 rewriter, loc, op.getKind(), finalReduce.getResult(),
1325 adaptor.getAcc()[i]));
1327 rewriter.replaceOpWithMultiple(op, {finalResults});
1333struct WgToSgVectorTransposeOp
1334 :
public OpConversionPattern<vector::TransposeOp> {
1335 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1338 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1339 ConversionPatternRewriter &rewriter)
const override {
1340 VectorType resultType = op.getResultVectorType();
1342 ArrayRef<int64_t> wgShape = resultType.getShape();
1343 xegpu::DistributeLayoutAttr layout =
1345 if (!layout || !layout.isForWorkgroup())
1347 xegpu::DistributeLayoutAttr sourceLayout =
1349 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1352 SmallVector<int64_t> sourceSgLayout =
1353 sourceLayout.getEffectiveSgLayoutAsInt();
1354 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1356 ArrayRef<int64_t> permutation = op.getPermutation();
1357 size_t permutationSize = permutation.size();
1358 if (sourceSgLayout.size() != permutationSize ||
1359 resultSgLayout.size() != permutationSize) {
1360 return rewriter.notifyMatchFailure(
1361 op,
"Layouts and permutation must have the same rank");
1366 if (!layout.isTransposeOf(sourceLayout, permutation,
1367 xegpu::LayoutKind::Subgroup))
1368 return rewriter.notifyMatchFailure(
1369 op,
"Result layout is not a valid transpose of source layout "
1370 "according to permutation");
1372 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1373 VectorType newResultType =
1374 VectorType::get(sgShape, resultType.getElementType());
1376 SmallVector<Value> newTransposeOps;
1377 for (
auto src : adaptor.getVector()) {
1378 auto newTranspose = vector::TransposeOp::create(
1379 rewriter, op.getLoc(), newResultType, src, permutation);
1380 newTransposeOps.push_back(newTranspose.getResult());
1382 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1388template <
typename MaskOpType>
1389struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1390 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1392 LogicalResult matchAndRewrite(
1394 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1395 ConversionPatternRewriter &rewriter)
const override {
1396 xegpu::DistributeLayoutAttr layout =
1398 if (!layout || !layout.isForWorkgroup())
1401 Location loc = op.getLoc();
1402 VectorType type = op.getResult().getType();
1403 auto wgShape = type.getShape();
1405 SmallVector<Value> wgMaskDimSizes;
1406 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1407 for (int64_t maskSize : op.getMaskDimSizes()) {
1408 wgMaskDimSizes.push_back(
1411 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1412 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1416 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1418 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1422 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1423 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1427 SmallVector<Value> newCreateMaskOps;
1428 for (
auto offsetSet : *sgOffsets) {
1429 SmallVector<Value> maskOperands;
1431 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1434 Value offset = offsetSet[i];
1435 Value adjustedMaskSize =
1436 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1439 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1441 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1442 maskOperands.push_back(sgMaskSize);
1445 auto newCreateMaskOp =
1446 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1447 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1450 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1455using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1456using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1459struct WgToSgVectorBitCastOp :
public OpConversionPattern<vector::BitCastOp> {
1460 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1463 matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1464 ConversionPatternRewriter &rewriter)
const override {
1465 VectorType resultType = op.getResultVectorType();
1467 ArrayRef<int64_t> wgShape = resultType.getShape();
1468 xegpu::DistributeLayoutAttr layout =
1470 if (!layout || !layout.isForWorkgroup())
1473 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1474 VectorType newResultType =
1475 VectorType::get(sgShape, resultType.getElementType());
1477 SmallVector<Value> newBitCastOps;
1478 for (
auto src : adaptor.getSource()) {
1480 vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1481 newBitCastOps.push_back(newBitCast.getResult());
1484 rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1490struct WgToSgVectorInterleaveOp
1491 :
public OpConversionPattern<vector::InterleaveOp> {
1492 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1495 matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1496 ConversionPatternRewriter &rewriter)
const override {
1497 VectorType resultType = op.getResultVectorType();
1499 ArrayRef<int64_t> wgShape = resultType.getShape();
1500 xegpu::DistributeLayoutAttr layout =
1502 if (!layout || !layout.isForWorkgroup())
1505 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1506 VectorType newResultType =
1507 VectorType::get(sgShape, resultType.getElementType());
1509 SmallVector<Value> newInterleaveOps;
1512 for (
auto [
lhs,
rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1513 auto newInterleave = vector::InterleaveOp::create(
1514 rewriter, op.getLoc(), newResultType,
lhs,
rhs);
1515 newInterleaveOps.push_back(newInterleave.getResult());
1518 rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1524struct WgToSgVectorDeinterleaveOp
1525 :
public OpConversionPattern<vector::DeinterleaveOp> {
1526 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1529 matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1530 ConversionPatternRewriter &rewriter)
const override {
1531 SmallVector<Value> newRes1Ops;
1532 SmallVector<Value> newRes2Ops;
1534 for (
auto src : adaptor.getSource()) {
1535 auto newDeinterleave =
1536 vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1537 newRes1Ops.push_back(newDeinterleave.getRes1());
1538 newRes2Ops.push_back(newDeinterleave.getRes2());
1541 SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1542 rewriter.replaceOpWithMultiple(op, results);
1552 patterns.
add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1553 WgToSgDpasMxOp, WgToSgPrefetchNdOp,
1554 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
1555 WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1556 WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1557 WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1558 WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1559 WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1560 WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1561 WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
1568struct XeGPUWgToSgDistributePass
1570 void runOnOperation()
override;
1574void XeGPUWgToSgDistributePass::runOnOperation() {
1576 Operation *op = getOperation();
1578 signalPassFailure();
1583 SmallVector<Operation *> existingCastOps;
1584 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1585 existingCastOps.push_back(castOp.getOperation());
1595 TypeConverter converter;
1596 converter.addConversion([&](Type type) -> Type {
return type; });
1597 converter.addConversion(
1598 [&](RankedTensorType type,
1599 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1603 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1604 type.getEncoding());
1606 return std::nullopt;
1608 Type elemTy = type.getElementType();
1609 ArrayRef<int64_t> shape = type.getShape();
1612 SmallVector<int64_t> subShape;
1613 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1615 auto newTy = VectorType::get(subShape, elemTy);
1616 result.append(count, newTy);
1627 RewritePatternSet patterns(ctx);
1628 ConversionTarget
target(*ctx);
1629 TypeConverter converter;
1630 converter.addConversion([&](Type type) -> Type {
return type; });
1631 converter.addConversion(
1632 [&](xegpu::TensorDescType type,
1633 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1634 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1637 if (!layout || !layout.isForWorkgroup())
1638 return std::nullopt;
1640 Type elemTy = type.getElementType();
1641 ArrayRef<int64_t> shape = type.getShape();
1644 SmallVector<int64_t> subShape;
1645 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1647 layout = layout.dropSgLayoutAndData();
1649 auto newTy = xegpu::TensorDescType::get(
1650 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1651 result.append(count, newTy);
1655 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1656 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1657 return createOp.getType();
1658 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1659 return loadOp.getTensorDescType();
1660 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1661 return storeOp.getTensorDescType();
1662 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1663 return prefetchOp.getTensorDescType();
1664 return xegpu::TensorDescType();
1667 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1668 return !layout || !layout.isForWorkgroup();
1671 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1672 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(
1673 [=](Operation *op) ->
bool {
1674 auto tdescTy = getTensorDescType(op);
1676 dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1677 return isLegal(layout);
1680 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1681 auto layout = op.getLayoutCdAttr();
1682 return isLegal(layout);
1685 target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1686 [=](xegpu::DpasMxOp op) ->
bool {
1687 auto layout = op.getLayoutCdAttr();
1688 return isLegal(layout);
1691 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1692 [=](xegpu::LoadMatrixOp op) ->
bool {
1693 return isLegal(op.getLayoutAttr());
1696 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1697 [=](xegpu::StoreMatrixOp op) ->
bool {
1698 return isLegal(op.getLayoutAttr());
1701 target.addDynamicallyLegalOp<arith::ConstantOp>(
1702 [=](arith::ConstantOp op) ->
bool {
1703 auto vecType = dyn_cast<VectorType>(op.getType());
1709 return isLegal(layout);
1712 target.addDynamicallyLegalOp<
1713 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1714 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1715 vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1716 vector::DeinterleaveOp>([=](Operation *op) ->
bool {
1720 return isLegal(layout);
1723 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1724 [=](xegpu::LoadGatherOp op) ->
bool {
1725 auto layout = op.getLayoutAttr();
1726 return isLegal(layout);
1729 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1730 [=](xegpu::StoreScatterOp op) ->
bool {
1731 auto layout = op.getLayoutAttr();
1732 return isLegal(layout);
1735 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1736 [=](xegpu::ConvertLayoutOp op) ->
bool {
1737 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1740 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1741 [=](Operation *op) -> std::optional<bool> {
1746 VectorType resultType =
1754 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1755 if (!operandType || operandType.getShape() != resultType.getShape()) {
1760 xegpu::DistributeLayoutAttr layout =
1762 return isLegal(layout);
1765 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1766 [=](UnrealizedConversionCastOp op) {
1767 return llvm::is_contained(existingCastOps, op.getOperation());
1770 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1776 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1777 return signalPassFailure();
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.