29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 if (layout && layout.isForWorkgroup()) {
57 if (!layout.getEffectiveSgDataAsInt().empty())
58 sgShape = layout.getEffectiveSgDataAsInt();
60 sgShape = *maybeDerivedSgData;
65 for (
size_t i = 0; i < distUnit.size(); ++i)
66 distUnit[i] = std::min(
shape[i], distUnit[i]);
69 return std::make_pair(sgShape, count);
78 typename = std::enable_if_t<llvm::is_one_of<
79 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
80 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
82genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
87 if (origOffsets.empty())
91 xegpu::DistributeLayoutAttr layout;
92 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
93 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
94 layout = op.getLayoutAttr();
96 layout = op.getDescLayoutAttr();
100 if (!layout || !layout.isForWorkgroup())
104 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
107 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
109 int64_t startOfRange = sgIdRange.getStart().getInt();
110 int64_t endOfRange = sgIdRange.getEnd().getInt();
112 if (layout.getNumSubgroups() != endOfRange - startOfRange)
113 return rewriter.notifyMatchFailure(
114 op,
"sg_layout size must match the sg_id_range");
116 if (startOfRange > 0) {
117 Value startOfRangeVal =
119 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
126 auto maybeDescOffsets =
127 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
128 if (
failed(maybeDescOffsets))
133 for (
const auto &sgOffsets : *maybeDescOffsets) {
136 offsetsList.push_back(std::move(newOffsets));
188struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
189 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
192 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter)
const override {
194 SmallVector<SmallVector<OpFoldResult>> offsetsList;
195 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
198 MLIRContext *ctx = op.getContext();
199 xegpu::TensorDescType tdescTy = op.getType();
200 ArrayRef<int64_t> wgShape = tdescTy.getShape();
201 Type elemTy = tdescTy.getElementType();
202 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
203 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
205 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
206 layout.dropSgLayoutAndData());
208 SmallVector<Value> newOps;
209 for (
auto offsets : offsetsList) {
210 auto newOp = xegpu::CreateNdDescOp::create(
211 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
212 op.getMixedSizes(), op.getMixedStrides());
214 newOps.push_back(newOp);
216 rewriter.replaceOpWithMultiple(op, {newOps});
224struct WgToSgCreateNdOpNoOffset
225 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
226 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
229 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter)
const override {
233 if (!op.getMixedOffsets().empty())
236 Location loc = op.getLoc();
237 MLIRContext *ctx = op.getContext();
238 xegpu::TensorDescType tdescTy = op.getType();
239 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
240 if (!layout || !layout.isForWorkgroup())
243 Type elemTy = tdescTy.getElementType();
244 ArrayRef<int64_t> wgShape = tdescTy.getShape();
246 SmallVector<int64_t> sgShape;
248 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
249 xegpu::TensorDescType newTdescTy =
250 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
251 layout.dropSgLayoutAndData());
253 SmallVector<Value> newCreateNdOps(count);
254 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
255 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
256 op.getSource(), op.getMixedSizes(),
257 op.getMixedStrides());
260 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
266struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
267 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
269 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter)
const override {
271 if (!op.getMixedOffsets().empty())
274 SmallVector<Value> newLoadOps;
275 for (
auto src : adaptor.getTensorDesc()) {
276 xegpu::TensorDescType tdescTy =
277 dyn_cast<xegpu::TensorDescType>(src.getType());
278 ArrayRef<int64_t> srcShape = tdescTy.getShape();
279 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
280 auto newLoadOp = xegpu::LoadNdOp::create(
281 rewriter, op.getLoc(), newResTy, src,
283 newLoadOps.push_back(newLoadOp);
285 rewriter.replaceOpWithMultiple(op, {newLoadOps});
286 return mlir::success();
293struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
294 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
296 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter)
const override {
298 if (!op.getMixedOffsets().empty())
301 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
302 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
303 op.getL2HintAttr(), op.getL3HintAttr());
305 rewriter.eraseOp(op);
312struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
313 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
315 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter)
const override {
318 SmallVector<SmallVector<OpFoldResult>> offsetsList;
319 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
322 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
324 layout = layout.dropSgLayoutAndData();
325 SmallVector<Value> newOps;
326 for (
auto [tdesc, offsets] :
327 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
328 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
329 VectorType newResTy =
330 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
331 auto newOp = xegpu::LoadNdOp::create(
332 rewriter, op.getLoc(), newResTy, tdesc, offsets,
333 nullptr,
nullptr, op.getL1HintAttr(),
334 op.getL2HintAttr(), op.getL3HintAttr(), layout);
335 newOps.push_back(newOp);
337 rewriter.replaceOpWithMultiple(op, {newOps});
345struct WgToSgStoreNdOpWithOffset
346 :
public OpConversionPattern<xegpu::StoreNdOp> {
347 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
349 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter)
const override {
351 SmallVector<SmallVector<OpFoldResult>> offsetsList;
352 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
355 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
357 layout = layout.dropSgLayoutAndData();
358 for (
auto [v, tdesc, offsets] :
359 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
360 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
361 op.getL1HintAttr(), op.getL2HintAttr(),
362 op.getL3HintAttr(), layout);
364 rewriter.eraseOp(op);
372struct WgToSgPrefetchNdOpWithOffset
373 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
374 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
376 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
377 ConversionPatternRewriter &rewriter)
const override {
378 SmallVector<SmallVector<OpFoldResult>> offsetsList;
379 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
382 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
384 layout = layout.dropSgLayoutAndData();
385 for (
auto [tdesc, offsets] :
386 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
387 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
388 op.getL1HintAttr(), op.getL2HintAttr(),
389 op.getL3HintAttr(), layout);
391 rewriter.eraseOp(op);
400struct WgToSgUpdateNdOffsetOp
401 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
402 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
404 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
406 llvm::SmallVector<Value> newUpdateTileOffsetOps;
407 for (
auto tDesc : adaptor.getTensorDesc()) {
408 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
409 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
410 op.getConstOffsets());
411 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
414 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
420struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
421 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
423 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 Location loc = op.getLoc();
426 VectorType resultTy = op.getResult().getType();
427 if (resultTy.getRank() != 2)
430 auto layoutCd = op.getLayoutCdAttr();
431 auto layoutA = op.getLayoutAAttr();
432 auto layoutB = op.getLayoutBAttr();
433 if (!layoutCd || !layoutA || !layoutB)
436 SmallVector<Value> newDpasOps;
437 for (
auto aVec : adaptor.getLhs()) {
438 for (
auto bVec : adaptor.getRhs()) {
440 llvm::SmallVector<Value> operands({aVec, bVec});
443 tmpC = adaptor.getAcc()[i++];
444 operands.push_back(tmpC);
447 ArrayRef<int64_t> aVecShape =
448 llvm::cast<VectorType>(aVec.getType()).getShape();
449 ArrayRef<int64_t> bVecShape =
450 llvm::cast<VectorType>(bVec.getType()).getShape();
451 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
452 resultTy.getElementType());
453 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
454 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
455 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
456 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
458 newDpasOps.push_back(newDpasOp);
461 rewriter.replaceOpWithMultiple(op, {newDpasOps});
467struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
468 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
470 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
473 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
474 if ((offsetSize != 0) || op.getConstOffsetsAttr())
477 for (
auto src : adaptor.getTensorDesc())
478 xegpu::PrefetchNdOp::create(
481 rewriter.eraseOp(op);
487struct WgToSgVectorBroadcastOp
488 :
public OpConversionPattern<vector::BroadcastOp> {
489 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
492 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
493 ConversionPatternRewriter &rewriter)
const override {
495 VectorType resultType = op.getResult().getType();
496 ArrayRef<int64_t> wgShape = resultType.getShape();
498 xegpu::DistributeLayoutAttr layout =
500 if (!layout || !layout.isForWorkgroup())
503 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
504 VectorType newResultType =
505 VectorType::get(sgShape, resultType.getElementType());
507 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
510 SmallVector<Value> newBroadcastOps;
511 for (
auto operand : adaptor.getOperands().front()) {
512 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
513 newResultType, operand);
515 newBroadcastOps.push_back(newBroadcast.getResult());
517 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
524 WgToSgElementwiseOp(MLIRContext *ctx)
525 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
528 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
529 ConversionPatternRewriter &rewriter)
const override {
535 assert(resultType &&
"Expected result to be a VectorType");
537 ArrayRef<int64_t> wgShape = resultType.getShape();
539 xegpu::DistributeLayoutAttr layout =
541 if (!layout || !layout.isForWorkgroup())
544 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
546 size_t numVariants = operands.empty() ? 0 : operands.front().size();
548 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
549 return operandVec.size() != numVariants;
553 SmallVector<Value> newResults;
554 VectorType newResultType =
555 VectorType::get(sgShape, resultType.getElementType());
557 for (
size_t i = 0; i < numVariants; ++i) {
558 SmallVector<Value> opOperands;
559 for (
auto &operandVec : operands)
560 opOperands.push_back(operandVec[i]);
563 state.addOperands(opOperands);
564 state.addTypes(newResultType);
565 state.addAttributes(op->
getAttrs());
566 Operation *newOp = rewriter.create(state);
568 newResults.push_back(newOp->
getResult(0));
571 rewriter.replaceOpWithMultiple(op, {newResults});
602struct WgToSgConvertLayoutOp
603 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
604 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
608 ConversionPatternRewriter &rewriter)
const override {
611 VectorType resultType = op.getResult().getType();
613 auto inputLayout = op.getInputLayout();
614 auto targetLayout = op.getTargetLayout();
616 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
617 !targetLayout.isForWorkgroup())
618 return rewriter.notifyMatchFailure(
619 op,
"Input and target layouts must have subgroup layout");
622 inputLayout.getEffectiveSgLayoutAsInt();
625 targetLayout.getEffectiveSgLayoutAsInt();
629 if (inputLayout.isCompatibleWith(targetLayout,
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
635 if (inputLayout && targetLayout) {
636 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
642 rewriter.replaceOpWithMultiple(op, {newOps});
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
653 auto bytesPerElement = bitWidth / 8;
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(),
nullptr);
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
671 if (failed(storeCoords))
675 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
677 for (
Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets,
nullptr );
684 gpu::BarrierOp::create(rewriter, loc);
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
695 SmallVector<Value> finalResults;
696 for (
auto coords : *loadCoords) {
697 SmallVector<OpFoldResult> loadMatrixOffsets;
698 for (Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
705 finalResults.push_back(loadOp.getResult());
708 rewriter.replaceOpWithMultiple(op, {finalResults});
744struct UnrealizedConversionCastOpPattern
745 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter)
const override {
754 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(
ValueRange(inputs).getTypes()))
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
779 if (op.getNumResults() == 1 &&
781 rewriter.replaceOpWithMultiple(op, {inputs});
785 return mlir::failure();
790struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter)
const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
801 xegpu::DistributeLayoutAttr layout =
803 if (!layout || !layout.isForWorkgroup())
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
815 if (vecAttr.isSplat()) {
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
822 }
else if (sgShape == wgShape) {
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op,
"Unsupported element type for non-splat constant op.");
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op,
"Only 1D & 2D vector constant supported");
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op,
"Non-constant column stride in constant op.");
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op,
"Non-constant row stride in constant op.");
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
897 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
908 strideConsts.begin(),
911 SmallVector<Value> newConstOps;
912 for (
auto offsets : *sgOffsets) {
915 for (
size_t i = 0; i < strideConsts.size(); ++i) {
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
937struct WgToSgLoadGatherOpWithOffset
938 :
public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter)
const override {
944 if (!op.getOffsets())
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
951 ArrayRef<int64_t> wgShape = resultType.getShape();
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
955 if (!layout || !layout.isForWorkgroup())
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
971 SmallVector<Value> newLoadOps;
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (
auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
982 newLoadOps.push_back(newLoadOp);
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
991struct WgToSgStoreScatterOpWithOffset
992 :
public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter)
const override {
998 if (!op.getOffsets())
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1008 if (!layout || !layout.isForWorkgroup())
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (
auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1032 rewriter.eraseOp(op);
1037struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter)
const override {
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy &&
"the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (
auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1068struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter)
const override {
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1088struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter)
const override {
1093 xegpu::DistributeLayoutAttr layout =
1095 if (!layout || !layout.isForWorkgroup())
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1107 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (
auto offsets : *sgOffsets) {
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1131struct WgToSgVectorShapeCastOp
1132 :
public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter)
const override {
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1146 if (!layout || !layout.isForWorkgroup())
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1160 xegpu::DistributeLayoutAttr sourceLayout =
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op,
"ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op,
"The ShapeCast op only expands dimensions, the result layout "
1177 "must be a slice of the input layout, or vice versa.");
1178 layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
1179 layoutToDistribute =
1180 layoutToDistribute.setUnitDimLayout(expandedUnitDims);
1183 SmallVector<int64_t> sgShape =
1184 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1185 VectorType newResultType =
1186 VectorType::get(sgShape, resultType.getElementType());
1188 SmallVector<Value> newShapeCastOps;
1189 for (
auto src : adaptor.getSource()) {
1190 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1191 newResultType, src);
1192 newShapeCastOps.push_back(newShapeCast.getResult());
1195 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1200static Value createAccumulator(ConversionPatternRewriter &rewriter,
1202 vector::CombiningKind kind) {
1203 Type elemTy = type.getElementType();
1206 case vector::CombiningKind::ADD:
1207 case vector::CombiningKind::XOR:
1208 case vector::CombiningKind::OR:
1209 return arith::ConstantOp::create(
1210 rewriter, loc, type,
1213 case vector::CombiningKind::MUL:
1214 case vector::CombiningKind::AND:
1215 return arith::ConstantOp::create(
1216 rewriter, loc, type,
1219 case vector::CombiningKind::MINSI:
1221 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1222 auto maxVal = APInt::getSignedMaxValue(intTy.getWidth());
1223 return arith::ConstantOp::create(
1224 rewriter, loc, type,
1226 rewriter.getIntegerAttr(elemTy, maxVal)));
1230 case vector::CombiningKind::MINUI:
1231 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1232 auto maxVal = APInt::getMaxValue(intTy.getWidth());
1233 return arith::ConstantOp::create(
1234 rewriter, loc, type,
1236 rewriter.getIntegerAttr(elemTy, maxVal)));
1240 case vector::CombiningKind::MAXSI:
1241 if (
auto intTy = dyn_cast<IntegerType>(elemTy)) {
1242 auto minVal = APInt::getSignedMinValue(intTy.getWidth());
1243 return arith::ConstantOp::create(
1244 rewriter, loc, type,
1246 rewriter.getIntegerAttr(elemTy, minVal)));
1250 case vector::CombiningKind::MAXUI:
1251 return arith::ConstantOp::create(
1252 rewriter, loc, type,
1255 case vector::CombiningKind::MINNUMF:
1256 case vector::CombiningKind::MINIMUMF:
1258 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1259 auto posInf = APFloat::getInf(floatTy.getFloatSemantics());
1260 return arith::ConstantOp::create(
1261 rewriter, loc, type,
1266 case vector::CombiningKind::MAXNUMF:
1267 case vector::CombiningKind::MAXIMUMF:
1269 if (
auto floatTy = dyn_cast<FloatType>(elemTy)) {
1270 auto negInf = APFloat::getInf(floatTy.getFloatSemantics(),
true);
1271 return arith::ConstantOp::create(
1272 rewriter, loc, type,
1307static Value linearizeSubgroupIndices(ConversionPatternRewriter &rewriter,
1315 Value dimVal = sgIds[dim];
1317 Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1319 arith::AddIOp::create(rewriter, loc, linearizedOffset, term);
1320 stride *= sgLayout[dim];
1323 return linearizedOffset;
1358struct WgToSgMultiDimReductionOp
1359 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1360 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1363 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1364 ConversionPatternRewriter &rewriter)
const override {
1365 Location loc = op.getLoc();
1367 VectorType srcType = op.getSourceVectorType();
1368 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1372 auto originalSrcShape = srcType.getShape();
1373 xegpu::DistributeLayoutAttr layout =
1375 if (!layout || !layout.isForWorkgroup())
1378 auto reductionDims = llvm::to_vector(op.getReductionDims());
1381 SmallVector<int64_t> sgLayout;
1382 SmallVector<int64_t> sgData;
1383 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1384 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1385 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1387 return rewriter.notifyMatchFailure(
1388 op,
"Reduction should have SliceAttr layout");
1390 Type elemTy = dstType.getElementType();
1393 SmallVector<Value> localReductions;
1394 SmallVector<int64_t> sgShape =
1395 getSgShapeAndCount(originalSrcShape, layout).first;
1396 VectorType newDstType = VectorType::get(sgShape, elemTy);
1397 for (
auto sgSrc : adaptor.getSource()) {
1399 auto neutralLocalAcc =
1400 createAccumulator(rewriter, loc, newDstType, op.getKind());
1402 auto localReduce = vector::MultiDimReductionOp::create(
1403 rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc,
1405 localReductions.push_back(localReduce.getResult());
1409 SmallVector<int64_t> crossSgReductionDims;
1410 for (int64_t reductionDim : reductionDims) {
1411 bool needsCrossSubgroupReduction =
1412 (sgLayout[reductionDim] > 1) &&
1413 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1415 if (needsCrossSubgroupReduction) {
1416 crossSgReductionDims.push_back(reductionDim);
1421 if (crossSgReductionDims.empty()) {
1422 SmallVector<Value> results;
1423 for (
auto localResult : localReductions) {
1425 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1426 results.push_back(finalResult);
1428 rewriter.replaceOpWithMultiple(op, {results});
1438 SmallVector<int64_t> storeShape2D = {1, localElements};
1439 VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
1440 auto storeShapeCast = vector::ShapeCastOp::create(
1441 rewriter, loc, storeType2D, localReductions[0]);
1442 Value storeData = storeShapeCast.getResult();
1446 int64_t totalReductionSubgroups = 1;
1447 for (int64_t dim : crossSgReductionDims) {
1448 totalReductionSubgroups *= sgLayout[dim];
1452 int64_t totalResultElements =
1453 localElements *
computeProduct(sgLayout) / totalReductionSubgroups;
1455 SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
1456 totalResultElements};
1460 auto bytesPerElement = bitWidth / 8;
1461 int64_t slmElements = slmShape2D[0] * slmShape2D[1];
1462 auto slmSize = slmElements * bytesPerElement;
1463 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1464 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1466 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
1467 slmShape2D, elemTy,
nullptr);
1469 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1472 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1473 rewriter.getIndexType(),
nullptr);
1476 SmallVector<Value> sgLayoutValues;
1477 for (int64_t dim : sgLayout)
1478 sgLayoutValues.push_back(
1481 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1485 SmallVector<Value> sgIds = *sgIdsResult;
1488 Value rowOffsetStore = linearizeSubgroupIndices(
1489 rewriter, loc, sgIds, crossSgReductionDims, sgLayout);
1492 SmallVector<int64_t> nonReductionDims;
1493 for (
size_t i = 0; i < sgLayout.size(); ++i) {
1494 if (!llvm::is_contained(reductionDims,
static_cast<int64_t
>(i))) {
1495 nonReductionDims.push_back(
static_cast<int64_t
>(i));
1499 Value colOffset = linearizeSubgroupIndices(rewriter, loc, sgIds,
1500 nonReductionDims, sgLayout);
1502 Value localElementsVal =
1505 arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
1507 SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
1509 xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
1510 storeOffsets2D,
nullptr);
1512 gpu::BarrierOp::create(rewriter, loc);
1515 SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
1516 VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
1522 SmallVector<OpFoldResult> loadOffsets2D = {rowOffsetLoad, colOffset};
1524 auto loadOp = xegpu::LoadMatrixOp::create(
1525 rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
1529 SmallVector<int64_t> finalReductionDims = {0};
1530 SmallVector<int64_t> finalResultShape = {localElements};
1531 VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
1533 auto neutralFinalAcc =
1534 createAccumulator(rewriter, loc, finalResultType, op.getKind());
1536 auto finalReduce = vector::MultiDimReductionOp::create(
1537 rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
1538 neutralFinalAcc, finalReductionDims);
1541 Value originalAcc = adaptor.getAcc()[0];
1542 Value accToAdd = originalAcc;
1545 if (originalAcc.
getType() != finalReduce.getResult().getType()) {
1546 auto originalAccType = cast<VectorType>(originalAcc.
getType());
1547 auto finalResultType =
1548 cast<VectorType>(finalReduce.getResult().getType());
1551 if (originalAccType.getNumElements() ==
1552 finalResultType.getNumElements()) {
1553 auto shapeCast = vector::ShapeCastOp::create(
1554 rewriter, loc, finalResultType, originalAcc);
1555 accToAdd = shapeCast.getResult();
1560 rewriter, loc, op.getKind(), finalReduce.getResult(), accToAdd);
1562 rewriter.replaceOp(op, finalResult);
1568struct WgToSgVectorTransposeOp
1569 :
public OpConversionPattern<vector::TransposeOp> {
1570 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1573 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter)
const override {
1575 VectorType resultType = op.getResultVectorType();
1577 ArrayRef<int64_t> wgShape = resultType.getShape();
1578 xegpu::DistributeLayoutAttr layout =
1580 if (!layout || !layout.isForWorkgroup())
1583 xegpu::DistributeLayoutAttr sourceLayout =
1585 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1588 SmallVector<int64_t> sourceSgLayout =
1589 sourceLayout.getEffectiveSgLayoutAsInt();
1590 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1594 if (!sourceOrder || !resultOrder) {
1595 return rewriter.notifyMatchFailure(
1596 op,
"Both source and result must have order attributes");
1599 ArrayRef<int64_t> permutation = op.getPermutation();
1600 size_t permutationSize = permutation.size();
1601 if (sourceSgLayout.size() != permutationSize ||
1602 resultSgLayout.size() != permutationSize) {
1603 return rewriter.notifyMatchFailure(
1604 op,
"Layouts and permutation must have the same rank");
1609 if (!layout.isTransposeOf(sourceLayout, permutation))
1610 return rewriter.notifyMatchFailure(
1611 op,
"Result layout is not a valid transpose of source layout "
1612 "according to permutation");
1614 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1615 VectorType newResultType =
1616 VectorType::get(sgShape, resultType.getElementType());
1617 SmallVector<Value> newTransposeOps;
1618 for (
auto src : adaptor.getVector()) {
1619 auto newTranspose = vector::TransposeOp::create(
1620 rewriter, op.getLoc(), newResultType, src, permutation);
1621 newTransposeOps.push_back(newTranspose.getResult());
1624 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1630template <
typename MaskOpType>
1631struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1632 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1634 LogicalResult matchAndRewrite(
1636 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter)
const override {
1638 xegpu::DistributeLayoutAttr layout =
1640 if (!layout || !layout.isForWorkgroup())
1643 Location loc = op.getLoc();
1644 VectorType type = op.getResult().getType();
1645 auto wgShape = type.getShape();
1647 SmallVector<Value> wgMaskDimSizes;
1648 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1649 for (int64_t maskSize : op.getMaskDimSizes()) {
1650 wgMaskDimSizes.push_back(
1653 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1654 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1658 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1660 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1664 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1665 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1669 SmallVector<Value> newCreateMaskOps;
1670 for (
auto offsetSet : *sgOffsets) {
1671 SmallVector<Value> maskOperands;
1673 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1676 Value offset = offsetSet[i];
1677 Value adjustedMaskSize =
1678 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1681 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1683 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1684 maskOperands.push_back(sgMaskSize);
1687 auto newCreateMaskOp =
1688 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1689 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1692 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1697using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1698using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1705 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1706 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1707 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1708 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1709 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1710 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1711 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1712 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1713 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1714 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1721struct XeGPUWgToSgDistributePass
1723 void runOnOperation()
override;
1727void XeGPUWgToSgDistributePass::runOnOperation() {
1729 Operation *op = getOperation();
1731 signalPassFailure();
1736 SmallVector<Operation *> existingCastOps;
1737 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1738 existingCastOps.push_back(castOp.getOperation());
1748 TypeConverter converter;
1749 converter.addConversion([&](Type type) -> Type {
return type; });
1750 converter.addConversion(
1751 [&](RankedTensorType type,
1752 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1753 Type elemTy = type.getElementType();
1754 ArrayRef<int64_t> shape = type.getShape();
1757 SmallVector<int64_t> subShape;
1758 std::tie(subShape, count) = getSgShapeAndCount(
1760 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1762 auto newTy = VectorType::get(subShape, elemTy);
1763 result.append(count, newTy);
1775 ConversionTarget
target(*ctx);
1776 TypeConverter converter;
1777 converter.addConversion([&](Type type) -> Type {
return type; });
1778 converter.addConversion(
1779 [&](xegpu::TensorDescType type,
1780 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1781 Type elemTy = type.getElementType();
1782 ArrayRef<int64_t> shape = type.getShape();
1785 SmallVector<int64_t> subShape;
1786 xegpu::LayoutAttr layout = type.getLayoutAttr();
1787 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1790 layout = layout.dropSgLayoutAndData();
1792 auto newTy = xegpu::TensorDescType::get(
1793 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1794 result.append(count, newTy);
1798 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1799 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1800 return createOp.getType();
1801 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1802 return loadOp.getTensorDescType();
1803 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1804 return storeOp.getTensorDescType();
1805 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1807 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1808 return prefetchOp.getTensorDescType();
1809 return xegpu::TensorDescType();
1812 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1813 return !layout || !layout.isForWorkgroup();
1816 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1817 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1818 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1819 auto tdescTy = getTensorDescType(op);
1820 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1821 return isLegal(layout);
1824 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1825 auto layout = op.getLayoutCdAttr();
1826 return isLegal(layout);
1829 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1830 [=](xegpu::LoadMatrixOp op) ->
bool {
1831 return isLegal(op.getLayoutAttr());
1834 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1835 [=](xegpu::StoreMatrixOp op) ->
bool {
1836 return isLegal(op.getLayoutAttr());
1839 target.addDynamicallyLegalOp<arith::ConstantOp>(
1840 [=](arith::ConstantOp op) ->
bool {
1841 auto vecType = dyn_cast<VectorType>(op.getType());
1847 return isLegal(layout);
1850 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1851 vector::TransposeOp, vector::BroadcastOp,
1852 vector::MultiDimReductionOp,
1853 vector::ConstantMaskOp, vector::CreateMaskOp>(
1854 [=](Operation *op) ->
bool {
1858 return isLegal(layout);
1861 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1862 [=](xegpu::LoadGatherOp op) ->
bool {
1863 auto layout = op.getLayoutAttr();
1864 return isLegal(layout);
1867 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1868 [=](xegpu::StoreScatterOp op) ->
bool {
1869 auto layout = op.getLayoutAttr();
1870 return isLegal(layout);
1873 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1874 [=](xegpu::ConvertLayoutOp op) ->
bool {
1875 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1878 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1879 [=](Operation *op) -> std::optional<bool> {
1884 VectorType resultType =
1892 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1893 if (!operandType || operandType.getShape() != resultType.getShape()) {
1898 xegpu::DistributeLayoutAttr layout =
1900 return isLegal(layout);
1903 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1904 [=](UnrealizedConversionCastOp op) {
1905 return llvm::is_contained(existingCastOps, op.getOperation());
1908 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1914 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
1915 return signalPassFailure();
1918 getOperation()->walk([](Operation *op) {
1919 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1922 SmallVector<StringAttr> attrsToRemove;
1924 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1925 attrsToRemove.push_back(namedAttr.getName());
1927 for (
auto attrName : attrsToRemove)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.