36#include "llvm/ADT/ArrayRef.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/SmallVectorExtras.h"
43#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-subgroup-distribute"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
54 "resolve_simt_type_mismatch";
67enum PatternHierarchy :
unsigned { Regular = 1, AboveRegular = 2 };
84static Value resolveDistributedTy(
Value orig, T expected,
90 if (isa<VectorType>(orig.
getType())) {
92 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
93 return castOp.getResult();
97 if (isa<xegpu::TensorDescType>(orig.
getType())) {
98 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
101 return castOp.getResult(0);
103 llvm_unreachable(
"Unsupported type for reconciliation");
110 VectorType distributedType) {
111 assert(originalType.getRank() == distributedType.getRank() &&
112 "sequential and distributed vector types must have the same rank");
114 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
115 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
116 distributedDims.push_back(i);
119 return distributedDims;
152 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
153 "to set the warp size");
155 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
156 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
160 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
161 return isa<gpu::WarpExecuteOnLane0Op>(op);
166 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
169 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
171 auto newGpuFunc = gpu::GPUFuncOp::create(
172 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
174 privateAttributionsTypes);
175 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
179 auto laneId = gpu::LaneIdOp::create(
181 mlir::IntegerAttr());
182 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
183 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
184 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
186 newGpuFunc.getArgumentTypes());
187 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
190 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
192 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
193 origRetunOp.getOperands());
197 warpOp.getBodyRegion().begin());
201 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
202 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
240 using gpu::WarpDistributionPattern::WarpDistributionPattern;
241 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
244 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
247 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
251 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
254 descOp,
"the tensor descriptor lacks layout attribute");
256 if (descOp.getMixedOffsets().size())
258 descOp,
"xegpu::CreateNdDescOp must not have offsets");
262 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
263 rewriter, warpOp, descOp->getOperands(),
264 descOp.getOperandTypes(), newRetIndices);
267 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
269 xegpu::TensorDescType distributedTensorDescTy =
270 descOp.getType().dropLayouts();
272 Value newDescOp = xegpu::CreateNdDescOp::create(
273 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
276 Value distributedVal = newWarpOp.getResult(operandIdx);
279 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
318 using gpu::WarpDistributionPattern::WarpDistributionPattern;
319 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
321 gpu::YieldOp yield = warpOp.getTerminator();
322 Operation *lastNode = yield->getPrevNode();
323 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
331 "the store op must have offsets");
336 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
337 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
340 storeOp,
"the source tensor descriptor lacks layout attribute");
342 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
344 if (failed(distributedTypeByWarpOpOrFailure))
346 "Failed to distribute the type");
347 VectorType distributedTypeByWarpOp =
348 distributedTypeByWarpOpOrFailure.value();
352 storeOp.getTensorDesc()};
354 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
355 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
356 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
357 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
367 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
369 if (failed(storeNdDistributedValueTyOrFailure))
371 storeOp,
"Failed to get distributed vector type for the store op");
372 newStoreOperands.push_back(resolveDistributedTy(
373 newWarpOp.getResult(newRetIndices[0]),
374 storeNdDistributedValueTyOrFailure.value(), rewriter));
377 xegpu::TensorDescType distributedTensorDescTy =
378 storeOp.getTensorDescType().dropLayouts();
379 newStoreOperands.push_back(
380 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
381 distributedTensorDescTy, rewriter));
383 for (
size_t i = 2; i < newRetIndices.size(); ++i)
384 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
387 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
388 newStoreOperands, storeOp->getAttrs());
432 using gpu::WarpDistributionPattern::WarpDistributionPattern;
433 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
436 if (!isa<xegpu::LoadNdOp>(op))
441 gpu::YieldOp yield = warpOp.getTerminator();
442 return yield->getPrevNode() == op;
447 warpOp,
"warp result is not a xegpu::LoadNd op");
453 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
454 "determine transpose "
462 "the load op must have offsets");
468 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
469 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
472 loadOp,
"the source tensor descriptor lacks layout attribute");
475 VectorType distributedTypeByWarpOp =
476 cast<VectorType>(warpOp.getResult(operandIdx).getType());
481 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
482 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
483 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
484 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
489 FailureOr<VectorType> loadNdDistValueTyOrFailure =
491 if (failed(loadNdDistValueTyOrFailure))
493 loadOp,
"Failed to get distributed vector type for the load op");
494 xegpu::TensorDescType distributedTensorDescTy =
495 loadOp.getTensorDescType().dropLayouts();
499 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
500 distributedTensorDescTy, rewriter)};
502 for (
size_t i = 1; i < newRetIndices.size(); ++i)
503 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
504 auto newLoadOp = xegpu::LoadNdOp::create(
505 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
506 newLoadOperands, loadOp->getAttrs());
512 newLoadOp.setTranspose(
514 Value distributedVal = newWarpOp.getResult(operandIdx);
518 Value tyResolvedVal = resolveDistributedTy(
519 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
560 using gpu::WarpDistributionPattern::WarpDistributionPattern;
561 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
562 PatternRewriter &rewriter)
const override {
563 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
566 "warp result is not a xegpu::Dpas op");
571 xegpu::LayoutAttr layoutA =
572 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
573 xegpu::LayoutAttr layoutB =
574 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
575 xegpu::LayoutAttr layoutOut =
576 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
578 if (!layoutA || !layoutB || !layoutOut)
581 "the xegpu::Dpas op lacks layout attribute for A, B or output");
583 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
585 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
587 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
590 if (
failed(distLhsTypeByWarpOpOrFailure) ||
591 failed(distRhsTypeByWarpOpOrFailure) ||
592 failed(distResultTypeByWarpOpOrFailure))
595 "Failed to distribute the A, B or output types in xegpu::Dpas op");
597 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
599 llvm::SmallVector<Type, 3> newYieldTypes{
600 distLhsTypeByWarpOpOrFailure.value(),
601 distRhsTypeByWarpOpOrFailure.value()};
603 if (dpasOp.getAcc()) {
604 newYieldValues.push_back(dpasOp.getAcc());
605 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
608 SmallVector<size_t> newRetIndices;
609 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
610 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
612 FailureOr<VectorType> expectedDistLhsTyOrFailure =
614 FailureOr<VectorType> expectedDistRhsTyOrFailure =
616 FailureOr<VectorType> expectedDistResultTyOrFailure =
619 if (
failed(expectedDistLhsTyOrFailure) ||
620 failed(expectedDistRhsTyOrFailure) ||
621 failed(expectedDistResultTyOrFailure))
624 "Failed to get distributed vector type for the dpas operands.");
627 SmallVector<Value> newDpasOperands;
628 SmallVector<VectorType> newDpasOperandExpectedTypes;
631 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
632 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
633 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
635 newDpasOperandExpectedTypes.push_back(distributedResultTy);
637 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
638 newDpasOperands.push_back(
639 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
640 newDpasOperandExpectedTypes[i], rewriter));
642 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
643 distributedResultTy, newDpasOperands,
646 Value distributedVal = newWarpOp.getResult(operandIdx);
649 resolveDistributedTy(newDpasOp.getResult(),
650 distResultTypeByWarpOpOrFailure.value(), rewriter);
685 using gpu::WarpDistributionPattern::WarpDistributionPattern;
686 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
687 PatternRewriter &rewriter)
const override {
688 gpu::YieldOp yield = warpOp.getTerminator();
689 Operation *lastNode = yield->getPrevNode();
690 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
694 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
698 "the prefetch op must have offsets");
699 SmallVector<Value> offsetsAsValues =
701 SmallVector<Type> offsetTypes = llvm::map_to_vector(
702 offsetsAsValues, [](Value v) {
return v.
getType(); });
704 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
707 prefetchOp,
"the source tensor descriptor lacks layout attribute");
709 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
710 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
711 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
712 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
713 SmallVector<size_t> newRetIndices;
714 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
715 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
718 xegpu::TensorDescType newTensorDescTy =
719 prefetchOp.getTensorDescType().dropLayouts();
721 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
722 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
724 for (
size_t i = 1; i < newRetIndices.size(); ++i)
725 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
726 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
727 rewriter, newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
728 prefetchOp->getAttrs());
738 using gpu::WarpDistributionPattern::WarpDistributionPattern;
739 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
740 PatternRewriter &rewriter)
const override {
741 gpu::YieldOp yield = warpOp.getTerminator();
742 Operation *lastNode = yield->getPrevNode();
744 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
749 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
750 barrierOp->getResultTypes(),
751 barrierOp->getOperands(), barrierOp->getAttrs());
791 using gpu::WarpDistributionPattern::WarpDistributionPattern;
792 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
793 PatternRewriter &rewriter)
const override {
794 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
795 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
798 auto offsets = storeScatterOp.getOffsets();
799 if (!offsets || !isa<VectorType>(offsets.getType()))
801 storeScatterOp,
"Store op must have a vector of offsets argument");
802 VectorType offsetsTy = cast<VectorType>(offsets.getType());
803 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
804 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
807 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
808 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
811 for (
int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
812 if (storeVecTy.getShape()[i] != 1) {
814 storeScatterOp,
"Only unit dimensions allowed for the leading "
815 "dimensions of the store vector!");
826 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
828 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
830 FailureOr<VectorType> distMaskByWarpOpOrFailure =
832 if (
failed(distStoreVecByWarpOpOrFailure) ||
833 failed(distOffsetsByWarpOpOrFailure) ||
834 failed(distMaskByWarpOpOrFailure)) {
837 "Some vector operands have no layouts, using defaults instead.");
840 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
841 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
842 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
844 SmallVector<size_t> newRetIndices;
845 SmallVector<Value> operands = storeScatterOp->getOperands();
846 SmallVector<Type> operandTypesToYield = {
847 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
849 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
850 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
855 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
856 distPayloadTy.getElementType());
858 VectorType distOffsetsTy1D = VectorType::get(
859 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
860 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
861 distMaskTy.getElementType());
864 Value distPayloadVal = resolveDistributedTy(
865 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
866 Value distOffsetVal = resolveDistributedTy(
867 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
868 Value distMaskVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
871 SmallVector<Value> newStoreScatterOpOperands = {
872 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
875 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
876 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
877 storeScatterOp->getAttrs());
879 rewriter.
eraseOp(storeScatterOp);
889 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
892 assert(maybeCoords.value().size() == 1 &&
893 "Expected one set of distributed offsets");
897 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
903 using gpu::WarpDistributionPattern::WarpDistributionPattern;
904 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
905 PatternRewriter &rewriter)
const override {
906 gpu::YieldOp yield = warpOp.getTerminator();
907 Operation *lastNode = yield->getPrevNode();
908 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
912 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
913 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
915 if (!producedByLastLoad)
917 warpOp,
"The last op is not xegpu::LoadMatrixOp");
920 VectorType sgPayloadTy =
921 dyn_cast<VectorType>(matrixOp.getResult().getType());
922 VectorType warpResultTy =
923 cast<VectorType>(warpOp.getResult(operandIdx).getType());
926 matrixOp,
"the matrix op payload must be a vector type");
928 auto loc = matrixOp.getLoc();
929 auto offsets = matrixOp.getMixedOffsets();
932 "the load op must have offsets");
933 SmallVector<Value> offsetsAsValues =
936 auto layout = matrixOp.getLayoutAttr();
939 matrixOp,
"the matrix operation lacks layout attribute");
941 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
943 if (
failed(distPayloadByWarpOpOrFailure))
945 matrixOp,
"Failed to distribute matrix op payload based on layout.");
947 SmallVector<Value> operands = {matrixOp.getMemDesc()};
948 const unsigned offsetsStartIdx = operands.size();
949 operands.append(offsetsAsValues);
951 SmallVector<Type> operandTypes =
952 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
954 SmallVector<size_t> newRetIndices;
955 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
956 rewriter, warpOp, operands, operandTypes, newRetIndices);
957 SmallVector<Value> newOperands = llvm::map_to_vector(
958 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
960 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
961 ShapedType::kDynamic);
965 ValueRange(newOperands).drop_front(offsetsStartIdx);
967 SmallVector<Value> newCoords = currentOffsets;
970 if (!matrixOp.getSubgroupBlockIoAttr()) {
971 newCoords = computeDistributedCoordinatesForMatrixOp(
972 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
975 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
976 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
977 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
978 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
981 newWarpOp.getResult(operandIdx),
982 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
989 using gpu::WarpDistributionPattern::WarpDistributionPattern;
990 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
991 PatternRewriter &rewriter)
const override {
992 gpu::YieldOp yield = warpOp.getTerminator();
993 Operation *lastNode = yield->getPrevNode();
994 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
998 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1001 matrixOp,
"the matrix op payload must be a vector type");
1003 auto loc = matrixOp.getLoc();
1004 auto offsets = matrixOp.getMixedOffsets();
1005 if (offsets.empty())
1007 "the store op must have offsets");
1008 SmallVector<Value> offsetsAsValues =
1011 auto layout = matrixOp.getLayoutAttr();
1014 matrixOp,
"the matrix operation lacks layout attribute");
1016 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1018 if (
failed(distPayloadByWarpOpOrFailure))
1020 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1022 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1023 const unsigned offsetsStartIdx = operands.size();
1024 operands.append(offsetsAsValues);
1026 SmallVector<Type> operandTypes =
1027 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
1028 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1030 SmallVector<size_t> newRetIndices;
1031 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1032 rewriter, warpOp, operands, operandTypes, newRetIndices);
1033 SmallVector<Value> newOperands = llvm::map_to_vector(
1034 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1036 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1037 ShapedType::kDynamic);
1041 ValueRange(newOperands).drop_front(offsetsStartIdx);
1043 SmallVector<Value> newCoords = currentOffsets;
1046 if (!matrixOp.getSubgroupBlockIoAttr()) {
1047 newCoords = computeDistributedCoordinatesForMatrixOp(
1048 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1052 xegpu::StoreMatrixOp::create(
1053 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1055 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1090 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1091 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1092 PatternRewriter &rewriter)
const override {
1093 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1096 return isa<xegpu::LoadGatherOp>(op) &&
1097 warpOp.getTerminator()->getPrevNode() == op;
1099 if (!producedByLastLoad)
1101 warpOp,
"The last op is not xegpu::LoadGatherOp");
1105 auto offsets = loadGatherOp.getOffsets();
1106 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1107 !isa<VectorType>(loadGatherOp.getMask().getType()))
1110 "Load op must have a vector arguments for offsets and mask");
1111 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1112 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1113 VectorType resultVecTy =
1114 cast<VectorType>(loadGatherOp.getResult().getType());
1116 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1117 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1118 for (
int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1119 if (resultVecTy.getShape()[i] != 1) {
1121 loadGatherOp,
"Only unit dimensions allowed for the leading "
1122 "dimensions of the load vector!");
1126 auto layoutOffsets =
1130 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1132 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1134 if (
failed(distOffsetsByWarpOpOrFailure) ||
1135 failed(distMaskByWarpOpOrFailure)) {
1138 "Some vector operands have no layouts, using defaults instead.");
1141 SmallVector<size_t> newRetIndices;
1142 SmallVector<Value> operands = loadGatherOp->getOperands();
1145 VectorType distResultTy =
1146 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1147 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1148 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1150 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1151 distOffsetsTy, distMaskTy};
1153 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1154 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1159 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1160 distResultTy.getElementType());
1162 VectorType distOffsetsTy1D =
1163 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1165 VectorType distMaskTy1D =
1166 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1169 Value distOffsetVal = resolveDistributedTy(
1170 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1171 Value distmaskVal = resolveDistributedTy(
1172 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1174 SmallVector<Value> newLoadGatherOperands = {
1175 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1177 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1178 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1179 loadGatherOp->getAttrs());
1181 Value distributedVal = newWarpOp.getResult(operandIdx);
1185 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1197 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1198 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1199 PatternRewriter &rewriter)
const override {
1201 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1204 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->
getNumRegions())
1206 int operandIdx = -1;
1208 OpOperand *operand = getWarpResult(
1209 warpOp, [&](Operation *op) {
return warpRegionPreYieldOp == op; });
1214 warpOp.getResult(operandIdx).getType())
1216 "The op result is not uniform.");
1220 bool uniformValuesOnly =
1221 llvm::all_of(warpRegionPreYieldOp->
getResults(), [](Value v) {
1222 return !xegpu::getDistributeLayoutAttr(v);
1224 uniformValuesOnly &=
1225 llvm::all_of(warpRegionPreYieldOp->
getOpOperands(), [](OpOperand &opr) {
1226 return !xegpu::getDistributeLayoutAttr(opr);
1228 if (!uniformValuesOnly)
1230 "Some values are not uniform.");
1231 SmallVector<size_t> newRetIndices;
1232 SmallVector<Value> operands =
1233 llvm::to_vector_of<Value>(warpRegionPreYieldOp->
getOperands());
1234 SmallVector<Type> operandTypes =
1236 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1237 rewriter, warpOp, operands, operandTypes, newRetIndices);
1240 IRMapping operandMapper;
1241 for (
auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1242 operandMapper.
map(warpRegionPreYieldOp->
getOperand(oldOperandIdx),
1243 newWarpOp->getResult(newOperandIdx));
1244 Operation *clonedOp = rewriter.
clone(*warpRegionPreYieldOp, operandMapper);
1246 rewriter.
eraseOp(warpRegionPreYieldOp);
1248 assert(operandIdx != -1 &&
"Expected a warp result for the operation");
1260 vector::CombiningKind kind,
1264 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1265 VectorType sourceType = src.getType();
1266 int64_t sourceH = sourceType.getShape()[0];
1267 int64_t sourceW = sourceType.getShape()[1];
1268 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1270 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1271 Value reductionResult = arith::ConstantOp::create(
1272 rewriter, loc,
acc.getType(),
1279 for (
int i = 0; i < nSlices; ++i) {
1281 if (reductionDim == 1) {
1282 sliceOffsets = {i, 0};
1283 sliceSizes = {1, sourceW};
1285 sliceOffsets = {0, i};
1286 sliceSizes = {sourceH, 1};
1288 vector::ExtractStridedSliceOp extractOp =
1289 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1290 sliceSizes, {1, 1});
1292 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1294 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1296 VectorType::get({nSliceElements}, sourceType.getElementType()),
1297 extractOp.getResult());
1310 Value accExtract = vector::ExtractOp::create(rewriter, loc,
acc, i);
1311 Value reduction = vector::ReductionOp::create(
1312 rewriter, loc, kind, slice.getResult(), accExtract);
1314 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1316 return reductionResult;
1375 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1376 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1377 PatternRewriter &rewriter)
const override {
1378 OpOperand *yieldOperand =
1379 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1385 VectorType sourceType = reductionOp.getSourceVectorType();
1387 if (sourceType.getRank() != 2)
1389 "Only 2D reductions are supported.");
1390 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1393 if (reductionDims.size() != 1)
1395 warpOp,
"Only 1 reduction dimension is supported.");
1396 int64_t reductionDim = reductionDims[0];
1397 VectorType distributedResultType =
1398 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1399 VectorType resultType = cast<VectorType>(reductionOp.getType());
1400 xegpu::DistributeLayoutAttr sourceLayout =
1403 FailureOr<VectorType> sourceDistTypeOrFailure =
1405 if (
failed(sourceDistTypeOrFailure))
1407 warpOp,
"Failed to distribute the source vector type.");
1408 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1410 bool dim0Distributed =
1411 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1412 bool dim1Distributed =
1413 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1414 if (dim0Distributed && dim1Distributed)
1416 warpOp,
"Expecting source to be distributed in a single dimension.");
1417 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1418 if (sourceDistDim == -1)
1420 warpOp,
"Expecting a distributed source vector.");
1421 bool resultDistributed =
1422 distributedResultType.getNumElements() < resultType.getNumElements();
1436 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1437 (sourceDistDim == 1 && reductionDim == 0);
1438 if (isReductionLaneLocal && !resultDistributed)
1440 warpOp,
"Expecting a distributed result for lane-local reduction.");
1442 if (!isReductionLaneLocal && resultDistributed)
1445 "Expecting a broadcasted result for non-lane-local reduction.");
1449 if (isReductionLaneLocal) {
1451 SmallVector<size_t> newRetIndices;
1452 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1453 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1454 {sourceDistType, distributedResultType}, newRetIndices);
1456 Value
result = lowerToVectorReductions(
1459 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1468 Value
result = lowerToVectorReductions(
1471 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1547 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1549 PatternRewriter &rewriter)
const override {
1550 OpOperand *yieldOperand =
1558 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1559 VectorType destType =
1560 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1562 xegpu::DistributeLayoutAttr sourceLayout =
1564 xegpu::DistributeLayoutAttr resultLayout =
1567 FailureOr<VectorType> sourceDistType;
1568 Type sourceElemOrDistType;
1572 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1575 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1579 "Broadcast input layout must be a slice of result layout.");
1582 if (rankDiff == 0) {
1583 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1584 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1585 broadcastUnitDimsSet.end());
1586 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1589 warpOp,
"For same-rank broadcast, source must be identical to "
1590 "adjusted result layouts with unit dims.");
1591 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1592 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1597 if (
failed(sourceDistType)) {
1599 warpOp,
"Failed to distribute the source vector type.");
1601 sourceElemOrDistType = sourceDistType.value();
1607 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1609 sourceElemOrDistType = broadcastOp.getSourceType();
1611 FailureOr<VectorType> destDistType =
1613 if (
failed(destDistType)) {
1615 warpOp,
"Failed to distribute the dest vector type.");
1618 SmallVector<size_t> newRetIndices;
1620 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1623 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1625 Value newBroadcast = distributedSource;
1627 if (sourceElemOrDistType != destDistType.value()) {
1630 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1631 destDistType.value(), distributedSource);
1642 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1644 PatternRewriter &rewriter)
const override {
1645 OpOperand *yieldOperand =
1653 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1654 xegpu::DistributeLayoutAttr sourceLayout =
1656 xegpu::DistributeLayoutAttr resultLayout =
1658 if (!sourceLayout || !resultLayout)
1661 "the source or result of shape_cast op lacks distribution layout");
1663 FailureOr<VectorType> sourceDistTypeOrFailure =
1665 shapeCastOp.getSourceVectorType());
1666 if (
failed(sourceDistTypeOrFailure))
1668 warpOp,
"failed to get distributed vector type for source");
1669 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1671 SmallVector<size_t> newRetIndices;
1673 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1676 Value source = newWarpOp.getResult(newRetIndices[0]);
1678 Value newShapeCast = vector::ShapeCastOp::create(
1679 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1690struct VectorExtractStridedSliceDistribution
1692 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1694 PatternRewriter &rewriter)
const override {
1695 OpOperand *operand =
1696 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1702 auto distributedType =
1703 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1705 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1706 auto distributedDims =
1707 getDistributedDims(extractResultType, distributedType);
1711 VectorType updatedSourceType = extractOp.getSourceVectorType();
1712 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1713 extractOp.getSizes(), [](Attribute attr) { return attr; });
1714 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1715 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1716 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1717 extractOp.getStrides(), [](Attribute attr) { return attr; });
1721 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1722 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1724 extractOp.getSourceVectorType().getDimSize(i)));
1726 updatedStrides.push_back(
1732 if (distributedDims.size() > 0) {
1733 if (distributedDims.size() != 1)
1735 warpOp,
"Source can not be distributed in multiple dimensions.");
1736 int64_t distributedDim = distributedDims[0];
1737 int sourceDistrDimSize =
1738 extractOp.getSourceVectorType().getShape()[distributedDim];
1740 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1742 warpOp,
"the source of extract_strided_slice op lacks distribution "
1744 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1747 int subgroupSize = sourceLaneLayout[distributedDim];
1750 if (sourceDistrDimSize % subgroupSize != 0)
1753 "Source size along distributed dimension is not a multiple of "
1755 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1757 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1759 warpOp,
"Expecting unit lane data in source layout");
1762 int64_t distrDimOffset =
1763 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1764 if (distrDimOffset % subgroupSize != 0)
1766 warpOp,
"Offset along distributed dimension "
1767 "is not a multiple of subgroup size.");
1769 sourceLayout, extractOp.getSourceVectorType())
1773 distributedType.getDimSize(distributedDim));
1776 updatedOffsets[distributedDim] =
1781 SmallVector<size_t> newRetIndices;
1783 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1786 Value source = newWarpOp.getResult(newRetIndices[0]);
1788 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1789 rewriter, extractOp.getLoc(), distributedType, source,
1790 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1791 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1792 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1802struct VectorInsertStridedSliceDistribution
1804 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1806 PatternRewriter &rewriter)
const override {
1807 OpOperand *operand =
getWarpResult(warpOp, [&](Operation *op) {
1809 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1810 warpOp.getTerminator()->getPrevNode() == op;
1817 auto distributedType =
1818 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1820 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1821 auto destDistributedDims =
1822 getDistributedDims(insertResultType, distributedType);
1826 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1827 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1828 VectorType updatedSourceType = insertOp.getSourceVectorType();
1829 VectorType updatedDestType = insertOp.getDestVectorType();
1830 if (destDistributedDims.size() > 0) {
1832 if (destDistributedDims.size() != 1)
1835 "Expecting source to be distributed in a single dimension.");
1836 int64_t destDistributedDim = destDistributedDims[0];
1838 VectorType srcType = insertOp.getSourceVectorType();
1839 VectorType destType = insertOp.getDestVectorType();
1843 int64_t sourceDistributedDim =
1844 destDistributedDim - (destType.getRank() - srcType.getRank());
1845 if (sourceDistributedDim < 0)
1848 "distributed dimension must be in the last k (i.e. source "
1849 "rank) dims of dest vector");
1850 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1854 if (!destLayout || !sourceLayout ||
1855 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1856 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1858 warpOp,
"the source or dest of insert_strided_slice op lacks "
1859 "distribution layout");
1863 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1866 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1867 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1868 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1869 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1871 warpOp,
"Expecting unit lane data in source and dest layouts");
1873 if (srcDistrDimSize % subgroupSize != 0)
1875 warpOp,
"Distributed dimension size in source is not a multiple of "
1879 int64_t destDistrDimOffset =
1880 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1881 if (destDistrDimOffset % subgroupSize != 0)
1884 "Offset along distributed dimension in dest is not a multiple of "
1888 sourceLayout, insertOp.getSourceVectorType())
1891 destLayout, insertOp.getDestVectorType())
1895 updatedOffsets[destDistributedDim] =
1900 SmallVector<size_t> newRetIndices;
1902 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1903 {updatedSourceType, updatedDestType}, newRetIndices);
1906 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1907 Value dest = newWarpOp.getResult(newRetIndices[1]);
1909 Value newInsertOp = vector::InsertStridedSliceOp::create(
1910 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1911 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1912 insertOp.getStrides());
1922struct MemrefExtractAlignedPointerAsIndexDistribution final
1924 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1925 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1926 PatternRewriter &rewriter)
const override {
1927 OpOperand *operand = getWarpResult(
1928 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1932 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1936 SmallVector<size_t> newRetIndices;
1937 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1938 rewriter, warpOp, extractOp.getSource(),
1939 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1941 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1942 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1943 newWarpOp.getResult(newRetIndices[0]));
1944 Value resultVal = newWarpOp.getResult(operandIdx);
1956 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1957 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1958 PatternRewriter &rewriter)
const override {
1959 OpOperand *operand =
1960 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1963 warpOp,
"warp result is not a vector::BitCast op");
1966 VectorType distributedSourceType =
1969 bitcastOp.getSourceVectorType())
1970 .value_or(VectorType());
1971 if (!distributedSourceType)
1973 bitcastOp,
"Failed to distribute the source vector type in "
1974 "vector::BitCast op");
1975 VectorType distributedResultType =
1976 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1977 SmallVector<size_t> newRetIndices;
1978 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1979 rewriter, warpOp, bitcastOp.getSource(),
1980 TypeRange{distributedSourceType}, newRetIndices);
1982 auto newBitcastOp = vector::BitCastOp::create(
1983 rewriter, newWarpOp.getLoc(), distributedResultType,
1984 newWarpOp.getResult(newRetIndices[0]));
1985 Value distributedVal = newWarpOp.getResult(operandIdx);
2000 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2001 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2002 PatternRewriter &rewriter)
const override {
2003 OpOperand *operand =
2004 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
2007 warpOp,
"warp result is not a vector::Transpose op");
2010 xegpu::DistributeLayoutAttr sourceLayout =
2012 xegpu::DistributeLayoutAttr resultLayout =
2014 if (!sourceLayout || !resultLayout)
2017 "the source or result vector of the transpose op lacks layout "
2019 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
2020 int64_t resultRank = transposeOp.getResultVectorType().getRank();
2023 if (sourceRank != 2 || resultRank != 2)
2025 transposeOp,
"the source or result vector of the transpose op "
2026 "does not have 2D layout");
2027 ArrayRef<int64_t> perm = transposeOp.getPermutation();
2029 if (!resultLayout.isTransposeOf(sourceLayout, perm))
2032 "the source or result vector layouts must be 2D transposes of each "
2034 FailureOr<VectorType> distributedSourceTypeOrFailure =
2036 transposeOp.getSourceVectorType());
2037 if (
failed(distributedSourceTypeOrFailure))
2039 transposeOp,
"Failed to distribute the source vector type in "
2040 "vector::Transpose op");
2041 SmallVector<size_t> newRetIndices;
2042 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2043 rewriter, warpOp, transposeOp.getVector(),
2044 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2046 auto newTransposeOp = vector::TransposeOp::create(
2047 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2049 Value distributedVal = newWarpOp.getResult(operandIdx);
2058struct XeGPUSubgroupDistributePass final
2060 XeGPUSubgroupDistributePass> {
2061 void runOnOperation()
override;
2067 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2068 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2069 GpuBarrierDistribution, VectorMultiReductionDistribution,
2070 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2071 VectorBitcastDistribution, LoadMatrixDistribution,
2072 StoreMatrixDistribution,
2073 MemrefExtractAlignedPointerAsIndexDistribution>(
2075 PatternHierarchy::Regular);
2079 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2080 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2081 SinkUniformOps>(
patterns.getContext(),
2082 PatternHierarchy::AboveRegular);
2090void XeGPUSubgroupDistributePass::runOnOperation() {
2097 signalPassFailure();
2108 signalPassFailure();
2115 getOperation()->walk([&](Operation *op) {
2116 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2117 vector::moveScalarUniformCode(warpOp);
2126 auto distributionFn = [](Value val) {
2127 VectorType vecType = dyn_cast<VectorType>(val.getType());
2128 int64_t vecRank = vecType ? vecType.getRank() : 0;
2137 assert(layout.getRank() == vecRank &&
2138 "Expecting vector and layout rank to match");
2142 SmallVector<unsigned int> distributedDims;
2143 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2144 if (v > 1 && vecType.getShape()[i] % v == 0)
2145 distributedDims.push_back(i);
2151 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2152 int64_t warpSz) {
return Value(); };
2154 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2155 vector::CombiningKind kind, uint32_t size) {
2157 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2159 for (uint64_t i = 1; i < size; i <<= 1) {
2160 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2162 gpu::ShuffleMode::XOR)
2163 .getShuffleResult();
2169 vector::populateDistributeReduction(
2171 PatternHierarchy::Regular);
2173 vector::populatePropagateWarpVectorDistributionPatterns(
2174 patterns, distributionFn, shuffleFn,
2175 PatternHierarchy::Regular);
2177 signalPassFailure();
2187 bool foundWarpOp =
false;
2188 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2198 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2204 Value input = op.getOperand(0);
2205 Value output = op.getResult(0);
2208 xegpu::TensorDescType inputDescType =
2209 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2210 xegpu::TensorDescType outputDescType =
2211 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2212 assert(inputDescType && outputDescType &&
2213 "Unrealized conversion cast must have tensor descriptor types");
2218 if (inputDescType.getLayout()) {
2219 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2221 argument.setType(output.
getType());
2223 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2224 argument.getOwner()->getParentOp())) {
2225 auto result = loopOp.getTiedLoopResult(argument);
2234 if (outputDescType.getLayout())
2237 if (op->use_empty())
static Type getElementType(Type type)
Determine the element type of type.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0