37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 "resolve_simt_type_mismatch";
68enum PatternHierarchy :
unsigned { Regular = 1, AboveRegular = 2 };
85static Value resolveDistributedTy(
Value orig, T expected,
91 if (isa<VectorType>(orig.
getType())) {
93 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
94 return castOp.getResult();
98 if (isa<xegpu::TensorDescType>(orig.
getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
102 return castOp.getResult(0);
104 llvm_unreachable(
"Unsupported type for reconciliation");
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
115 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
120 return distributedDims;
153 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 if (!gpuFuncOp.getBody().hasOneBlock())
157 gpuFuncOp,
"expected gpu.func to have a single block");
160 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
161 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
165 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
166 return isa<gpu::WarpExecuteOnLane0Op>(op);
169 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
170 gpuFuncOp.getBlocks().back().getTerminator());
173 gpuFuncOp,
"expected gpu.func terminator to be gpu.return");
176 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
179 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
181 auto newGpuFunc = gpu::GPUFuncOp::create(
182 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
184 privateAttributionsTypes);
185 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
189 auto laneId = gpu::LaneIdOp::create(
191 mlir::IntegerAttr());
192 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
193 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
194 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
196 newGpuFunc.getArgumentTypes());
197 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
200 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
201 origReturnOp.getOperands());
202 rewriter.
eraseOp(origReturnOp);
205 warpOp.getBodyRegion().begin());
209 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
210 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
248 using gpu::WarpDistributionPattern::WarpDistributionPattern;
249 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
252 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
255 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
259 xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
262 descOp,
"the tensor descriptor lacks layout attribute");
264 if (descOp.getMixedOffsets().size())
266 descOp,
"xegpu::CreateNdDescOp must not have offsets");
270 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
271 rewriter, warpOp, descOp->getOperands(),
272 descOp.getOperandTypes(), newRetIndices);
275 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
277 xegpu::TensorDescType distributedTensorDescTy =
278 descOp.getType().dropLayouts();
280 Value newDescOp = xegpu::CreateNdDescOp::create(
281 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
284 Value distributedVal = newWarpOp.getResult(operandIdx);
287 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
326 using gpu::WarpDistributionPattern::WarpDistributionPattern;
327 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
329 gpu::YieldOp yield = warpOp.getTerminator();
330 Operation *lastNode = yield->getPrevNode();
331 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
339 "the store op must have offsets");
344 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
345 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
348 storeOp,
"the source tensor descriptor lacks layout attribute");
350 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
352 if (failed(distributedTypeByWarpOpOrFailure))
354 "Failed to distribute the type");
355 VectorType distributedTypeByWarpOp =
356 distributedTypeByWarpOpOrFailure.value();
360 storeOp.getTensorDesc()};
362 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
363 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
364 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
365 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
375 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
377 if (failed(storeNdDistributedValueTyOrFailure))
379 storeOp,
"Failed to get distributed vector type for the store op");
380 newStoreOperands.push_back(resolveDistributedTy(
381 newWarpOp.getResult(newRetIndices[0]),
382 storeNdDistributedValueTyOrFailure.value(), rewriter));
385 xegpu::TensorDescType distributedTensorDescTy =
386 storeOp.getTensorDescType().dropLayouts();
387 newStoreOperands.push_back(
388 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
389 distributedTensorDescTy, rewriter));
391 for (
size_t i = 2; i < newRetIndices.size(); ++i)
392 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
395 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
396 newStoreOperands, storeOp->getAttrs());
440 using gpu::WarpDistributionPattern::WarpDistributionPattern;
441 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
444 if (!isa<xegpu::LoadNdOp>(op))
449 gpu::YieldOp yield = warpOp.getTerminator();
450 return yield->getPrevNode() == op;
455 warpOp,
"warp result is not a xegpu::LoadNd op");
461 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
462 "determine transpose "
470 "the load op must have offsets");
476 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
477 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
480 loadOp,
"the source tensor descriptor lacks layout attribute");
483 VectorType distributedTypeByWarpOp =
484 cast<VectorType>(warpOp.getResult(operandIdx).getType());
489 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
490 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
491 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
492 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
497 FailureOr<VectorType> loadNdDistValueTyOrFailure =
499 if (failed(loadNdDistValueTyOrFailure))
501 loadOp,
"Failed to get distributed vector type for the load op");
502 xegpu::TensorDescType distributedTensorDescTy =
503 loadOp.getTensorDescType().dropLayouts();
507 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
508 distributedTensorDescTy, rewriter)};
510 for (
size_t i = 1; i < newRetIndices.size(); ++i)
511 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
512 auto newLoadOp = xegpu::LoadNdOp::create(
513 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
514 newLoadOperands, loadOp->getAttrs());
520 newLoadOp.setTranspose(
522 Value distributedVal = newWarpOp.getResult(operandIdx);
526 Value tyResolvedVal = resolveDistributedTy(
527 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
568 using gpu::WarpDistributionPattern::WarpDistributionPattern;
569 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
571 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
574 "warp result is not a xegpu::Dpas op");
579 xegpu::LayoutAttr layoutA =
580 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
581 xegpu::LayoutAttr layoutB =
582 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
583 xegpu::LayoutAttr layoutOut =
584 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
586 if (!layoutA || !layoutB || !layoutOut)
589 "the xegpu::Dpas op lacks layout attribute for A, B or output");
591 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
592 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
593 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
594 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
595 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
596 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
598 if (failed(distLhsTypeByWarpOpOrFailure) ||
599 failed(distRhsTypeByWarpOpOrFailure) ||
600 failed(distResultTypeByWarpOpOrFailure))
603 "Failed to distribute the A, B or output types in xegpu::Dpas op");
608 distLhsTypeByWarpOpOrFailure.value(),
609 distRhsTypeByWarpOpOrFailure.value()};
611 if (dpasOp.getAcc()) {
612 newYieldValues.push_back(dpasOp.getAcc());
613 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
617 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
618 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
620 FailureOr<VectorType> expectedDistLhsTyOrFailure =
622 FailureOr<VectorType> expectedDistRhsTyOrFailure =
624 FailureOr<VectorType> expectedDistResultTyOrFailure =
627 if (failed(expectedDistLhsTyOrFailure) ||
628 failed(expectedDistRhsTyOrFailure) ||
629 failed(expectedDistResultTyOrFailure))
632 "Failed to get distributed vector type for the dpas operands.");
639 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
640 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
641 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
643 newDpasOperandExpectedTypes.push_back(distributedResultTy);
645 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
646 newDpasOperands.push_back(
647 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
648 newDpasOperandExpectedTypes[i], rewriter));
650 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
651 distributedResultTy, newDpasOperands,
654 Value distributedVal = newWarpOp.getResult(operandIdx);
657 resolveDistributedTy(newDpasOp.getResult(),
658 distResultTypeByWarpOpOrFailure.value(), rewriter);
693 using gpu::WarpDistributionPattern::WarpDistributionPattern;
694 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
696 gpu::YieldOp yield = warpOp.getTerminator();
697 Operation *lastNode = yield->getPrevNode();
698 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
706 "the prefetch op must have offsets");
712 xegpu::DistributeLayoutAttr layout =
713 prefetchOp.getTensorDescType().getLayoutAttr();
716 prefetchOp,
"the source tensor descriptor lacks layout attribute");
720 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
721 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
723 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
724 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
727 xegpu::TensorDescType newTensorDescTy =
728 prefetchOp.getTensorDescType().dropLayouts();
731 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
733 for (
size_t i = 1; i < newRetIndices.size(); ++i)
734 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
735 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
736 rewriter, newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
737 prefetchOp->getAttrs());
747 using gpu::WarpDistributionPattern::WarpDistributionPattern;
748 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
750 gpu::YieldOp yield = warpOp.getTerminator();
751 Operation *lastNode = yield->getPrevNode();
753 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
758 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
759 barrierOp->getResultTypes(),
760 barrierOp->getOperands(), barrierOp->getAttrs());
800 using gpu::WarpDistributionPattern::WarpDistributionPattern;
801 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
803 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
804 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
807 Value offsets = storeScatterOp.getOffsets();
808 if (!isa<VectorType>(offsets.
getType()))
810 storeScatterOp,
"Store op must have a vector of offsets argument");
811 VectorType offsetsTy = cast<VectorType>(offsets.
getType());
812 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
813 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
816 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
817 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
820 for (
int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
821 if (storeVecTy.getShape()[i] != 1) {
823 storeScatterOp,
"Only unit dimensions allowed for the leading "
824 "dimensions of the store vector!");
828 auto layoutPayload = storeScatterOp.getLayoutAttr();
831 auto layoutMask = layoutOffsets;
833 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
834 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
835 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
836 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
837 FailureOr<VectorType> distMaskByWarpOpOrFailure =
838 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
839 if (failed(distStoreVecByWarpOpOrFailure) ||
840 failed(distOffsetsByWarpOpOrFailure) ||
841 failed(distMaskByWarpOpOrFailure)) {
844 "Some vector operands have no layouts, using defaults instead.");
847 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
848 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
849 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
854 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
856 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
857 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
862 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
863 distPayloadTy.getElementType());
865 VectorType distOffsetsTy1D = VectorType::get(
866 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
867 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
868 distMaskTy.getElementType());
871 Value distPayloadVal = resolveDistributedTy(
872 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
873 Value distOffsetVal = resolveDistributedTy(
874 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
875 Value distMaskVal = resolveDistributedTy(
876 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
879 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
882 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
883 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
884 storeScatterOp->getAttrs());
886 rewriter.
eraseOp(storeScatterOp);
896 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
897 if (failed(maybeCoords))
899 assert(maybeCoords.value().size() == 1 &&
900 "Expected one set of distributed offsets");
904 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
910 using gpu::WarpDistributionPattern::WarpDistributionPattern;
911 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
913 gpu::YieldOp yield = warpOp.getTerminator();
914 Operation *lastNode = yield->getPrevNode();
915 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
920 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
922 if (!producedByLastLoad)
924 warpOp,
"The last op is not xegpu::LoadMatrixOp");
927 VectorType sgPayloadTy =
928 dyn_cast<VectorType>(matrixOp.getResult().getType());
929 VectorType warpResultTy =
930 cast<VectorType>(warpOp.getResult(operandIdx).getType());
933 matrixOp,
"the matrix op payload must be a vector type");
935 auto loc = matrixOp.getLoc();
936 auto offsets = matrixOp.getMixedOffsets();
939 "the load op must have offsets");
943 auto layout = matrixOp.getLayoutAttr();
946 matrixOp,
"the matrix operation lacks layout attribute");
948 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
949 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
950 if (failed(distPayloadByWarpOpOrFailure))
952 matrixOp,
"Failed to distribute matrix op payload based on layout.");
955 const unsigned offsetsStartIdx = operands.size();
956 operands.append(offsetsAsValues);
959 llvm::map_to_vector(operands, [](
Value v) {
return v.
getType(); });
962 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
963 rewriter, warpOp, operands, operandTypes, newRetIndices);
965 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
968 ShapedType::kDynamic);
972 ValueRange(newOperands).drop_front(offsetsStartIdx);
977 if (!matrixOp.getSubgroupBlockIoAttr()) {
978 newCoords = computeDistributedCoordinatesForMatrixOp(
979 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
982 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
983 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
984 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
985 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
988 newWarpOp.getResult(operandIdx),
989 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
996 using gpu::WarpDistributionPattern::WarpDistributionPattern;
997 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
999 gpu::YieldOp yield = warpOp.getTerminator();
1000 Operation *lastNode = yield->getPrevNode();
1001 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1005 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1008 matrixOp,
"the matrix op payload must be a vector type");
1010 auto loc = matrixOp.getLoc();
1011 auto offsets = matrixOp.getMixedOffsets();
1012 if (offsets.empty())
1014 "the store op must have offsets");
1018 auto layout = matrixOp.getLayoutAttr();
1021 matrixOp,
"the matrix operation lacks layout attribute");
1023 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1024 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1025 if (failed(distPayloadByWarpOpOrFailure))
1027 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1030 const unsigned offsetsStartIdx = operands.size();
1031 operands.append(offsetsAsValues);
1034 llvm::map_to_vector(operands, [](
Value v) {
return v.
getType(); });
1035 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1038 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1039 rewriter, warpOp, operands, operandTypes, newRetIndices);
1041 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1044 ShapedType::kDynamic);
1048 ValueRange(newOperands).drop_front(offsetsStartIdx);
1053 if (!matrixOp.getSubgroupBlockIoAttr()) {
1054 newCoords = computeDistributedCoordinatesForMatrixOp(
1055 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1059 xegpu::StoreMatrixOp::create(
1060 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1062 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1097 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1098 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1103 return isa<xegpu::LoadGatherOp>(op) &&
1104 warpOp.getTerminator()->getPrevNode() == op;
1106 if (!producedByLastLoad)
1108 warpOp,
"The last op is not xegpu::LoadGatherOp");
1112 Value offsets = loadGatherOp.getOffsets();
1113 if (!isa<VectorType>(offsets.getType()) ||
1114 !isa<VectorType>(loadGatherOp.getMask().getType()))
1117 "Load op must have vector arguments for offsets and mask");
1118 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1119 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1120 VectorType resultVecTy =
1121 cast<VectorType>(loadGatherOp.getResult().getType());
1123 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1124 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1125 for (
int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1126 if (resultVecTy.getShape()[i] != 1) {
1128 loadGatherOp,
"Only unit dimensions allowed for the leading "
1129 "dimensions of the load vector!");
1133 auto layoutPayload = loadGatherOp.getLayoutAttr();
1134 auto layoutOffsets =
1136 auto layoutMask = layoutOffsets;
1138 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1139 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1140 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1141 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1142 if (failed(distOffsetsByWarpOpOrFailure) ||
1143 failed(distMaskByWarpOpOrFailure)) {
1146 "Some vector operands have no layouts, using defaults instead.");
1153 VectorType distResultTy =
1154 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1155 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1156 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1159 distOffsetsTy, distMaskTy};
1161 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1162 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1167 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1168 distResultTy.getElementType());
1170 VectorType distOffsetsTy1D =
1171 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1173 VectorType distMaskTy1D =
1174 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1177 Value distOffsetVal = resolveDistributedTy(
1178 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1179 Value distmaskVal = resolveDistributedTy(
1180 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1183 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1185 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1186 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1187 loadGatherOp->getAttrs());
1189 Value distributedVal = newWarpOp.getResult(operandIdx);
1193 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1205 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1206 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1209 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1212 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->
getNumRegions())
1214 int operandIdx = -1;
1217 warpOp, [&](
Operation *op) {
return warpRegionPreYieldOp == op; });
1222 warpOp.getResult(operandIdx).getType())
1224 "The op result is not uniform.");
1228 bool uniformValuesOnly =
1230 return !xegpu::getDistributeLayoutAttr(v);
1232 uniformValuesOnly &=
1234 return !xegpu::getDistributeLayoutAttr(opr);
1236 if (!uniformValuesOnly)
1238 "Some values are not uniform.");
1241 llvm::to_vector_of<Value>(warpRegionPreYieldOp->
getOperands());
1244 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1245 rewriter, warpOp, operands, operandTypes, newRetIndices);
1249 for (
auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1250 operandMapper.
map(warpRegionPreYieldOp->
getOperand(oldOperandIdx),
1251 newWarpOp->getResult(newOperandIdx));
1252 Operation *clonedOp = rewriter.
clone(*warpRegionPreYieldOp, operandMapper);
1254 rewriter.
eraseOp(warpRegionPreYieldOp);
1256 assert(operandIdx != -1 &&
"Expected a warp result for the operation");
1320 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1321 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1324 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1330 VectorType sourceType = reductionOp.getSourceVectorType();
1331 int64_t sourceRank = sourceType.getRank();
1335 "Only 2D+ reductions are supported.");
1337 for (
int64_t i = 0; i < sourceRank - 2; ++i) {
1338 if (sourceType.getShape()[i] != 1)
1340 warpOp,
"Only unit dimensions allowed for the leading dimensions.");
1343 int64_t rowIdx = sourceRank - 2;
1344 int64_t columnIdx = sourceRank - 1;
1346 if (reductionDims.size() != 1)
1348 "Only 1 reduction dim is supported.");
1349 int64_t reductionDim = reductionDims[0];
1351 if (reductionDim != rowIdx && reductionDim != columnIdx)
1353 warpOp,
"Reduction dim must be among the last 2 dimensions.");
1354 VectorType distributedResultType =
1355 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1356 VectorType resultType = cast<VectorType>(reductionOp.getType());
1357 xegpu::DistributeLayoutAttr sourceLayout =
1360 FailureOr<VectorType> sourceDistTypeOrFailure =
1361 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1362 if (failed(sourceDistTypeOrFailure))
1364 warpOp,
"Failed to distribute the source vector type.");
1365 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1367 bool rowDistributed =
1368 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1369 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1370 sourceType.getShape()[columnIdx];
1371 if (rowDistributed && columnDistributed)
1373 warpOp,
"Expecting source to be distributed in a single dimension.");
1375 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1376 if (sourceDistDim == -1)
1378 warpOp,
"Expecting a distributed source vector.");
1379 bool resultDistributed =
1380 distributedResultType.getNumElements() < resultType.getNumElements();
1394 bool isReductionLaneLocal =
1395 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1396 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1397 if (isReductionLaneLocal && !resultDistributed)
1399 warpOp,
"Expecting a distributed result for lane-local reduction.");
1401 if (!isReductionLaneLocal && resultDistributed)
1404 "Expecting a broadcasted result for non-lane-local reduction.");
1408 if (isReductionLaneLocal) {
1411 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1412 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1413 {sourceDistType, distributedResultType}, newRetIndices);
1418 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1430 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1506 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1517 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1518 VectorType destType =
1519 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1521 xegpu::DistributeLayoutAttr sourceLayout =
1523 xegpu::DistributeLayoutAttr resultLayout =
1526 FailureOr<VectorType> sourceDistType;
1527 Type sourceElemOrDistType;
1531 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1534 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1536 broadcastOp.emitWarning()
1537 <<
"Broadcast input layout must be a slice of result layout.";
1540 if (rankDiff == 0) {
1541 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1543 broadcastUnitDimsSet.end());
1544 assert(sourceLayout.isEqualTo(
1545 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1546 "The sg_data for unit dimensions should be set as 1");
1547 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1551 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1552 if (failed(sourceDistType)) {
1554 warpOp,
"Failed to distribute the source vector type.");
1556 sourceElemOrDistType = sourceDistType.value();
1562 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1564 sourceElemOrDistType = broadcastOp.getSourceType();
1566 FailureOr<VectorType> destDistType =
1567 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1568 if (failed(destDistType)) {
1570 warpOp,
"Failed to distribute the dest vector type.");
1575 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1578 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1580 Value newBroadcast = distributedSource;
1582 if (sourceElemOrDistType != destDistType.value()) {
1585 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1586 destDistType.value(), distributedSource);
1597 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1608 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1609 xegpu::DistributeLayoutAttr sourceLayout =
1611 xegpu::DistributeLayoutAttr resultLayout =
1613 if (!sourceLayout || !resultLayout)
1616 "the source or result of shape_cast op lacks distribution layout");
1618 FailureOr<VectorType> sourceDistTypeOrFailure =
1619 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1620 shapeCastOp.getSourceVectorType());
1621 if (failed(sourceDistTypeOrFailure))
1623 warpOp,
"failed to get distributed vector type for source");
1624 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1628 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1631 Value source = newWarpOp.getResult(newRetIndices[0]);
1633 Value newShapeCast = vector::ShapeCastOp::create(
1634 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1645struct VectorExtractStridedSliceDistribution
1647 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1651 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1657 auto distributedType =
1658 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1660 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1661 auto distributedDims =
1662 getDistributedDims(extractResultType, distributedType);
1666 VectorType updatedSourceType = extractOp.getSourceVectorType();
1668 extractOp.getSizes(), [](
Attribute attr) { return attr; });
1670 extractOp.getOffsets(), [](
Attribute attr) { return attr; });
1672 extractOp.getStrides(), [](
Attribute attr) { return attr; });
1676 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1677 for (
int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1679 extractOp.getSourceVectorType().getDimSize(i)));
1681 updatedStrides.push_back(
1687 if (distributedDims.size() > 0) {
1688 if (distributedDims.size() != 1)
1690 warpOp,
"Source can not be distributed in multiple dimensions.");
1691 int64_t distributedDim = distributedDims[0];
1692 int sourceDistrDimSize =
1693 extractOp.getSourceVectorType().getShape()[distributedDim];
1695 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1697 warpOp,
"the source of extract_strided_slice op lacks distribution "
1699 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1702 int subgroupSize = sourceLaneLayout[distributedDim];
1705 if (sourceDistrDimSize % subgroupSize != 0)
1708 "Source size along distributed dimension is not a multiple of "
1710 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1712 if (!llvm::all_of(sourceLaneData, [](
int64_t v) {
return v == 1; }))
1714 warpOp,
"Expecting unit lane data in source layout");
1718 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1719 if (distrDimOffset % subgroupSize != 0)
1721 warpOp,
"Offset along distributed dimension "
1722 "is not a multiple of subgroup size.");
1723 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1724 sourceLayout, extractOp.getSourceVectorType())
1728 distributedType.getDimSize(distributedDim));
1731 updatedOffsets[distributedDim] =
1738 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1741 Value source = newWarpOp.getResult(newRetIndices[0]);
1743 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1744 rewriter, extractOp.getLoc(), distributedType, source,
1745 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1746 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1747 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1757struct VectorInsertStridedSliceDistribution
1759 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1764 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1765 warpOp.getTerminator()->getPrevNode() == op;
1772 auto distributedType =
1773 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1775 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1776 auto destDistributedDims =
1777 getDistributedDims(insertResultType, distributedType);
1782 insertOp.getOffsets(), [](
Attribute attr) { return attr; });
1783 VectorType updatedSourceType = insertOp.getSourceVectorType();
1784 VectorType updatedDestType = insertOp.getDestVectorType();
1785 if (destDistributedDims.size() > 0) {
1787 if (destDistributedDims.size() != 1)
1790 "Expecting source to be distributed in a single dimension.");
1791 int64_t destDistributedDim = destDistributedDims[0];
1793 VectorType srcType = insertOp.getSourceVectorType();
1794 VectorType destType = insertOp.getDestVectorType();
1798 int64_t sourceDistributedDim =
1799 destDistributedDim - (destType.getRank() - srcType.getRank());
1800 if (sourceDistributedDim < 0)
1803 "distributed dimension must be in the last k (i.e. source "
1804 "rank) dims of dest vector");
1805 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1809 if (!destLayout || !sourceLayout ||
1810 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1811 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1813 warpOp,
"the source or dest of insert_strided_slice op lacks "
1814 "distribution layout");
1818 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1821 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1822 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1823 if (!llvm::all_of(destLaneData, [](
int64_t v) {
return v == 1; }) ||
1824 !llvm::all_of(sourceLaneData, [](
int64_t v) {
return v == 1; }))
1826 warpOp,
"Expecting unit lane data in source and dest layouts");
1828 if (srcDistrDimSize % subgroupSize != 0)
1830 warpOp,
"Distributed dimension size in source is not a multiple of "
1835 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1836 if (destDistrDimOffset % subgroupSize != 0)
1839 "Offset along distributed dimension in dest is not a multiple of "
1842 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1843 sourceLayout, insertOp.getSourceVectorType())
1845 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1846 destLayout, insertOp.getDestVectorType())
1850 updatedOffsets[destDistributedDim] =
1857 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1858 {updatedSourceType, updatedDestType}, newRetIndices);
1861 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1862 Value dest = newWarpOp.getResult(newRetIndices[1]);
1864 Value newInsertOp = vector::InsertStridedSliceOp::create(
1865 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1866 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1867 insertOp.getStrides());
1877struct MemrefExtractAlignedPointerAsIndexDistribution final
1879 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1880 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1883 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1887 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1892 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1893 rewriter, warpOp, extractOp.getSource(),
1894 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1896 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1897 rewriter, newWarpOp.getLoc(), extractOp.
getType(),
1898 newWarpOp.getResult(newRetIndices[0]));
1899 Value resultVal = newWarpOp.getResult(operandIdx);
1911 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1912 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1915 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1918 warpOp,
"warp result is not a vector::BitCast op");
1921 VectorType distributedSourceType =
1922 getDistVecTypeBasedOnLaneLayout(
1924 bitcastOp.getSourceVectorType())
1925 .value_or(VectorType());
1926 if (!distributedSourceType)
1928 bitcastOp,
"Failed to distribute the source vector type in "
1929 "vector::BitCast op");
1930 VectorType distributedResultType =
1931 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1933 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1934 rewriter, warpOp, bitcastOp.getSource(),
1935 TypeRange{distributedSourceType}, newRetIndices);
1937 auto newBitcastOp = vector::BitCastOp::create(
1938 rewriter, newWarpOp.getLoc(), distributedResultType,
1939 newWarpOp.getResult(newRetIndices[0]));
1940 Value distributedVal = newWarpOp.getResult(operandIdx);
1955 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1956 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1959 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1962 warpOp,
"warp result is not a vector::Transpose op");
1965 xegpu::DistributeLayoutAttr sourceLayout =
1967 xegpu::DistributeLayoutAttr resultLayout =
1969 if (!sourceLayout || !resultLayout)
1972 "the source or result vector of the transpose op lacks layout "
1974 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1975 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1978 if (sourceRank != 2 || resultRank != 2)
1980 transposeOp,
"the source or result vector of the transpose op "
1981 "does not have 2D layout");
1984 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1988 "the source or result vector layouts must be 2D transposes of each "
1990 FailureOr<VectorType> distributedSourceTypeOrFailure =
1991 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1992 transposeOp.getSourceVectorType());
1993 if (failed(distributedSourceTypeOrFailure))
1995 transposeOp,
"Failed to distribute the source vector type in "
1996 "vector::Transpose op");
1998 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1999 rewriter, warpOp, transposeOp.getVector(),
2000 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2002 auto newTransposeOp = vector::TransposeOp::create(
2003 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2005 Value distributedVal = newWarpOp.getResult(operandIdx);
2016 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2017 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2019 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2022 warpOp,
"warp result is not a vector::StepOp op");
2025 xegpu::DistributeLayoutAttr resultLayout =
2029 stepOp,
"the result vector of the step op lacks layout "
2031 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2034 stepOp,
"the result layout must be a slice layout");
2035 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2037 stepOp,
"expecting 1 dim in the effective result layout");
2040 auto loc = stepOp.getLoc();
2041 auto stepResultVecTy = stepOp.getResult().getType();
2042 Value distributedVal = warpOp.getResult(operandIdx);
2043 VectorType newVecTy = cast<VectorType>(distributedVal.
getType());
2045 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2046 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2047 if (failed(laneDataBlockCoords))
2049 stepOp,
"failed to compute lane data block coordinates");
2051 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2052 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2053 assert(
static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
2054 newVecTy.getNumElements() / laneDataBlockLength);
2063 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2064 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2065 stepVals.push_back(laneDataBlockStartCoord);
2066 for (
int i = 1; i < laneDataBlockLength; ++i) {
2068 stepVals.push_back(arith::AddIOp::create(
2069 rewriter, loc, laneDataBlockStartCoord, offset));
2072 assert(
static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
2073 "Expecting the number of step values to match the number of "
2074 "elements in the vector");
2076 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2082struct ConvertLayoutDistribution
2088 auto inputLayout = op.getInputLayoutAttr();
2089 auto targetLayout = op.getTargetLayoutAttr();
2090 Type valType = op.getResult().getType();
2092 if (!inputLayout || !targetLayout)
2099 auto resShape = cast<VectorType>(valType).getShape();
2101 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2104 op,
"lowering incompatible convert_layout not yet supported");
2114struct XeGPUSubgroupDistributePass final
2115 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
2116 XeGPUSubgroupDistributePass> {
2117 void runOnOperation()
override;
2123 patterns.
add<CreateNdDescDistribution, StoreNdDistribution,
2124 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2125 GpuBarrierDistribution, VectorMultiReductionDistribution,
2126 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2127 VectorBitcastDistribution, LoadMatrixDistribution,
2128 StoreMatrixDistribution, ConvertLayoutDistribution,
2129 MemrefExtractAlignedPointerAsIndexDistribution>(
2131 PatternHierarchy::Regular);
2135 .
add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2136 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2137 VectorStepSliceDistribution, SinkUniformOps>(
2139 PatternHierarchy::AboveRegular);
2147void XeGPUSubgroupDistributePass::runOnOperation() {
2154 signalPassFailure();
2165 signalPassFailure();
2172 getOperation()->walk([&](Operation *op) {
2173 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2174 vector::moveScalarUniformCode(warpOp);
2183 auto distributionFn = [](Value val) {
2184 VectorType vecType = dyn_cast<VectorType>(val.getType());
2185 int64_t vecRank = vecType ? vecType.getRank() : 0;
2194 assert(layout.getRank() == vecRank &&
2195 "Expecting vector and layout rank to match");
2199 SmallVector<unsigned int> distributedDims;
2200 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2201 if (v > 1 && vecType.getShape()[i] % v == 0)
2202 distributedDims.push_back(i);
2208 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2209 int64_t warpSz) {
return Value(); };
2211 vector::populateDistributeReduction(
2213 PatternHierarchy::Regular);
2215 vector::populatePropagateWarpVectorDistributionPatterns(
2216 patterns, distributionFn, shuffleFn,
2217 PatternHierarchy::Regular);
2219 signalPassFailure();
2229 bool foundWarpOp =
false;
2230 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2240 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2246 Value input = op.getOperand(0);
2247 Value output = op.getResult(0);
2250 xegpu::TensorDescType inputDescType =
2251 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2252 xegpu::TensorDescType outputDescType =
2253 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2254 assert(inputDescType && outputDescType &&
2255 "Unrealized conversion cast must have tensor descriptor types");
2260 if (inputDescType.getLayout()) {
2261 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2263 argument.setType(output.
getType());
2265 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2266 argument.getOwner()->getParentOp())) {
2267 auto result = loopOp.getTiedLoopResult(argument);
2276 if (outputDescType.getLayout())
2279 if (op->use_empty())
static Type getElementType(Type type)
Determine the element type of type.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0