37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 "resolve_simt_type_mismatch";
68enum PatternHierarchy :
unsigned { Regular = 1, AboveRegular = 2 };
85static Value resolveDistributedTy(
Value orig, T expected,
91 if (isa<VectorType>(orig.
getType())) {
93 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
94 return castOp.getResult();
98 if (isa<xegpu::TensorDescType>(orig.
getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
102 return castOp.getResult(0);
104 llvm_unreachable(
"Unsupported type for reconciliation");
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
115 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
120 return distributedDims;
153 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
154 "to set the warp size");
156 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
157 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
161 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
162 return isa<gpu::WarpExecuteOnLane0Op>(op);
167 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
170 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
172 auto newGpuFunc = gpu::GPUFuncOp::create(
173 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
175 privateAttributionsTypes);
176 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
180 auto laneId = gpu::LaneIdOp::create(
182 mlir::IntegerAttr());
183 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
184 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
185 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
187 newGpuFunc.getArgumentTypes());
188 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
191 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
193 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
194 origRetunOp.getOperands());
198 warpOp.getBodyRegion().begin());
202 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
203 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
241 using gpu::WarpDistributionPattern::WarpDistributionPattern;
242 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
245 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
248 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
252 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
255 descOp,
"the tensor descriptor lacks layout attribute");
257 if (descOp.getMixedOffsets().size())
259 descOp,
"xegpu::CreateNdDescOp must not have offsets");
263 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
264 rewriter, warpOp, descOp->getOperands(),
265 descOp.getOperandTypes(), newRetIndices);
268 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
270 xegpu::TensorDescType distributedTensorDescTy =
271 descOp.getType().dropLayouts();
273 Value newDescOp = xegpu::CreateNdDescOp::create(
274 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
277 Value distributedVal = newWarpOp.getResult(operandIdx);
280 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
319 using gpu::WarpDistributionPattern::WarpDistributionPattern;
320 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
322 gpu::YieldOp yield = warpOp.getTerminator();
323 Operation *lastNode = yield->getPrevNode();
324 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
332 "the store op must have offsets");
337 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
338 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
341 storeOp,
"the source tensor descriptor lacks layout attribute");
343 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
345 if (failed(distributedTypeByWarpOpOrFailure))
347 "Failed to distribute the type");
348 VectorType distributedTypeByWarpOp =
349 distributedTypeByWarpOpOrFailure.value();
353 storeOp.getTensorDesc()};
355 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
356 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
357 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
358 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
368 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
370 if (failed(storeNdDistributedValueTyOrFailure))
372 storeOp,
"Failed to get distributed vector type for the store op");
373 newStoreOperands.push_back(resolveDistributedTy(
374 newWarpOp.getResult(newRetIndices[0]),
375 storeNdDistributedValueTyOrFailure.value(), rewriter));
378 xegpu::TensorDescType distributedTensorDescTy =
379 storeOp.getTensorDescType().dropLayouts();
380 newStoreOperands.push_back(
381 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
382 distributedTensorDescTy, rewriter));
384 for (
size_t i = 2; i < newRetIndices.size(); ++i)
385 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
388 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
389 newStoreOperands, storeOp->getAttrs());
433 using gpu::WarpDistributionPattern::WarpDistributionPattern;
434 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
437 if (!isa<xegpu::LoadNdOp>(op))
442 gpu::YieldOp yield = warpOp.getTerminator();
443 return yield->getPrevNode() == op;
448 warpOp,
"warp result is not a xegpu::LoadNd op");
454 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
455 "determine transpose "
463 "the load op must have offsets");
469 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
470 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
473 loadOp,
"the source tensor descriptor lacks layout attribute");
476 VectorType distributedTypeByWarpOp =
477 cast<VectorType>(warpOp.getResult(operandIdx).getType());
482 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
483 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
484 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
485 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
490 FailureOr<VectorType> loadNdDistValueTyOrFailure =
492 if (failed(loadNdDistValueTyOrFailure))
494 loadOp,
"Failed to get distributed vector type for the load op");
495 xegpu::TensorDescType distributedTensorDescTy =
496 loadOp.getTensorDescType().dropLayouts();
500 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
501 distributedTensorDescTy, rewriter)};
503 for (
size_t i = 1; i < newRetIndices.size(); ++i)
504 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
505 auto newLoadOp = xegpu::LoadNdOp::create(
506 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
507 newLoadOperands, loadOp->getAttrs());
513 newLoadOp.setTranspose(
515 Value distributedVal = newWarpOp.getResult(operandIdx);
519 Value tyResolvedVal = resolveDistributedTy(
520 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
561 using gpu::WarpDistributionPattern::WarpDistributionPattern;
562 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
564 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
567 "warp result is not a xegpu::Dpas op");
572 xegpu::LayoutAttr layoutA =
573 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
574 xegpu::LayoutAttr layoutB =
575 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
576 xegpu::LayoutAttr layoutOut =
577 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
579 if (!layoutA || !layoutB || !layoutOut)
582 "the xegpu::Dpas op lacks layout attribute for A, B or output");
584 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
585 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
586 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
587 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
588 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
589 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
591 if (failed(distLhsTypeByWarpOpOrFailure) ||
592 failed(distRhsTypeByWarpOpOrFailure) ||
593 failed(distResultTypeByWarpOpOrFailure))
596 "Failed to distribute the A, B or output types in xegpu::Dpas op");
601 distLhsTypeByWarpOpOrFailure.value(),
602 distRhsTypeByWarpOpOrFailure.value()};
604 if (dpasOp.getAcc()) {
605 newYieldValues.push_back(dpasOp.getAcc());
606 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
609 SmallVector<size_t> newRetIndices;
610 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
611 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
613 FailureOr<VectorType> expectedDistLhsTyOrFailure =
615 FailureOr<VectorType> expectedDistRhsTyOrFailure =
617 FailureOr<VectorType> expectedDistResultTyOrFailure =
620 if (
failed(expectedDistLhsTyOrFailure) ||
621 failed(expectedDistRhsTyOrFailure) ||
622 failed(expectedDistResultTyOrFailure))
625 "Failed to get distributed vector type for the dpas operands.");
628 SmallVector<Value> newDpasOperands;
629 SmallVector<VectorType> newDpasOperandExpectedTypes;
632 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
633 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
634 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
636 newDpasOperandExpectedTypes.push_back(distributedResultTy);
638 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
639 newDpasOperands.push_back(
640 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
641 newDpasOperandExpectedTypes[i], rewriter));
643 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
644 distributedResultTy, newDpasOperands,
647 Value distributedVal = newWarpOp.getResult(operandIdx);
650 resolveDistributedTy(newDpasOp.getResult(),
651 distResultTypeByWarpOpOrFailure.value(), rewriter);
686 using gpu::WarpDistributionPattern::WarpDistributionPattern;
687 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
688 PatternRewriter &rewriter)
const override {
689 gpu::YieldOp yield = warpOp.getTerminator();
690 Operation *lastNode = yield->getPrevNode();
691 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
695 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
699 "the prefetch op must have offsets");
700 SmallVector<Value> offsetsAsValues =
702 SmallVector<Type> offsetTypes = llvm::map_to_vector(
703 offsetsAsValues, [](Value v) {
return v.
getType(); });
705 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
708 prefetchOp,
"the source tensor descriptor lacks layout attribute");
710 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
711 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
712 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
713 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
714 SmallVector<size_t> newRetIndices;
715 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
716 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
719 xegpu::TensorDescType newTensorDescTy =
720 prefetchOp.getTensorDescType().dropLayouts();
722 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
723 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
725 for (
size_t i = 1; i < newRetIndices.size(); ++i)
726 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
727 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
728 rewriter, newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
729 prefetchOp->getAttrs());
739 using gpu::WarpDistributionPattern::WarpDistributionPattern;
740 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
741 PatternRewriter &rewriter)
const override {
742 gpu::YieldOp yield = warpOp.getTerminator();
743 Operation *lastNode = yield->getPrevNode();
745 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
750 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
751 barrierOp->getResultTypes(),
752 barrierOp->getOperands(), barrierOp->getAttrs());
792 using gpu::WarpDistributionPattern::WarpDistributionPattern;
793 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
794 PatternRewriter &rewriter)
const override {
795 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
796 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
799 auto offsets = storeScatterOp.getOffsets();
800 if (!offsets || !isa<VectorType>(offsets.getType()))
802 storeScatterOp,
"Store op must have a vector of offsets argument");
803 VectorType offsetsTy = cast<VectorType>(offsets.getType());
804 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
805 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
808 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
809 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
812 for (
int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
813 if (storeVecTy.getShape()[i] != 1) {
815 storeScatterOp,
"Only unit dimensions allowed for the leading "
816 "dimensions of the store vector!");
827 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
829 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
831 FailureOr<VectorType> distMaskByWarpOpOrFailure =
833 if (
failed(distStoreVecByWarpOpOrFailure) ||
834 failed(distOffsetsByWarpOpOrFailure) ||
835 failed(distMaskByWarpOpOrFailure)) {
838 "Some vector operands have no layouts, using defaults instead.");
841 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
842 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
843 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
845 SmallVector<size_t> newRetIndices;
846 SmallVector<Value> operands = storeScatterOp->getOperands();
847 SmallVector<Type> operandTypesToYield = {
848 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
850 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
851 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
856 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
857 distPayloadTy.getElementType());
859 VectorType distOffsetsTy1D = VectorType::get(
860 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
861 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
862 distMaskTy.getElementType());
865 Value distPayloadVal = resolveDistributedTy(
866 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
867 Value distOffsetVal = resolveDistributedTy(
868 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
869 Value distMaskVal = resolveDistributedTy(
870 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
872 SmallVector<Value> newStoreScatterOpOperands = {
873 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
876 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
877 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
878 storeScatterOp->getAttrs());
880 rewriter.
eraseOp(storeScatterOp);
890 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
893 assert(maybeCoords.value().size() == 1 &&
894 "Expected one set of distributed offsets");
898 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
904 using gpu::WarpDistributionPattern::WarpDistributionPattern;
905 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
906 PatternRewriter &rewriter)
const override {
907 gpu::YieldOp yield = warpOp.getTerminator();
908 Operation *lastNode = yield->getPrevNode();
909 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
913 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
914 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
916 if (!producedByLastLoad)
918 warpOp,
"The last op is not xegpu::LoadMatrixOp");
921 VectorType sgPayloadTy =
922 dyn_cast<VectorType>(matrixOp.getResult().getType());
923 VectorType warpResultTy =
924 cast<VectorType>(warpOp.getResult(operandIdx).getType());
927 matrixOp,
"the matrix op payload must be a vector type");
929 auto loc = matrixOp.getLoc();
930 auto offsets = matrixOp.getMixedOffsets();
933 "the load op must have offsets");
934 SmallVector<Value> offsetsAsValues =
937 auto layout = matrixOp.getLayoutAttr();
940 matrixOp,
"the matrix operation lacks layout attribute");
942 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
944 if (
failed(distPayloadByWarpOpOrFailure))
946 matrixOp,
"Failed to distribute matrix op payload based on layout.");
948 SmallVector<Value> operands = {matrixOp.getMemDesc()};
949 const unsigned offsetsStartIdx = operands.size();
950 operands.append(offsetsAsValues);
952 SmallVector<Type> operandTypes =
953 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
955 SmallVector<size_t> newRetIndices;
956 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
957 rewriter, warpOp, operands, operandTypes, newRetIndices);
958 SmallVector<Value> newOperands = llvm::map_to_vector(
959 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
961 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
962 ShapedType::kDynamic);
966 ValueRange(newOperands).drop_front(offsetsStartIdx);
968 SmallVector<Value> newCoords = currentOffsets;
971 if (!matrixOp.getSubgroupBlockIoAttr()) {
972 newCoords = computeDistributedCoordinatesForMatrixOp(
973 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
976 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
977 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
978 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
979 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
982 newWarpOp.getResult(operandIdx),
983 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
990 using gpu::WarpDistributionPattern::WarpDistributionPattern;
991 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
992 PatternRewriter &rewriter)
const override {
993 gpu::YieldOp yield = warpOp.getTerminator();
994 Operation *lastNode = yield->getPrevNode();
995 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
999 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1002 matrixOp,
"the matrix op payload must be a vector type");
1004 auto loc = matrixOp.getLoc();
1005 auto offsets = matrixOp.getMixedOffsets();
1006 if (offsets.empty())
1008 "the store op must have offsets");
1009 SmallVector<Value> offsetsAsValues =
1012 auto layout = matrixOp.getLayoutAttr();
1015 matrixOp,
"the matrix operation lacks layout attribute");
1017 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1019 if (
failed(distPayloadByWarpOpOrFailure))
1021 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1023 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1024 const unsigned offsetsStartIdx = operands.size();
1025 operands.append(offsetsAsValues);
1027 SmallVector<Type> operandTypes =
1028 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
1029 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1031 SmallVector<size_t> newRetIndices;
1032 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1033 rewriter, warpOp, operands, operandTypes, newRetIndices);
1034 SmallVector<Value> newOperands = llvm::map_to_vector(
1035 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1037 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1038 ShapedType::kDynamic);
1042 ValueRange(newOperands).drop_front(offsetsStartIdx);
1044 SmallVector<Value> newCoords = currentOffsets;
1047 if (!matrixOp.getSubgroupBlockIoAttr()) {
1048 newCoords = computeDistributedCoordinatesForMatrixOp(
1049 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1053 xegpu::StoreMatrixOp::create(
1054 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1056 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1091 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1092 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1093 PatternRewriter &rewriter)
const override {
1094 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1097 return isa<xegpu::LoadGatherOp>(op) &&
1098 warpOp.getTerminator()->getPrevNode() == op;
1100 if (!producedByLastLoad)
1102 warpOp,
"The last op is not xegpu::LoadGatherOp");
1106 auto offsets = loadGatherOp.getOffsets();
1107 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1108 !isa<VectorType>(loadGatherOp.getMask().getType()))
1111 "Load op must have a vector arguments for offsets and mask");
1112 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1113 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1114 VectorType resultVecTy =
1115 cast<VectorType>(loadGatherOp.getResult().getType());
1117 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1118 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1119 for (
int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1120 if (resultVecTy.getShape()[i] != 1) {
1122 loadGatherOp,
"Only unit dimensions allowed for the leading "
1123 "dimensions of the load vector!");
1127 auto layoutOffsets =
1131 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1133 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1135 if (
failed(distOffsetsByWarpOpOrFailure) ||
1136 failed(distMaskByWarpOpOrFailure)) {
1139 "Some vector operands have no layouts, using defaults instead.");
1142 SmallVector<size_t> newRetIndices;
1143 SmallVector<Value> operands = loadGatherOp->getOperands();
1146 VectorType distResultTy =
1147 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1148 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1149 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1151 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1152 distOffsetsTy, distMaskTy};
1154 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1155 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1160 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1161 distResultTy.getElementType());
1163 VectorType distOffsetsTy1D =
1164 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1166 VectorType distMaskTy1D =
1167 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1170 Value distOffsetVal = resolveDistributedTy(
1171 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1172 Value distmaskVal = resolveDistributedTy(
1173 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1175 SmallVector<Value> newLoadGatherOperands = {
1176 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1178 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1179 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1180 loadGatherOp->getAttrs());
1182 Value distributedVal = newWarpOp.getResult(operandIdx);
1186 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1198 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1199 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1200 PatternRewriter &rewriter)
const override {
1202 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1205 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->
getNumRegions())
1207 int operandIdx = -1;
1209 OpOperand *operand = getWarpResult(
1210 warpOp, [&](Operation *op) {
return warpRegionPreYieldOp == op; });
1215 warpOp.getResult(operandIdx).getType())
1217 "The op result is not uniform.");
1221 bool uniformValuesOnly =
1222 llvm::all_of(warpRegionPreYieldOp->
getResults(), [](Value v) {
1223 return !xegpu::getDistributeLayoutAttr(v);
1225 uniformValuesOnly &=
1226 llvm::all_of(warpRegionPreYieldOp->
getOpOperands(), [](OpOperand &opr) {
1227 return !xegpu::getDistributeLayoutAttr(opr);
1229 if (!uniformValuesOnly)
1231 "Some values are not uniform.");
1232 SmallVector<size_t> newRetIndices;
1233 SmallVector<Value> operands =
1234 llvm::to_vector_of<Value>(warpRegionPreYieldOp->
getOperands());
1235 SmallVector<Type> operandTypes =
1237 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1238 rewriter, warpOp, operands, operandTypes, newRetIndices);
1241 IRMapping operandMapper;
1242 for (
auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1243 operandMapper.
map(warpRegionPreYieldOp->
getOperand(oldOperandIdx),
1244 newWarpOp->getResult(newOperandIdx));
1245 Operation *clonedOp = rewriter.
clone(*warpRegionPreYieldOp, operandMapper);
1247 rewriter.
eraseOp(warpRegionPreYieldOp);
1249 assert(operandIdx != -1 &&
"Expected a warp result for the operation");
1313 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1314 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1315 PatternRewriter &rewriter)
const override {
1316 OpOperand *yieldOperand =
1317 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1323 VectorType sourceType = reductionOp.getSourceVectorType();
1325 if (sourceType.getRank() != 2)
1327 "Only 2D reductions are supported.");
1328 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1331 if (reductionDims.size() != 1)
1333 warpOp,
"Only 1 reduction dimension is supported.");
1334 int64_t reductionDim = reductionDims[0];
1335 VectorType distributedResultType =
1336 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1337 VectorType resultType = cast<VectorType>(reductionOp.getType());
1338 xegpu::DistributeLayoutAttr sourceLayout =
1341 FailureOr<VectorType> sourceDistTypeOrFailure =
1343 if (
failed(sourceDistTypeOrFailure))
1345 warpOp,
"Failed to distribute the source vector type.");
1346 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1348 bool dim0Distributed =
1349 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1350 bool dim1Distributed =
1351 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1352 if (dim0Distributed && dim1Distributed)
1354 warpOp,
"Expecting source to be distributed in a single dimension.");
1355 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1356 if (sourceDistDim == -1)
1358 warpOp,
"Expecting a distributed source vector.");
1359 bool resultDistributed =
1360 distributedResultType.getNumElements() < resultType.getNumElements();
1374 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1375 (sourceDistDim == 1 && reductionDim == 0);
1376 if (isReductionLaneLocal && !resultDistributed)
1378 warpOp,
"Expecting a distributed result for lane-local reduction.");
1380 if (!isReductionLaneLocal && resultDistributed)
1383 "Expecting a broadcasted result for non-lane-local reduction.");
1387 if (isReductionLaneLocal) {
1389 SmallVector<size_t> newRetIndices;
1390 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1391 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1392 {sourceDistType, distributedResultType}, newRetIndices);
1397 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1409 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1485 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1487 PatternRewriter &rewriter)
const override {
1488 OpOperand *yieldOperand =
1496 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1497 VectorType destType =
1498 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1500 xegpu::DistributeLayoutAttr sourceLayout =
1502 xegpu::DistributeLayoutAttr resultLayout =
1505 FailureOr<VectorType> sourceDistType;
1506 Type sourceElemOrDistType;
1510 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1513 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1515 broadcastOp.emitWarning()
1516 <<
"Broadcast input layout must be a slice of result layout.";
1519 if (rankDiff == 0) {
1520 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1521 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1522 broadcastUnitDimsSet.end());
1523 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1526 warpOp,
"For same-rank broadcast, source must be identical to "
1527 "adjusted result layouts with unit dims.");
1528 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1529 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1534 if (
failed(sourceDistType)) {
1536 warpOp,
"Failed to distribute the source vector type.");
1538 sourceElemOrDistType = sourceDistType.value();
1544 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1546 sourceElemOrDistType = broadcastOp.getSourceType();
1548 FailureOr<VectorType> destDistType =
1550 if (
failed(destDistType)) {
1552 warpOp,
"Failed to distribute the dest vector type.");
1555 SmallVector<size_t> newRetIndices;
1557 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1560 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1562 Value newBroadcast = distributedSource;
1564 if (sourceElemOrDistType != destDistType.value()) {
1567 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1568 destDistType.value(), distributedSource);
1579 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1581 PatternRewriter &rewriter)
const override {
1582 OpOperand *yieldOperand =
1590 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1591 xegpu::DistributeLayoutAttr sourceLayout =
1593 xegpu::DistributeLayoutAttr resultLayout =
1595 if (!sourceLayout || !resultLayout)
1598 "the source or result of shape_cast op lacks distribution layout");
1600 FailureOr<VectorType> sourceDistTypeOrFailure =
1602 shapeCastOp.getSourceVectorType());
1603 if (
failed(sourceDistTypeOrFailure))
1605 warpOp,
"failed to get distributed vector type for source");
1606 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1608 SmallVector<size_t> newRetIndices;
1610 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1613 Value source = newWarpOp.getResult(newRetIndices[0]);
1615 Value newShapeCast = vector::ShapeCastOp::create(
1616 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1627struct VectorExtractStridedSliceDistribution
1629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1631 PatternRewriter &rewriter)
const override {
1632 OpOperand *operand =
1633 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1639 auto distributedType =
1640 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1642 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1643 auto distributedDims =
1644 getDistributedDims(extractResultType, distributedType);
1648 VectorType updatedSourceType = extractOp.getSourceVectorType();
1649 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1650 extractOp.getSizes(), [](Attribute attr) { return attr; });
1651 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1652 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1653 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1654 extractOp.getStrides(), [](Attribute attr) { return attr; });
1658 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1659 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1661 extractOp.getSourceVectorType().getDimSize(i)));
1663 updatedStrides.push_back(
1669 if (distributedDims.size() > 0) {
1670 if (distributedDims.size() != 1)
1672 warpOp,
"Source can not be distributed in multiple dimensions.");
1673 int64_t distributedDim = distributedDims[0];
1674 int sourceDistrDimSize =
1675 extractOp.getSourceVectorType().getShape()[distributedDim];
1677 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1679 warpOp,
"the source of extract_strided_slice op lacks distribution "
1681 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1684 int subgroupSize = sourceLaneLayout[distributedDim];
1687 if (sourceDistrDimSize % subgroupSize != 0)
1690 "Source size along distributed dimension is not a multiple of "
1692 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1694 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1696 warpOp,
"Expecting unit lane data in source layout");
1699 int64_t distrDimOffset =
1700 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1701 if (distrDimOffset % subgroupSize != 0)
1703 warpOp,
"Offset along distributed dimension "
1704 "is not a multiple of subgroup size.");
1706 sourceLayout, extractOp.getSourceVectorType())
1710 distributedType.getDimSize(distributedDim));
1713 updatedOffsets[distributedDim] =
1718 SmallVector<size_t> newRetIndices;
1720 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1723 Value source = newWarpOp.getResult(newRetIndices[0]);
1725 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1726 rewriter, extractOp.getLoc(), distributedType, source,
1727 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1728 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1729 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1739struct VectorInsertStridedSliceDistribution
1741 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1743 PatternRewriter &rewriter)
const override {
1744 OpOperand *operand =
getWarpResult(warpOp, [&](Operation *op) {
1746 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1747 warpOp.getTerminator()->getPrevNode() == op;
1754 auto distributedType =
1755 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1757 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1758 auto destDistributedDims =
1759 getDistributedDims(insertResultType, distributedType);
1763 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1764 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1765 VectorType updatedSourceType = insertOp.getSourceVectorType();
1766 VectorType updatedDestType = insertOp.getDestVectorType();
1767 if (destDistributedDims.size() > 0) {
1769 if (destDistributedDims.size() != 1)
1772 "Expecting source to be distributed in a single dimension.");
1773 int64_t destDistributedDim = destDistributedDims[0];
1775 VectorType srcType = insertOp.getSourceVectorType();
1776 VectorType destType = insertOp.getDestVectorType();
1780 int64_t sourceDistributedDim =
1781 destDistributedDim - (destType.getRank() - srcType.getRank());
1782 if (sourceDistributedDim < 0)
1785 "distributed dimension must be in the last k (i.e. source "
1786 "rank) dims of dest vector");
1787 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1791 if (!destLayout || !sourceLayout ||
1792 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1793 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1795 warpOp,
"the source or dest of insert_strided_slice op lacks "
1796 "distribution layout");
1800 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1803 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1804 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1805 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1806 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1808 warpOp,
"Expecting unit lane data in source and dest layouts");
1810 if (srcDistrDimSize % subgroupSize != 0)
1812 warpOp,
"Distributed dimension size in source is not a multiple of "
1816 int64_t destDistrDimOffset =
1817 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1818 if (destDistrDimOffset % subgroupSize != 0)
1821 "Offset along distributed dimension in dest is not a multiple of "
1825 sourceLayout, insertOp.getSourceVectorType())
1828 destLayout, insertOp.getDestVectorType())
1832 updatedOffsets[destDistributedDim] =
1837 SmallVector<size_t> newRetIndices;
1839 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1840 {updatedSourceType, updatedDestType}, newRetIndices);
1843 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1844 Value dest = newWarpOp.getResult(newRetIndices[1]);
1846 Value newInsertOp = vector::InsertStridedSliceOp::create(
1847 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1848 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1849 insertOp.getStrides());
1859struct MemrefExtractAlignedPointerAsIndexDistribution final
1861 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1862 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1863 PatternRewriter &rewriter)
const override {
1864 OpOperand *operand = getWarpResult(
1865 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1869 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1873 SmallVector<size_t> newRetIndices;
1874 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1875 rewriter, warpOp, extractOp.getSource(),
1876 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1878 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1879 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1880 newWarpOp.getResult(newRetIndices[0]));
1881 Value resultVal = newWarpOp.getResult(operandIdx);
1893 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1894 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1895 PatternRewriter &rewriter)
const override {
1896 OpOperand *operand =
1897 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1900 warpOp,
"warp result is not a vector::BitCast op");
1903 VectorType distributedSourceType =
1906 bitcastOp.getSourceVectorType())
1907 .value_or(VectorType());
1908 if (!distributedSourceType)
1910 bitcastOp,
"Failed to distribute the source vector type in "
1911 "vector::BitCast op");
1912 VectorType distributedResultType =
1913 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1914 SmallVector<size_t> newRetIndices;
1915 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1916 rewriter, warpOp, bitcastOp.getSource(),
1917 TypeRange{distributedSourceType}, newRetIndices);
1919 auto newBitcastOp = vector::BitCastOp::create(
1920 rewriter, newWarpOp.getLoc(), distributedResultType,
1921 newWarpOp.getResult(newRetIndices[0]));
1922 Value distributedVal = newWarpOp.getResult(operandIdx);
1937 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1938 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1939 PatternRewriter &rewriter)
const override {
1940 OpOperand *operand =
1941 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1944 warpOp,
"warp result is not a vector::Transpose op");
1947 xegpu::DistributeLayoutAttr sourceLayout =
1949 xegpu::DistributeLayoutAttr resultLayout =
1951 if (!sourceLayout || !resultLayout)
1954 "the source or result vector of the transpose op lacks layout "
1956 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1957 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1960 if (sourceRank != 2 || resultRank != 2)
1962 transposeOp,
"the source or result vector of the transpose op "
1963 "does not have 2D layout");
1964 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1966 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1969 "the source or result vector layouts must be 2D transposes of each "
1971 FailureOr<VectorType> distributedSourceTypeOrFailure =
1973 transposeOp.getSourceVectorType());
1974 if (
failed(distributedSourceTypeOrFailure))
1976 transposeOp,
"Failed to distribute the source vector type in "
1977 "vector::Transpose op");
1978 SmallVector<size_t> newRetIndices;
1979 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1980 rewriter, warpOp, transposeOp.getVector(),
1981 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1983 auto newTransposeOp = vector::TransposeOp::create(
1984 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1986 Value distributedVal = newWarpOp.getResult(operandIdx);
1997 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1998 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1999 PatternRewriter &rewriter)
const override {
2000 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2003 warpOp,
"warp result is not a vector::StepOp op");
2006 xegpu::DistributeLayoutAttr resultLayout =
2010 stepOp,
"the result vector of the step op lacks layout "
2012 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2015 stepOp,
"the result layout must be a slice layout");
2016 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2018 stepOp,
"expecting 1 dim in the effective result layout");
2021 auto loc = stepOp.getLoc();
2022 auto stepResultVecTy = stepOp.getResult().getType();
2023 Value distributedVal = warpOp.getResult(operandIdx);
2024 VectorType newVecTy = cast<VectorType>(distributedVal.
getType());
2026 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2027 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2028 if (
failed(laneDataBlockCoords))
2030 stepOp,
"failed to compute lane data block coordinates");
2032 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2033 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2034 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
2035 newVecTy.getNumElements() / laneDataBlockLength);
2036 SmallVector<Value> stepVals;
2044 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2045 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2046 stepVals.push_back(laneDataBlockStartCoord);
2047 for (
int i = 1; i < laneDataBlockLength; ++i) {
2049 stepVals.push_back(arith::AddIOp::create(
2050 rewriter, loc, laneDataBlockStartCoord, offset));
2053 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
2054 "Expecting the number of step values to match the number of "
2055 "elements in the vector");
2057 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2063struct ConvertLayoutDistribution
2068 PatternRewriter &rewriter)
const override {
2069 auto inputLayout = op.getInputLayoutAttr();
2070 auto targetLayout = op.getTargetLayoutAttr();
2072 if (!inputLayout || !targetLayout)
2075 if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
2077 op,
"lowering incompatible convert_layout not yet supported");
2087struct XeGPUSubgroupDistributePass final
2089 XeGPUSubgroupDistributePass> {
2090 void runOnOperation()
override;
2096 patterns.
add<CreateNdDescDistribution, StoreNdDistribution,
2097 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2098 GpuBarrierDistribution, VectorMultiReductionDistribution,
2099 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2100 VectorBitcastDistribution, LoadMatrixDistribution,
2101 StoreMatrixDistribution, ConvertLayoutDistribution,
2102 MemrefExtractAlignedPointerAsIndexDistribution>(
2104 PatternHierarchy::Regular);
2108 .
add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2109 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2110 VectorStepSliceDistribution, SinkUniformOps>(
2112 PatternHierarchy::AboveRegular);
2120void XeGPUSubgroupDistributePass::runOnOperation() {
2127 signalPassFailure();
2138 signalPassFailure();
2145 getOperation()->walk([&](Operation *op) {
2146 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2147 vector::moveScalarUniformCode(warpOp);
2156 auto distributionFn = [](Value val) {
2157 VectorType vecType = dyn_cast<VectorType>(val.getType());
2158 int64_t vecRank = vecType ? vecType.getRank() : 0;
2167 assert(layout.getRank() == vecRank &&
2168 "Expecting vector and layout rank to match");
2172 SmallVector<unsigned int> distributedDims;
2173 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2174 if (v > 1 && vecType.getShape()[i] % v == 0)
2175 distributedDims.push_back(i);
2181 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2182 int64_t warpSz) {
return Value(); };
2184 vector::populateDistributeReduction(
2186 PatternHierarchy::Regular);
2188 vector::populatePropagateWarpVectorDistributionPatterns(
2189 patterns, distributionFn, shuffleFn,
2190 PatternHierarchy::Regular);
2192 signalPassFailure();
2202 bool foundWarpOp =
false;
2203 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2213 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2219 Value input = op.getOperand(0);
2220 Value output = op.getResult(0);
2223 xegpu::TensorDescType inputDescType =
2224 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2225 xegpu::TensorDescType outputDescType =
2226 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2227 assert(inputDescType && outputDescType &&
2228 "Unrealized conversion cast must have tensor descriptor types");
2233 if (inputDescType.getLayout()) {
2234 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2236 argument.setType(output.
getType());
2238 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2239 argument.getOwner()->getParentOp())) {
2240 auto result = loopOp.getTiedLoopResult(argument);
2249 if (outputDescType.getLayout())
2252 if (op->use_empty())
static Type getElementType(Type type)
Determine the element type of type.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0