34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51 "resolve_simt_type_mismatch";
64 static constexpr
unsigned regularPatternBenefit = 1;
65 static constexpr
unsigned highPatternBenefit = 2;
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82 VectorType originalType) {
85 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86 "Expecting a valid layout.");
88 layout.getEffectiveLaneLayoutAsInt();
89 assert(
static_cast<size_t>(originalType.getRank()) >=
90 effectiveLaneLayout.size() &&
91 "Rank of the original vector type should be greater or equal to the "
92 "size of the lane layout to distribute the vector type.");
96 unsigned distributionStart =
97 originalType.getRank() - effectiveLaneLayout.size();
99 if (i < distributionStart)
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
107 return VectorType::get(distributedShape, originalType.getElementType());
124 template <
typename T>
125 static Value resolveDistributedTy(
Value orig, T expected,
128 if (orig.
getType() == expected)
131 if (isa<VectorType>(orig.
getType())) {
133 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
134 return castOp.getResult();
138 if (isa<xegpu::TensorDescType>(orig.
getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
142 return castOp.getResult(0);
144 llvm_unreachable(
"Unsupported type for reconciliation");
151 static bool requirePacked(
const xegpu::LayoutAttr layout) {
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
157 return laneData[0] != 1;
161 static bool requireTranspose(
const xegpu::LayoutAttr layout,
170 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171 if (laneLayout.size() != 2)
201 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
206 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
207 "to set the warp size");
209 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
210 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
214 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
215 return isa<gpu::WarpExecuteOnLane0Op>(op);
220 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
223 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
225 auto newGpuFunc = gpu::GPUFuncOp::create(
226 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
228 privateAttributionsTypes);
229 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
233 auto laneId = gpu::LaneIdOp::create(
235 mlir::IntegerAttr());
236 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
237 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
238 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
240 newGpuFunc.getArgumentTypes());
241 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
244 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
246 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
247 origRetunOp.getOperands());
251 warpOp.getBodyRegion().begin());
255 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
256 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
294 using gpu::WarpDistributionPattern::WarpDistributionPattern;
295 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
298 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
301 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
305 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
308 descOp,
"the tensor descriptor lacks layout attribute");
310 if (descOp.getMixedOffsets().size())
312 descOp,
"xegpu::CreateNdDescOp must not have offsets");
316 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
317 rewriter, warpOp, descOp->getOperands(),
318 descOp.getOperandTypes(), newRetIndices);
321 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
323 xegpu::TensorDescType distributedTensorDescTy =
324 descOp.getType().dropLayouts();
326 Value newDescOp = xegpu::CreateNdDescOp::create(
327 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
330 Value distributedVal = newWarpOp.getResult(operandIdx);
333 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
372 using gpu::WarpDistributionPattern::WarpDistributionPattern;
373 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
375 gpu::YieldOp yield = warpOp.getTerminator();
376 Operation *lastNode = yield->getPrevNode();
377 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
385 "the store op must have offsets");
389 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
390 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
391 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
394 storeOp,
"the source tensor descriptor lacks layout attribute");
396 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
397 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
398 if (
failed(distributedTypeByWarpOpOrFailure))
400 "Failed to distribute the type");
401 VectorType distributedTypeByWarpOp =
402 distributedTypeByWarpOpOrFailure.value();
406 storeOp.getTensorDesc()};
408 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
409 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
410 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
411 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
421 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
423 if (
failed(storeNdDistributedValueTyOrFailure))
425 storeOp,
"Failed to get distributed vector type for the store op");
426 newStoreOperands.push_back(resolveDistributedTy(
427 newWarpOp.getResult(newRetIndices[0]),
428 storeNdDistributedValueTyOrFailure.value(), rewriter));
431 xegpu::TensorDescType distributedTensorDescTy =
432 storeOp.getTensorDescType().dropLayouts();
433 newStoreOperands.push_back(
434 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
435 distributedTensorDescTy, rewriter));
437 for (
size_t i = 2; i < newRetIndices.size(); ++i)
438 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
441 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
442 newStoreOperands, storeOp->getAttrs());
486 using gpu::WarpDistributionPattern::WarpDistributionPattern;
487 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
490 if (!isa<xegpu::LoadNdOp>(op))
495 gpu::YieldOp yield = warpOp.getTerminator();
496 return yield->getPrevNode() == op;
501 warpOp,
"warp result is not a xegpu::LoadNd op");
507 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
508 "determine transpose "
516 "the load op must have offsets");
520 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
522 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
523 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
526 loadOp,
"the source tensor descriptor lacks layout attribute");
529 VectorType distributedTypeByWarpOp =
530 cast<VectorType>(warpOp.getResult(operandIdx).getType());
535 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
536 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
537 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
538 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
543 FailureOr<VectorType> loadNdDistValueTyOrFailure =
545 if (
failed(loadNdDistValueTyOrFailure))
547 loadOp,
"Failed to get distributed vector type for the load op");
548 xegpu::TensorDescType distributedTensorDescTy =
549 loadOp.getTensorDescType().dropLayouts();
553 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
554 distributedTensorDescTy, rewriter)};
556 for (
size_t i = 1; i < newRetIndices.size(); ++i)
557 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
558 auto newLoadOp = xegpu::LoadNdOp::create(
559 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
560 newLoadOperands, loadOp->getAttrs());
563 newLoadOp.setPacked(requirePacked(layout));
565 if (requireTranspose(layout,
uArch))
566 newLoadOp.setTranspose(
568 Value distributedVal = newWarpOp.getResult(operandIdx);
572 Value tyResolvedVal = resolveDistributedTy(
573 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
614 using gpu::WarpDistributionPattern::WarpDistributionPattern;
615 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
617 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
620 "warp result is not a xegpu::Dpas op");
628 xegpu::LayoutAttr layoutA =
629 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
630 xegpu::LayoutAttr layoutB =
631 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
632 xegpu::LayoutAttr layoutOut =
633 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
634 if (!layoutA || !layoutB || !layoutOut)
637 "the xegpu::Dpas op lacks layout attribute for A, B or output");
639 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
640 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
641 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
642 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
643 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
644 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
645 if (
failed(distLhsTypeByWarpOpOrFailure) ||
646 failed(distRhsTypeByWarpOpOrFailure) ||
647 failed(distResultTypeByWarpOpOrFailure))
650 "Failed to distribute the A, B or output types in xegpu::Dpas op");
655 distLhsTypeByWarpOpOrFailure.value(),
656 distRhsTypeByWarpOpOrFailure.value()};
658 if (dpasOp.getAcc()) {
659 newYieldValues.push_back(dpasOp.getAcc());
660 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
664 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
665 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
667 FailureOr<VectorType> expectedDistLhsTyOrFailure =
669 FailureOr<VectorType> expectedDistRhsTyOrFailure =
671 FailureOr<VectorType> expectedDistResultTyOrFailure =
673 if (
failed(expectedDistLhsTyOrFailure) ||
674 failed(expectedDistRhsTyOrFailure) ||
675 failed(expectedDistResultTyOrFailure))
678 "Failed to get distributed vector type for the dpas operands.");
685 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
686 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
687 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
689 newDpasOperandExpectedTypes.push_back(distributedResultTy);
691 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
692 newDpasOperands.push_back(
693 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
694 newDpasOperandExpectedTypes[i], rewriter));
696 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
697 distributedResultTy, newDpasOperands,
700 Value distributedVal = newWarpOp.getResult(operandIdx);
703 resolveDistributedTy(newDpasOp.getResult(),
704 distResultTypeByWarpOpOrFailure.value(), rewriter);
739 using gpu::WarpDistributionPattern::WarpDistributionPattern;
740 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
742 gpu::YieldOp yield = warpOp.getTerminator();
743 Operation *lastNode = yield->getPrevNode();
744 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
752 "the prefetch op must have offsets");
756 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
758 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
761 prefetchOp,
"the source tensor descriptor lacks layout attribute");
765 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
766 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
768 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
769 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
772 xegpu::TensorDescType newTensorDescTy =
773 prefetchOp.getTensorDescType().dropLayouts();
776 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
778 for (
size_t i = 1; i < newRetIndices.size(); ++i)
779 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
780 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
781 newPrefetchOperands, prefetchOp->getAttrs());
791 using gpu::WarpDistributionPattern::WarpDistributionPattern;
792 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
794 gpu::YieldOp yield = warpOp.getTerminator();
795 Operation *lastNode = yield->getPrevNode();
797 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
802 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
803 barrierOp->getResultTypes(),
804 barrierOp->getOperands(), barrierOp->getAttrs());
835 using gpu::WarpDistributionPattern::WarpDistributionPattern;
836 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
838 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
839 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
842 auto offsets = storeScatterOp.getOffsets();
843 if (!offsets || !isa<VectorType>(offsets.getType()))
845 storeScatterOp,
"Store op must have a vector of offsets argument");
846 VectorType offsetsTy = cast<VectorType>(offsets.getType());
847 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
848 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
850 "Expected 1D offsets and mask vector");
851 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
852 if (storeVecTy.getRank() > 2)
854 storeScatterOp,
"Expected at most 2D result at SG level");
856 std::string layoutPayloadName =
858 std::string layoutOffsetsName =
860 std::string layoutMaskName =
863 xegpu::LayoutAttr layoutPayload =
864 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
865 xegpu::LayoutAttr layoutOffsets =
866 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
867 xegpu::LayoutAttr layoutMask =
868 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
870 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
872 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
874 FailureOr<VectorType> distMaskByWarpOpOrFailure =
875 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
876 if (
failed(distStoreVecByWarpOpOrFailure) ||
877 failed(distOffsetsByWarpOpOrFailure) ||
878 failed(distMaskByWarpOpOrFailure)) {
881 "Some vector operands have no layouts, using defaults instead.");
884 VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
886 VectorType expectedPayloadTy =
888 distPayloadTyByWarpOp.getElementType());
893 distPayloadTyByWarpOp, operands[1].getType(),
894 distOffsetsByWarpOpOrFailure.value(),
895 distMaskByWarpOpOrFailure.value()};
897 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
898 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
900 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
904 newStoreScatterOpOperands[0] = resolveDistributedTy(
905 newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
906 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
907 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
908 storeScatterOp->getAttrs());
910 rewriter.
eraseOp(storeScatterOp);
935 using gpu::WarpDistributionPattern::WarpDistributionPattern;
936 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
941 return isa<xegpu::LoadGatherOp>(op) &&
942 warpOp.getTerminator()->getPrevNode() == op;
944 if (!producedByLastLoad)
946 warpOp,
"The last op is not xegpu::LoadGatherOp");
950 auto offsets = loadGatherOp.getOffsets();
951 if (!offsets || !isa<VectorType>(offsets.getType()) ||
952 !isa<VectorType>(loadGatherOp.getMask().getType()))
955 "Load op must have a vector arguments for offsets and mask");
956 VectorType offsetsTy = cast<VectorType>(offsets.getType());
957 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
958 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
960 "Expected 1D offsets and mask vector");
962 std::string layoutOffsetsName =
964 std::string layoutMaskName =
967 xegpu::LayoutAttr layoutOffsets =
968 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
969 xegpu::LayoutAttr layoutMask =
970 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
972 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
973 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
974 FailureOr<VectorType> distMaskByWarpOpOrFailure =
975 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
976 if (
failed(distOffsetsByWarpOpOrFailure) ||
977 failed(distMaskByWarpOpOrFailure)) {
980 "Some vector operands have no layouts, using defaults instead.");
986 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
987 distMaskByWarpOpOrFailure.value()};
990 VectorType distResultTy =
991 cast<VectorType>(warpOp.getResult(operandIdx).getType());
993 VectorType loadVecTy =
VectorType::get({distResultTy.getNumElements()},
994 distResultTy.getElementType());
996 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
997 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1000 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1003 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1004 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1005 loadGatherOp->getAttrs());
1007 Value distributedVal = newWarpOp.getResult(operandIdx);
1011 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1020 vector::CombiningKind
kind,
1021 int64_t reductionDim,
Location loc,
1024 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1025 VectorType sourceType = src.getType();
1026 int64_t sourceH = sourceType.getShape()[0];
1027 int64_t sourceW = sourceType.getShape()[1];
1028 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1030 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1031 Value reductionResult = arith::ConstantOp::create(
1032 rewriter, loc, acc.getType(),
1039 for (
int i = 0; i < nSlices; ++i) {
1041 if (reductionDim == 1) {
1042 sliceOffsets = {i, 0};
1043 sliceSizes = {1, sourceW};
1045 sliceOffsets = {0, i};
1046 sliceSizes = {sourceH, 1};
1048 vector::ExtractStridedSliceOp extractOp =
1049 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1050 sliceSizes, {1, 1});
1051 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1052 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1055 extractOp.getResult());
1066 Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1067 Value reduction = vector::ReductionOp::create(
1068 rewriter, loc,
kind, slice.getResult(), accExtract);
1070 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1072 return reductionResult;
1131 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1132 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1135 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1141 VectorType sourceType = reductionOp.getSourceVectorType();
1143 if (sourceType.getRank() != 2)
1145 "Only 2D reductions are supported.");
1149 if (reductionDims.size() != 1)
1151 warpOp,
"Only 1 reduction dimension is supported.");
1152 int64_t reductionDim = reductionDims[0];
1153 VectorType distributedResultType =
1154 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1155 VectorType resultType = cast<VectorType>(reductionOp.getType());
1156 xegpu::DistributeLayoutAttr sourceLayout =
1159 FailureOr<VectorType> sourceDistTypeOrFailure =
1160 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1161 if (
failed(sourceDistTypeOrFailure))
1163 warpOp,
"Failed to distribute the source vector type.");
1164 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1166 bool dim0Distributed =
1167 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1168 bool dim1Distributed =
1169 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1170 if (dim0Distributed && dim1Distributed)
1172 warpOp,
"Expecting source to be distributed in a single dimension.");
1173 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1174 if (sourceDistDim == -1)
1176 warpOp,
"Expecting a distributed source vector.");
1177 bool resultDistributed =
1178 distributedResultType.getNumElements() < resultType.getNumElements();
1192 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1193 (sourceDistDim == 1 && reductionDim == 0);
1194 if (isReductionLaneLocal && !resultDistributed)
1196 warpOp,
"Expecting a distributed result for lane-local reduction.");
1198 if (!isReductionLaneLocal && resultDistributed)
1201 "Expecting a broadcasted result for non-lane-local reduction.");
1205 if (isReductionLaneLocal) {
1208 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1209 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1210 {sourceDistType, distributedResultType}, newRetIndices);
1212 Value result = lowerToVectorReductions(
1215 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1224 Value result = lowerToVectorReductions(
1227 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1237 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1238 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1241 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1248 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1249 xegpu::DistributeLayoutAttr sourceLayout =
1251 xegpu::DistributeLayoutAttr resultLayout =
1253 if (!sourceLayout || !resultLayout)
1256 "the source or result of shape_cast op lacks distribution layout");
1260 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1261 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1262 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1264 warpOp,
"shape_cast is rank reducing but source layout is not a "
1265 "slice of result layout");
1266 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1268 warpOp,
"shape_cast is rank increasing but result layout is not a "
1269 "slice of source layout");
1271 FailureOr<VectorType> sourceDistTypeOrFailure =
1272 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1273 shapeCastOp.getSourceVectorType());
1274 if (
failed(sourceDistTypeOrFailure))
1276 warpOp,
"failed to get distributed vector type for source");
1277 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1280 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1281 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1284 Value source = newWarpOp.getResult(newRetIndices[0]);
1286 Value newShapeCast = vector::ShapeCastOp::create(
1287 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1297 struct MemrefExtractAlignedPointerAsIndexDistribution final
1299 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1300 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1303 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1307 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1312 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1313 rewriter, warpOp, extractOp.getSource(),
1314 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1316 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1317 rewriter, newWarpOp.getLoc(), extractOp.
getType(),
1318 newWarpOp.getResult(newRetIndices[0]));
1319 Value distributedVal = newWarpOp.getResult(operandIdx);
1331 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1332 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1335 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1338 warpOp,
"warp result is not a vector::BitCast op");
1341 VectorType distributedSourceType =
1342 getDistVecTypeBasedOnLaneLayout(
1344 bitcastOp.getSourceVectorType())
1345 .value_or(VectorType());
1346 if (!distributedSourceType)
1348 bitcastOp,
"Failed to distribute the source vector type in "
1349 "vector::BitCast op");
1350 VectorType distributedResultType =
1351 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1353 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1354 rewriter, warpOp, bitcastOp.getSource(),
1355 TypeRange{distributedSourceType}, newRetIndices);
1357 auto newBitcastOp = vector::BitCastOp::create(
1358 rewriter, newWarpOp.getLoc(), distributedResultType,
1359 newWarpOp.getResult(newRetIndices[0]));
1360 Value distributedVal = newWarpOp.getResult(operandIdx);
1375 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1376 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1379 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1382 warpOp,
"warp result is not a vector::Transpose op");
1385 xegpu::DistributeLayoutAttr sourceLayout =
1387 xegpu::DistributeLayoutAttr resultLayout =
1389 if (!sourceLayout || !resultLayout)
1392 "the source or result vector of the transpose op lacks layout "
1394 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1395 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1398 if (sourceRank != 2 || resultRank != 2)
1400 transposeOp,
"the source or result vector of the transpose op "
1401 "does not have 2D layout");
1404 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1407 "the source or result vector layouts must be 2D transposes of each "
1409 FailureOr<VectorType> distributedSourceTypeOrFailure =
1410 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1411 transposeOp.getSourceVectorType());
1412 if (
failed(distributedSourceTypeOrFailure))
1414 transposeOp,
"Failed to distribute the source vector type in "
1415 "vector::Transpose op");
1417 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1418 rewriter, warpOp, transposeOp.getVector(),
1419 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1421 auto newTransposeOp = vector::TransposeOp::create(
1422 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1424 Value distributedVal = newWarpOp.getResult(operandIdx);
1433 struct XeGPUSubgroupDistributePass final
1434 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
1435 XeGPUSubgroupDistributePass> {
1436 void runOnOperation()
override;
1442 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1443 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1444 GpuBarrierDistribution, VectorMultiReductionDistribution,
1445 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1446 VectorBitcastDistribution,
1447 MemrefExtractAlignedPointerAsIndexDistribution>(
1449 regularPatternBenefit);
1450 patterns.add<VectorShapeCastDistribution>(
1452 highPatternBenefit);
1460 void XeGPUSubgroupDistributePass::runOnOperation() {
1469 if (!isa<VectorType>(operand.get().getType()))
1472 auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1474 op->emitError(
"Could not find layout attribute for operand ")
1475 << operand.getOperandNumber() <<
" of operation " << op->getName();
1476 signalPassFailure();
1489 signalPassFailure();
1496 getOperation()->walk([&](
Operation *op) {
1497 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1498 vector::moveScalarUniformCode(warpOp);
1507 auto distributionFn = [](
Value val) {
1508 VectorType vecType = dyn_cast<VectorType>(val.getType());
1509 int64_t vecRank = vecType ? vecType.getRank() : 0;
1519 assert(layout.getRank() == vecRank &&
1520 "Expecting vector and layout rank to match");
1525 for (
auto [i, v] :
llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1526 if (v > 1 && vecType.getShape()[i] % v == 0)
1527 distributedDims.push_back(i);
1534 int64_t warpSz) {
return Value(); };
1537 vector::CombiningKind
kind, uint32_t size) {
1539 Value laneVal = vector::ReductionOp::create(builder, loc,
kind, input);
1541 for (uint64_t i = 1; i < size; i <<= 1) {
1542 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
1544 gpu::ShuffleMode::XOR)
1545 .getShuffleResult();
1551 vector::populateDistributeReduction(
1553 regularPatternBenefit);
1555 vector::populatePropagateWarpVectorDistributionPatterns(
1556 patterns, distributionFn, shuffleFn,
1557 regularPatternBenefit);
1559 signalPassFailure();
1569 bool foundWarpOp =
false;
1570 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1580 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1586 Value input = op.getOperand(0);
1587 Value output = op.getResult(0);
1590 xegpu::TensorDescType inputDescType =
1591 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1592 xegpu::TensorDescType outputDescType =
1593 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1594 assert(inputDescType && outputDescType &&
1595 "Unrealized conversion cast must have tensor descriptor types");
1600 if (inputDescType.getLayout()) {
1601 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1603 argument.setType(output.getType());
1604 output.replaceAllUsesWith(argument);
1605 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1606 argument.getOwner()->getParentOp())) {
1607 auto result = loopOp.getTiedLoopResult(argument);
1608 result.setType(output.getType());
1616 if (outputDescType.getLayout())
1617 output.replaceAllUsesWith(input);
1619 if (op->use_empty())
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
virtual int getSubgroupSize() const =0
StringRef getName() const