34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51 "resolve_simt_type_mismatch";
64 static constexpr
unsigned regularPatternBenefit = 1;
65 static constexpr
unsigned highPatternBenefit = 2;
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82 VectorType originalType) {
85 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86 "Expecting a valid layout.");
88 layout.getEffectiveLaneLayoutAsInt();
89 assert(
static_cast<size_t>(originalType.getRank()) >=
90 effectiveLaneLayout.size() &&
91 "Rank of the original vector type should be greater or equal to the "
92 "size of the lane layout to distribute the vector type.");
96 unsigned distributionStart =
97 originalType.getRank() - effectiveLaneLayout.size();
99 if (i < distributionStart)
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
107 return VectorType::get(distributedShape, originalType.getElementType());
124 template <
typename T>
125 static Value resolveDistributedTy(
Value orig, T expected,
128 if (orig.
getType() == expected)
131 if (isa<VectorType>(orig.
getType())) {
133 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
134 return castOp.getResult();
138 if (isa<xegpu::TensorDescType>(orig.
getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
142 return castOp.getResult(0);
144 llvm_unreachable(
"Unsupported type for reconciliation");
151 static bool requirePacked(
const xegpu::LayoutAttr layout) {
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
157 return laneData[0] != 1;
161 static bool requireTranspose(
const xegpu::LayoutAttr layout,
162 const std::string &chipStr) {
165 if (chipStr !=
"pvc" && chipStr !=
"bmg")
169 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
170 if (laneLayout.size() != 2)
198 struct MoveFuncBodyToWarpExecuteOnLane0
201 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
204 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
205 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
209 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
210 return isa<gpu::WarpExecuteOnLane0Op>(op);
215 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
218 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
220 auto newGpuFunc = gpu::GPUFuncOp::create(
221 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
223 privateAttributionsTypes);
224 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
228 auto laneId = gpu::LaneIdOp::create(
230 mlir::IntegerAttr());
231 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
232 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
233 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
235 newGpuFunc.getArgumentTypes());
236 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
239 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
241 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
242 origRetunOp.getOperands());
246 warpOp.getBodyRegion().begin());
250 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
251 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
289 using gpu::WarpDistributionPattern::WarpDistributionPattern;
290 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
293 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
296 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
300 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
303 descOp,
"the tensor descriptor lacks layout attribute");
305 if (descOp.getMixedOffsets().size())
307 descOp,
"xegpu::CreateNdDescOp must not have offsets");
311 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
312 rewriter, warpOp, descOp->getOperands(),
313 descOp.getOperandTypes(), newRetIndices);
316 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
318 xegpu::TensorDescType distributedTensorDescTy =
319 descOp.getType().dropLayouts();
321 Value newDescOp = xegpu::CreateNdDescOp::create(
322 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
325 Value distributedVal = newWarpOp.getResult(operandIdx);
328 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
367 using gpu::WarpDistributionPattern::WarpDistributionPattern;
368 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
370 gpu::YieldOp yield = warpOp.getTerminator();
371 Operation *lastNode = yield->getPrevNode();
372 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
380 "the store op must have offsets");
384 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
385 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
386 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
389 storeOp,
"the source tensor descriptor lacks layout attribute");
391 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
392 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
393 if (
failed(distributedTypeByWarpOpOrFailure))
395 "Failed to distribute the type");
396 VectorType distributedTypeByWarpOp =
397 distributedTypeByWarpOpOrFailure.value();
401 storeOp.getTensorDesc()};
403 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
404 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
405 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
406 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
416 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
418 if (
failed(storeNdDistributedValueTyOrFailure))
420 storeOp,
"Failed to get distributed vector type for the store op");
421 newStoreOperands.push_back(resolveDistributedTy(
422 newWarpOp.getResult(newRetIndices[0]),
423 storeNdDistributedValueTyOrFailure.value(), rewriter));
426 xegpu::TensorDescType distributedTensorDescTy =
427 storeOp.getTensorDescType().dropLayouts();
428 newStoreOperands.push_back(
429 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
430 distributedTensorDescTy, rewriter));
432 for (
size_t i = 2; i < newRetIndices.size(); ++i)
433 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
436 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
437 newStoreOperands, storeOp->getAttrs());
481 using gpu::WarpDistributionPattern::WarpDistributionPattern;
482 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
485 if (!isa<xegpu::LoadNdOp>(op))
490 gpu::YieldOp yield = warpOp.getTerminator();
491 return yield->getPrevNode() == op;
496 warpOp,
"warp result is not a xegpu::LoadNd op");
505 "xegpu::LoadNdOp require chip information to determine transpose "
511 "the load op must have offsets");
515 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
517 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
518 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
521 loadOp,
"the source tensor descriptor lacks layout attribute");
524 VectorType distributedTypeByWarpOp =
525 cast<VectorType>(warpOp.getResult(operandIdx).getType());
530 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
531 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
532 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
533 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
538 FailureOr<VectorType> loadNdDistValueTyOrFailure =
540 if (
failed(loadNdDistValueTyOrFailure))
542 loadOp,
"Failed to get distributed vector type for the load op");
543 xegpu::TensorDescType distributedTensorDescTy =
544 loadOp.getTensorDescType().dropLayouts();
548 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
549 distributedTensorDescTy, rewriter)};
551 for (
size_t i = 1; i < newRetIndices.size(); ++i)
552 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
553 auto newLoadOp = xegpu::LoadNdOp::create(
554 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
555 newLoadOperands, loadOp->getAttrs());
558 newLoadOp.setPacked(requirePacked(layout));
560 if (requireTranspose(layout, chipStr.value()))
561 newLoadOp.setTranspose(
563 Value distributedVal = newWarpOp.getResult(operandIdx);
567 Value tyResolvedVal = resolveDistributedTy(
568 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
609 using gpu::WarpDistributionPattern::WarpDistributionPattern;
610 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
612 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
615 "warp result is not a xegpu::Dpas op");
623 xegpu::LayoutAttr layoutA =
624 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
625 xegpu::LayoutAttr layoutB =
626 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
627 xegpu::LayoutAttr layoutOut =
628 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
629 if (!layoutA || !layoutB || !layoutOut)
632 "the xegpu::Dpas op lacks layout attribute for A, B or output");
634 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
635 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
636 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
637 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
638 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
639 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
640 if (
failed(distLhsTypeByWarpOpOrFailure) ||
641 failed(distRhsTypeByWarpOpOrFailure) ||
642 failed(distResultTypeByWarpOpOrFailure))
645 "Failed to distribute the A, B or output types in xegpu::Dpas op");
650 distLhsTypeByWarpOpOrFailure.value(),
651 distRhsTypeByWarpOpOrFailure.value()};
653 if (dpasOp.getAcc()) {
654 newYieldValues.push_back(dpasOp.getAcc());
655 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
659 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
660 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
662 FailureOr<VectorType> expectedDistLhsTyOrFailure =
664 FailureOr<VectorType> expectedDistRhsTyOrFailure =
666 FailureOr<VectorType> expectedDistResultTyOrFailure =
668 if (
failed(expectedDistLhsTyOrFailure) ||
669 failed(expectedDistRhsTyOrFailure) ||
670 failed(expectedDistResultTyOrFailure))
673 "Failed to get distributed vector type for the dpas operands.");
680 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
681 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
682 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
684 newDpasOperandExpectedTypes.push_back(distributedResultTy);
686 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
687 newDpasOperands.push_back(
688 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
689 newDpasOperandExpectedTypes[i], rewriter));
691 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
692 distributedResultTy, newDpasOperands,
695 Value distributedVal = newWarpOp.getResult(operandIdx);
698 resolveDistributedTy(newDpasOp.getResult(),
699 distResultTypeByWarpOpOrFailure.value(), rewriter);
734 using gpu::WarpDistributionPattern::WarpDistributionPattern;
735 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
737 gpu::YieldOp yield = warpOp.getTerminator();
738 Operation *lastNode = yield->getPrevNode();
739 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
747 "the prefetch op must have offsets");
751 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
753 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
756 prefetchOp,
"the source tensor descriptor lacks layout attribute");
760 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
761 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
763 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
764 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
767 xegpu::TensorDescType newTensorDescTy =
768 prefetchOp.getTensorDescType().dropLayouts();
771 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
773 for (
size_t i = 1; i < newRetIndices.size(); ++i)
774 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
775 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
776 newPrefetchOperands, prefetchOp->getAttrs());
786 using gpu::WarpDistributionPattern::WarpDistributionPattern;
787 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
789 gpu::YieldOp yield = warpOp.getTerminator();
790 Operation *lastNode = yield->getPrevNode();
792 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
797 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
798 barrierOp->getResultTypes(),
799 barrierOp->getOperands(), barrierOp->getAttrs());
830 using gpu::WarpDistributionPattern::WarpDistributionPattern;
831 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
833 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
834 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
837 auto offsets = storeScatterOp.getOffsets();
838 if (!offsets || !isa<VectorType>(offsets.getType()))
840 storeScatterOp,
"Store op must have a vector of offsets argument");
841 VectorType offsetsTy = cast<VectorType>(offsets.getType());
842 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
843 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
845 "Expected 1D offsets and mask vector");
846 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
847 if (storeVecTy.getRank() > 2)
849 storeScatterOp,
"Expected at most 2D result at SG level");
851 std::string layoutPayloadName =
853 std::string layoutOffsetsName =
855 std::string layoutMaskName =
858 xegpu::LayoutAttr layoutPayload =
859 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
860 xegpu::LayoutAttr layoutOffsets =
861 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
862 xegpu::LayoutAttr layoutMask =
863 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
865 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
866 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
867 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
868 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
869 FailureOr<VectorType> distMaskByWarpOpOrFailure =
870 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
871 if (
failed(distStoreVecByWarpOpOrFailure) ||
872 failed(distOffsetsByWarpOpOrFailure) ||
873 failed(distMaskByWarpOpOrFailure)) {
876 "Some vector operands have no layouts, using defaults instead.");
878 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
880 {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
885 expectedPayloadTy, operands[1].getType(),
886 distOffsetsByWarpOpOrFailure.value(),
887 distMaskByWarpOpOrFailure.value()};
889 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
890 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
892 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
895 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
896 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
897 storeScatterOp->getAttrs());
899 rewriter.
eraseOp(storeScatterOp);
924 using gpu::WarpDistributionPattern::WarpDistributionPattern;
925 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
930 return isa<xegpu::LoadGatherOp>(op) &&
931 warpOp.getTerminator()->getPrevNode() == op;
933 if (!producedByLastLoad)
935 warpOp,
"The last op is not xegpu::LoadGatherOp");
939 auto offsets = loadGatherOp.getOffsets();
940 if (!offsets || !isa<VectorType>(offsets.getType()) ||
941 !isa<VectorType>(loadGatherOp.getMask().getType()))
944 "Load op must have a vector arguments for offsets and mask");
945 VectorType offsetsTy = cast<VectorType>(offsets.getType());
946 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
947 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
949 "Expected 1D offsets and mask vector");
951 std::string layoutOffsetsName =
953 std::string layoutMaskName =
956 xegpu::LayoutAttr layoutOffsets =
957 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
958 xegpu::LayoutAttr layoutMask =
959 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
961 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
962 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
963 FailureOr<VectorType> distMaskByWarpOpOrFailure =
964 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
965 if (
failed(distOffsetsByWarpOpOrFailure) ||
966 failed(distMaskByWarpOpOrFailure)) {
969 "Some vector operands have no layouts, using defaults instead.");
975 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
976 distMaskByWarpOpOrFailure.value()};
979 VectorType loadVecTy =
980 cast<VectorType>(warpOp.getResult(operandIdx).getType());
982 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
983 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
986 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
989 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
990 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
991 loadGatherOp->getAttrs());
993 Value distributedVal = newWarpOp.getResult(operandIdx);
1003 vector::CombiningKind
kind,
1004 int64_t reductionDim,
Location loc,
1007 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1008 VectorType sourceType = src.getType();
1009 int64_t sourceH = sourceType.getShape()[0];
1010 int64_t sourceW = sourceType.getShape()[1];
1011 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1013 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1014 Value reductionResult = arith::ConstantOp::create(
1015 rewriter, loc, acc.getType(),
1019 for (
int i = 0; i < nSlices; ++i) {
1021 if (reductionDim == 1) {
1022 sliceOffsets = {i, 0};
1023 sliceSizes = {1, sourceW};
1025 sliceOffsets = {0, i};
1026 sliceSizes = {sourceH, 1};
1028 vector::ExtractStridedSliceOp extractOp =
1029 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1030 sliceSizes, {1, 1});
1031 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1032 Value slice = vector::ShapeCastOp::create(
1035 extractOp.getResult());
1036 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1038 vector::ReductionOp::create(rewriter, loc,
kind, slice, accExtract);
1040 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1042 return reductionResult;
1101 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1102 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1105 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1111 VectorType sourceType = reductionOp.getSourceVectorType();
1113 if (sourceType.getRank() != 2)
1115 "Only 2D reductions are supported.");
1119 if (reductionDims.size() != 1)
1121 warpOp,
"Only 1 reduction dimension is supported.");
1122 int64_t reductionDim = reductionDims[0];
1123 VectorType distributedResultType =
1124 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1125 VectorType resultType = cast<VectorType>(reductionOp.getType());
1126 xegpu::DistributeLayoutAttr sourceLayout =
1129 FailureOr<VectorType> sourceDistTypeOrFailure =
1130 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1131 if (
failed(sourceDistTypeOrFailure))
1133 warpOp,
"Failed to distribute the source vector type.");
1134 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1136 bool dim0Distributed =
1137 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1138 bool dim1Distributed =
1139 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1140 if (dim0Distributed && dim1Distributed)
1142 warpOp,
"Expecting source to be distributed in a single dimension.");
1143 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1144 if (sourceDistDim == -1)
1146 warpOp,
"Expecting a distributed source vector.");
1147 bool resultDistributed =
1148 distributedResultType.getNumElements() < resultType.getNumElements();
1162 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1163 (sourceDistDim == 1 && reductionDim == 0);
1164 if (isReductionLaneLocal && !resultDistributed)
1166 warpOp,
"Expecting a distributed result for lane-local reduction.");
1168 if (!isReductionLaneLocal && resultDistributed)
1171 "Expecting a broadcasted result for non-lane-local reduction.");
1175 if (isReductionLaneLocal) {
1178 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1179 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1180 {sourceDistType, distributedResultType}, newRetIndices);
1182 Value result = lowerToVectorReductions(
1185 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1194 Value result = lowerToVectorReductions(
1197 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1207 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1208 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1211 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1218 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1219 xegpu::DistributeLayoutAttr sourceLayout =
1221 xegpu::DistributeLayoutAttr resultLayout =
1223 if (!sourceLayout || !resultLayout)
1226 "the source or result of shape_cast op lacks distribution layout");
1230 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1231 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1232 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1234 warpOp,
"shape_cast is rank reducing but source layout is not a "
1235 "slice of result layout");
1236 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1238 warpOp,
"shape_cast is rank increasing but result layout is not a "
1239 "slice of source layout");
1241 FailureOr<VectorType> sourceDistTypeOrFailure =
1242 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1243 shapeCastOp.getSourceVectorType());
1244 if (
failed(sourceDistTypeOrFailure))
1246 warpOp,
"failed to get distributed vector type for source");
1247 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1250 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1251 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1254 Value source = newWarpOp.getResult(newRetIndices[0]);
1256 Value newShapeCast = vector::ShapeCastOp::create(
1257 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1267 struct MemrefExtractAlignedPointerAsIndexDistribution final
1269 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1270 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1273 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1277 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1282 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1283 rewriter, warpOp, extractOp.getSource(),
1284 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1286 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1287 rewriter, newWarpOp.getLoc(), extractOp.
getType(),
1288 newWarpOp.getResult(newRetIndices[0]));
1289 Value distributedVal = newWarpOp.getResult(operandIdx);
1301 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1302 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1305 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1308 warpOp,
"warp result is not a vector::BitCast op");
1311 VectorType distributedSourceType =
1312 getDistVecTypeBasedOnLaneLayout(
1314 bitcastOp.getSourceVectorType())
1315 .value_or(VectorType());
1316 if (!distributedSourceType)
1318 bitcastOp,
"Failed to distribute the source vector type in "
1319 "vector::BitCast op");
1320 VectorType distributedResultType =
1321 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1323 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1324 rewriter, warpOp, bitcastOp.getSource(),
1325 TypeRange{distributedSourceType}, newRetIndices);
1327 auto newBitcastOp = vector::BitCastOp::create(
1328 rewriter, newWarpOp.getLoc(), distributedResultType,
1329 newWarpOp.getResult(newRetIndices[0]));
1330 Value distributedVal = newWarpOp.getResult(operandIdx);
1345 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1346 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1349 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1352 warpOp,
"warp result is not a vector::Transpose op");
1355 xegpu::DistributeLayoutAttr sourceLayout =
1357 xegpu::DistributeLayoutAttr resultLayout =
1359 if (!sourceLayout || !resultLayout)
1362 "the source or result vector of the transpose op lacks layout "
1364 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1365 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1368 if (sourceRank != 2 || resultRank != 2)
1370 transposeOp,
"the source or result vector of the transpose op "
1371 "does not have 2D layout");
1374 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1377 "the source or result vector layouts must be 2D transposes of each "
1379 FailureOr<VectorType> distributedSourceTypeOrFailure =
1380 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1381 transposeOp.getSourceVectorType());
1382 if (
failed(distributedSourceTypeOrFailure))
1384 transposeOp,
"Failed to distribute the source vector type in "
1385 "vector::Transpose op");
1387 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1388 rewriter, warpOp, transposeOp.getVector(),
1389 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1391 auto newTransposeOp = vector::TransposeOp::create(
1392 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1394 Value distributedVal = newWarpOp.getResult(operandIdx);
1403 struct XeGPUSubgroupDistributePass final
1404 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
1405 XeGPUSubgroupDistributePass> {
1406 XeGPUSubgroupDistributePass() =
default;
1407 XeGPUSubgroupDistributePass(
const XeGPUSubgroupDistributePass &other) =
1409 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions
options)
1410 : XeGPUSubgroupDistributeBase(
options) {}
1411 void runOnOperation()
override;
1417 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1418 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1419 GpuBarrierDistribution, VectorMultiReductionDistribution,
1420 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1421 VectorBitcastDistribution,
1422 MemrefExtractAlignedPointerAsIndexDistribution>(
1424 regularPatternBenefit);
1425 patterns.add<VectorShapeCastDistribution>(
1427 highPatternBenefit);
1430 void XeGPUSubgroupDistributePass::runOnOperation() {
1439 if (!isa<VectorType>(operand.get().getType()))
1442 auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1444 op->emitError(
"Could not find layout attribute for operand ")
1445 << operand.getOperandNumber() <<
" of operation " << op->getName();
1446 signalPassFailure();
1459 signalPassFailure();
1466 getOperation()->walk([&](
Operation *op) {
1467 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1468 vector::moveScalarUniformCode(warpOp);
1477 auto distributionFn = [](
Value val) {
1478 VectorType vecType = dyn_cast<VectorType>(val.getType());
1479 int64_t vecRank = vecType ? vecType.getRank() : 0;
1488 vecRank, {
static_cast<unsigned int>(vecRank - 1)}, val.getContext());
1490 for (
auto [i, v] :
llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1492 distributedDims.push_back(i);
1499 int64_t warpSz) {
return Value(); };
1502 vector::CombiningKind
kind, uint32_t size) {
1504 Value laneVal = builder.create<vector::ReductionOp>(loc,
kind, input);
1506 for (uint64_t i = 1; i < size; i <<= 1) {
1509 .create<gpu::ShuffleOp>(loc, laneVal, i,
1511 gpu::ShuffleMode::XOR)
1512 .getShuffleResult();
1518 if (enableSGReductions)
1519 vector::populateDistributeReduction(
1521 regularPatternBenefit);
1523 vector::populatePropagateWarpVectorDistributionPatterns(
1524 patterns, distributionFn, shuffleFn,
1525 regularPatternBenefit);
1527 signalPassFailure();
1537 bool foundWarpOp =
false;
1538 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1548 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1554 Value input = op.getOperand(0);
1555 Value output = op.getResult(0);
1558 xegpu::TensorDescType inputDescType =
1559 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1560 xegpu::TensorDescType outputDescType =
1561 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1562 assert(inputDescType && outputDescType &&
1563 "Unrealized conversion cast must have tensor descriptor types");
1568 if (inputDescType.getLayout()) {
1569 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1571 argument.setType(output.getType());
1572 output.replaceAllUsesWith(argument);
1573 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1574 argument.getOwner()->getParentOp())) {
1575 auto result = loopOp.getTiedLoopResult(argument);
1576 result.setType(output.getType());
1584 if (outputDescType.getLayout())
1585 output.replaceAllUsesWith(input);
1587 if (op->use_empty())
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
constexpr unsigned subgroupSize
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...