34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51 "resolve_simt_type_mismatch";
64 static constexpr
unsigned regularPatternBenefit = 1;
65 static constexpr
unsigned highPatternBenefit = 2;
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82 VectorType originalType) {
85 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86 "Expecting a valid layout.");
88 layout.getEffectiveLaneLayoutAsInt();
89 assert(
static_cast<size_t>(originalType.getRank()) >=
90 effectiveLaneLayout.size() &&
91 "Rank of the original vector type should be greater or equal to the "
92 "size of the lane layout to distribute the vector type.");
96 unsigned distributionStart =
97 originalType.getRank() - effectiveLaneLayout.size();
99 if (i < distributionStart)
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
107 return VectorType::get(distributedShape, originalType.getElementType());
124 template <
typename T>
125 static Value resolveDistributedTy(
Value orig, T expected,
128 if (orig.
getType() == expected)
131 if (isa<VectorType>(orig.
getType())) {
133 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
134 return castOp.getResult();
138 if (isa<xegpu::TensorDescType>(orig.
getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
142 return castOp.getResult(0);
144 llvm_unreachable(
"Unsupported type for reconciliation");
151 static bool requirePacked(
const xegpu::LayoutAttr layout) {
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
157 return laneData[0] != 1;
161 static bool requireTranspose(
const xegpu::LayoutAttr layout,
162 const std::string &chipStr) {
165 if (chipStr !=
"pvc" && chipStr !=
"bmg")
169 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
170 if (laneLayout.size() != 2)
198 struct MoveFuncBodyToWarpExecuteOnLane0
201 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
204 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
205 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
209 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
210 return isa<gpu::WarpExecuteOnLane0Op>(op);
215 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
218 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
220 auto newGpuFunc = gpu::GPUFuncOp::create(
221 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
223 privateAttributionsTypes);
224 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
228 auto laneId = gpu::LaneIdOp::create(
230 mlir::IntegerAttr());
231 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
232 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
233 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
235 newGpuFunc.getArgumentTypes());
236 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
239 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
241 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
242 origRetunOp.getOperands());
246 warpOp.getBodyRegion().begin());
250 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
251 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
289 using gpu::WarpDistributionPattern::WarpDistributionPattern;
290 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
293 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
296 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
300 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
303 descOp,
"the tensor descriptor lacks layout attribute");
307 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
308 rewriter, warpOp, descOp->getOperands(),
309 descOp.getOperandTypes(), newRetIndices);
312 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
314 xegpu::TensorDescType distributedTensorDescTy =
315 descOp.getType().dropLayouts();
317 Value newDescOp = xegpu::CreateNdDescOp::create(
318 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
321 Value distributedVal = newWarpOp.getResult(operandIdx);
324 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
362 using gpu::WarpDistributionPattern::WarpDistributionPattern;
363 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
365 gpu::YieldOp yield = warpOp.getTerminator();
366 Operation *lastNode = yield->getPrevNode();
367 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
371 int64_t offsetSize =
static_cast<int64_t
>(storeOp.getOffsets().size());
372 if ((offsetSize != 0) || storeOp.getConstOffsetsAttr())
375 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
376 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
379 storeOp,
"the source tensor descriptor lacks layout attribute");
381 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
382 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
383 if (
failed(distributedTypeByWarpOpOrFailure))
385 "Failed to distribute the type");
386 VectorType distributedTypeByWarpOp =
387 distributedTypeByWarpOpOrFailure.value();
390 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
393 ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
395 TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
406 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
408 if (
failed(storeNdDistributedValueTyOrFailure))
410 storeOp,
"Failed to get distributed vector type for the store op");
411 newStoreOperands.push_back(resolveDistributedTy(
412 newWarpOp.getResult(newRetIndices[0]),
413 storeNdDistributedValueTyOrFailure.value(), rewriter));
416 xegpu::TensorDescType distributedTensorDescTy =
417 storeOp.getTensorDescType().dropLayouts();
418 newStoreOperands.push_back(
419 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
420 distributedTensorDescTy, rewriter));
423 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
424 newStoreOperands, storeOp->getAttrs());
468 using gpu::WarpDistributionPattern::WarpDistributionPattern;
469 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
472 if (!isa<xegpu::LoadNdOp>(op))
477 gpu::YieldOp yield = warpOp.getTerminator();
478 return yield->getPrevNode() == op;
483 warpOp,
"warp result is not a xegpu::LoadNd op");
492 "xegpu::LoadNdOp require chip information to determine transpose "
494 int64_t offsetSize =
static_cast<int64_t
>(loadOp.getOffsets().size());
495 if ((offsetSize != 0) || loadOp.getConstOffsetsAttr())
498 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
499 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
502 loadOp,
"the source tensor descriptor lacks layout attribute");
505 VectorType distributedTypeByWarpOp =
506 cast<VectorType>(warpOp.getResult(operandIdx).getType());
509 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511 loadOp.getTensorDesc(),
512 tensorDescTy, newRetIndices);
517 FailureOr<VectorType> loadNdDistValueTyOrFailure =
519 if (
failed(loadNdDistValueTyOrFailure))
521 loadOp,
"Failed to get distributed vector type for the load op");
522 xegpu::TensorDescType distributedTensorDescTy =
523 loadOp.getTensorDescType().dropLayouts();
526 auto newLoadOp = xegpu::LoadNdOp::create(
527 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
528 resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
529 distributedTensorDescTy, rewriter),
533 newLoadOp.setPacked(requirePacked(layout));
535 if (requireTranspose(layout, chipStr.value()))
536 newLoadOp.setTranspose(
538 Value distributedVal = newWarpOp.getResult(operandIdx);
542 Value tyResolvedVal = resolveDistributedTy(
543 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
584 using gpu::WarpDistributionPattern::WarpDistributionPattern;
585 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
587 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
590 "warp result is not a xegpu::Dpas op");
598 xegpu::LayoutAttr layoutA =
599 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
600 xegpu::LayoutAttr layoutB =
601 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
602 xegpu::LayoutAttr layoutOut =
603 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
604 if (!layoutA || !layoutB || !layoutOut)
607 "the xegpu::Dpas op lacks layout attribute for A, B or output");
609 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
610 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
611 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
612 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
613 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
614 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
615 if (
failed(distLhsTypeByWarpOpOrFailure) ||
616 failed(distRhsTypeByWarpOpOrFailure) ||
617 failed(distResultTypeByWarpOpOrFailure))
620 "Failed to distribute the A, B or output types in xegpu::Dpas op");
625 distLhsTypeByWarpOpOrFailure.value(),
626 distRhsTypeByWarpOpOrFailure.value()};
628 if (dpasOp.getAcc()) {
629 newYieldValues.push_back(dpasOp.getAcc());
630 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
634 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
635 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
637 FailureOr<VectorType> expectedDistLhsTyOrFailure =
639 FailureOr<VectorType> expectedDistRhsTyOrFailure =
641 FailureOr<VectorType> expectedDistResultTyOrFailure =
643 if (
failed(expectedDistLhsTyOrFailure) ||
644 failed(expectedDistRhsTyOrFailure) ||
645 failed(expectedDistResultTyOrFailure))
648 "Failed to get distributed vector type for the dpas operands.");
655 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
656 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
657 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
659 newDpasOperandExpectedTypes.push_back(distributedResultTy);
661 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
662 newDpasOperands.push_back(
663 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
664 newDpasOperandExpectedTypes[i], rewriter));
666 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
667 distributedResultTy, newDpasOperands,
670 Value distributedVal = newWarpOp.getResult(operandIdx);
673 resolveDistributedTy(newDpasOp.getResult(),
674 distResultTypeByWarpOpOrFailure.value(), rewriter);
715 using gpu::WarpDistributionPattern::WarpDistributionPattern;
716 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
719 getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
722 warpOp,
"warp result is not a xegpu::UpdateNdOffset op");
727 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
732 xegpu::TensorDescType distributedTensorDescTy =
733 updateOp.getTensorDescType().dropLayouts();
735 llvm::map_to_vector(newRetIndices, [&](
size_t i) {
739 if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
740 return resolveDistributedTy(newWarpOp.getResult(i),
741 distributedTensorDescTy, rewriter);
743 return newWarpOp.getResult(i);
746 auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
747 rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
748 newUpdateOperands,
updateOp->getAttrs());
750 Value distributedVal = newWarpOp.getResult(operandIdx);
752 Value typeResolved = resolveDistributedTy(
753 newUpdateOp.getResult(), distributedVal.
getType(), rewriter);
787 using gpu::WarpDistributionPattern::WarpDistributionPattern;
788 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
790 gpu::YieldOp yield = warpOp.getTerminator();
791 Operation *lastNode = yield->getPrevNode();
792 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
796 int64_t offsetSize =
static_cast<int64_t
>(prefetchOp.getOffsets().size());
797 if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr())
800 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
803 prefetchOp,
"the source tensor descriptor lacks layout attribute");
808 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
809 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
812 xegpu::TensorDescType newTensorDescTy =
813 prefetchOp.getTensorDescType().dropLayouts();
816 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
817 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
818 newPrefetchOperands, prefetchOp->getAttrs());
828 using gpu::WarpDistributionPattern::WarpDistributionPattern;
829 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
831 gpu::YieldOp yield = warpOp.getTerminator();
832 Operation *lastNode = yield->getPrevNode();
834 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
839 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
840 barrierOp->getResultTypes(),
841 barrierOp->getOperands(), barrierOp->getAttrs());
872 using gpu::WarpDistributionPattern::WarpDistributionPattern;
873 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
875 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
876 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
879 auto offsets = storeScatterOp.getOffsets();
880 if (!offsets || !isa<VectorType>(offsets.getType()))
882 storeScatterOp,
"Store op must have a vector of offsets argument");
883 VectorType offsetsTy = cast<VectorType>(offsets.getType());
884 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
885 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
887 "Expected 1D offsets and mask vector");
888 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
889 if (storeVecTy.getRank() > 2)
891 storeScatterOp,
"Expected at most 2D result at SG level");
893 std::string layoutPayloadName =
895 std::string layoutOffsetsName =
897 std::string layoutMaskName =
900 xegpu::LayoutAttr layoutPayload =
901 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
902 xegpu::LayoutAttr layoutOffsets =
903 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
904 xegpu::LayoutAttr layoutMask =
905 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
907 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
908 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
909 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
910 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
911 FailureOr<VectorType> distMaskByWarpOpOrFailure =
912 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
913 if (
failed(distStoreVecByWarpOpOrFailure) ||
914 failed(distOffsetsByWarpOpOrFailure) ||
915 failed(distMaskByWarpOpOrFailure)) {
918 "Some vector operands have no layouts, using defaults instead.");
920 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
922 {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
927 expectedPayloadTy, operands[1].getType(),
928 distOffsetsByWarpOpOrFailure.value(),
929 distMaskByWarpOpOrFailure.value()};
931 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
932 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
934 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
937 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
938 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
939 storeScatterOp->getAttrs());
941 rewriter.
eraseOp(storeScatterOp);
966 using gpu::WarpDistributionPattern::WarpDistributionPattern;
967 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
972 return isa<xegpu::LoadGatherOp>(op) &&
973 warpOp.getTerminator()->getPrevNode() == op;
975 if (!producedByLastLoad)
977 warpOp,
"The last op is not xegpu::LoadGatherOp");
981 auto offsets = loadGatherOp.getOffsets();
982 if (!offsets || !isa<VectorType>(offsets.getType()) ||
983 !isa<VectorType>(loadGatherOp.getMask().getType()))
986 "Load op must have a vector arguments for offsets and mask");
987 VectorType offsetsTy = cast<VectorType>(offsets.getType());
988 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
989 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
991 "Expected 1D offsets and mask vector");
993 std::string layoutOffsetsName =
995 std::string layoutMaskName =
998 xegpu::LayoutAttr layoutOffsets =
999 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1000 xegpu::LayoutAttr layoutMask =
1001 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1003 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1004 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1005 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1006 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1007 if (
failed(distOffsetsByWarpOpOrFailure) ||
1008 failed(distMaskByWarpOpOrFailure)) {
1011 "Some vector operands have no layouts, using defaults instead.");
1017 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1018 distMaskByWarpOpOrFailure.value()};
1021 VectorType loadVecTy =
1022 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1024 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1025 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1028 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1031 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1032 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1033 loadGatherOp->getAttrs());
1035 Value distributedVal = newWarpOp.getResult(operandIdx);
1045 vector::CombiningKind
kind,
1046 int64_t reductionDim,
Location loc,
1049 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1050 VectorType sourceType = src.getType();
1051 int64_t sourceH = sourceType.getShape()[0];
1052 int64_t sourceW = sourceType.getShape()[1];
1053 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1055 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1056 Value reductionResult = arith::ConstantOp::create(
1057 rewriter, loc, acc.getType(),
1061 for (
int i = 0; i < nSlices; ++i) {
1063 if (reductionDim == 1) {
1064 sliceOffsets = {i, 0};
1065 sliceSizes = {1, sourceW};
1067 sliceOffsets = {0, i};
1068 sliceSizes = {sourceH, 1};
1070 vector::ExtractStridedSliceOp extractOp =
1071 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1072 sliceSizes, {1, 1});
1073 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1074 Value slice = vector::ShapeCastOp::create(
1077 extractOp.getResult());
1078 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1080 vector::ReductionOp::create(rewriter, loc,
kind, slice, accExtract);
1082 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1084 return reductionResult;
1143 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1144 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1147 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1153 VectorType sourceType = reductionOp.getSourceVectorType();
1155 if (sourceType.getRank() != 2)
1157 "Only 2D reductions are supported.");
1161 if (reductionDims.size() != 1)
1163 warpOp,
"Only 1 reduction dimension is supported.");
1164 int64_t reductionDim = reductionDims[0];
1165 VectorType distributedResultType =
1166 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1167 VectorType resultType = cast<VectorType>(reductionOp.getType());
1168 xegpu::DistributeLayoutAttr sourceLayout =
1171 FailureOr<VectorType> sourceDistTypeOrFailure =
1172 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1173 if (
failed(sourceDistTypeOrFailure))
1175 warpOp,
"Failed to distribute the source vector type.");
1176 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1178 bool dim0Distributed =
1179 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1180 bool dim1Distributed =
1181 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1182 if (dim0Distributed && dim1Distributed)
1184 warpOp,
"Expecting source to be distributed in a single dimension.");
1185 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1186 if (sourceDistDim == -1)
1188 warpOp,
"Expecting a distributed source vector.");
1189 bool resultDistributed =
1190 distributedResultType.getNumElements() < resultType.getNumElements();
1204 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1205 (sourceDistDim == 1 && reductionDim == 0);
1206 if (isReductionLaneLocal && !resultDistributed)
1208 warpOp,
"Expecting a distributed result for lane-local reduction.");
1210 if (!isReductionLaneLocal && resultDistributed)
1213 "Expecting a broadcasted result for non-lane-local reduction.");
1217 if (isReductionLaneLocal) {
1220 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1221 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1222 {sourceDistType, distributedResultType}, newRetIndices);
1224 Value result = lowerToVectorReductions(
1227 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1236 Value result = lowerToVectorReductions(
1239 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1249 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1250 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1253 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1260 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1261 xegpu::DistributeLayoutAttr sourceLayout =
1263 xegpu::DistributeLayoutAttr resultLayout =
1265 if (!sourceLayout || !resultLayout)
1268 "the source or result of shape_cast op lacks distribution layout");
1272 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1273 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1274 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1276 warpOp,
"shape_cast is rank reducing but source layout is not a "
1277 "slice of result layout");
1278 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1280 warpOp,
"shape_cast is rank increasing but result layout is not a "
1281 "slice of source layout");
1283 FailureOr<VectorType> sourceDistTypeOrFailure =
1284 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1285 shapeCastOp.getSourceVectorType());
1286 if (
failed(sourceDistTypeOrFailure))
1288 warpOp,
"failed to get distributed vector type for source");
1289 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1292 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1293 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1296 Value source = newWarpOp.getResult(newRetIndices[0]);
1298 Value newShapeCast = vector::ShapeCastOp::create(
1299 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1309 struct MemrefExtractAlignedPointerAsIndexDistribution final
1311 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1312 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1315 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1319 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1324 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1325 rewriter, warpOp, extractOp.getSource(),
1326 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1328 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1329 rewriter, newWarpOp.getLoc(), extractOp.
getType(),
1330 newWarpOp.getResult(newRetIndices[0]));
1331 Value distributedVal = newWarpOp.getResult(operandIdx);
1343 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1344 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1347 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1350 warpOp,
"warp result is not a vector::BitCast op");
1353 VectorType distributedSourceType =
1354 getDistVecTypeBasedOnLaneLayout(
1356 bitcastOp.getSourceVectorType())
1357 .value_or(VectorType());
1358 if (!distributedSourceType)
1360 bitcastOp,
"Failed to distribute the source vector type in "
1361 "vector::BitCast op");
1362 VectorType distributedResultType =
1363 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1365 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1366 rewriter, warpOp, bitcastOp.getSource(),
1367 TypeRange{distributedSourceType}, newRetIndices);
1369 auto newBitcastOp = vector::BitCastOp::create(
1370 rewriter, newWarpOp.getLoc(), distributedResultType,
1371 newWarpOp.getResult(newRetIndices[0]));
1372 Value distributedVal = newWarpOp.getResult(operandIdx);
1387 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1388 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1391 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1394 warpOp,
"warp result is not a vector::Transpose op");
1397 xegpu::DistributeLayoutAttr sourceLayout =
1399 xegpu::DistributeLayoutAttr resultLayout =
1401 if (!sourceLayout || !resultLayout)
1404 "the source or result vector of the transpose op lacks layout "
1406 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1407 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1410 if (sourceRank != 2 || resultRank != 2)
1412 transposeOp,
"the source or result vector of the transpose op "
1413 "does not have 2D layout");
1416 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1419 "the source or result vector layouts must be 2D transposes of each "
1421 FailureOr<VectorType> distributedSourceTypeOrFailure =
1422 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1423 transposeOp.getSourceVectorType());
1424 if (
failed(distributedSourceTypeOrFailure))
1426 transposeOp,
"Failed to distribute the source vector type in "
1427 "vector::Transpose op");
1429 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1430 rewriter, warpOp, transposeOp.getVector(),
1431 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1433 auto newTransposeOp = vector::TransposeOp::create(
1434 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1436 Value distributedVal = newWarpOp.getResult(operandIdx);
1445 struct XeGPUSubgroupDistributePass final
1446 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
1447 XeGPUSubgroupDistributePass> {
1448 XeGPUSubgroupDistributePass() =
default;
1449 XeGPUSubgroupDistributePass(
const XeGPUSubgroupDistributePass &other) =
1451 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions
options)
1452 : XeGPUSubgroupDistributeBase(
options) {}
1453 void runOnOperation()
override;
1460 .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1461 DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1462 GpuBarrierDistribution, VectorMultiReductionDistribution,
1463 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1464 VectorBitcastDistribution,
1465 MemrefExtractAlignedPointerAsIndexDistribution>(
1467 regularPatternBenefit);
1468 patterns.add<VectorShapeCastDistribution>(
1470 highPatternBenefit);
1473 void XeGPUSubgroupDistributePass::runOnOperation() {
1482 if (!isa<VectorType>(operand.get().getType()))
1485 auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1487 op->emitError(
"Could not find layout attribute for operand ")
1488 << operand.getOperandNumber() <<
" of operation " << op->getName();
1489 signalPassFailure();
1502 signalPassFailure();
1509 getOperation()->walk([&](
Operation *op) {
1510 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1511 vector::moveScalarUniformCode(warpOp);
1520 auto distributionFn = [](
Value val) {
1521 VectorType vecType = dyn_cast<VectorType>(val.getType());
1522 int64_t vecRank = vecType ? vecType.getRank() : 0;
1531 vecRank, {
static_cast<unsigned int>(vecRank - 1)}, val.getContext());
1533 for (
auto [i, v] :
llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1535 distributedDims.push_back(i);
1542 int64_t warpSz) {
return Value(); };
1545 vector::CombiningKind
kind, uint32_t size) {
1547 Value laneVal = builder.create<vector::ReductionOp>(loc,
kind, input);
1549 for (uint64_t i = 1; i < size; i <<= 1) {
1552 .create<gpu::ShuffleOp>(loc, laneVal, i,
1554 gpu::ShuffleMode::XOR)
1555 .getShuffleResult();
1561 if (enableSGReductions)
1562 vector::populateDistributeReduction(
1564 regularPatternBenefit);
1566 vector::populatePropagateWarpVectorDistributionPatterns(
1567 patterns, distributionFn, shuffleFn,
1568 regularPatternBenefit);
1570 signalPassFailure();
1580 bool foundWarpOp =
false;
1581 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1591 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1597 Value input = op.getOperand(0);
1598 Value output = op.getResult(0);
1601 xegpu::TensorDescType inputDescType =
1602 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1603 xegpu::TensorDescType outputDescType =
1604 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1605 assert(inputDescType && outputDescType &&
1606 "Unrealized conversion cast must have tensor descriptor types");
1611 if (inputDescType.getLayout()) {
1612 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1614 argument.setType(output.getType());
1615 output.replaceAllUsesWith(argument);
1616 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1617 argument.getOwner()->getParentOp())) {
1618 auto result = loopOp.getTiedLoopResult(argument);
1619 result.setType(output.getType());
1627 if (outputDescType.getLayout())
1628 output.replaceAllUsesWith(input);
1630 if (op->use_empty())
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...