37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 "resolve_simt_type_mismatch";
68enum PatternHierarchy :
unsigned { Regular = 1, AboveRegular = 2 };
85static Value resolveDistributedTy(
Value orig, T expected,
91 if (isa<VectorType>(orig.
getType())) {
93 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
94 return castOp.getResult();
98 if (isa<xegpu::TensorDescType>(orig.
getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
102 return castOp.getResult(0);
104 llvm_unreachable(
"Unsupported type for reconciliation");
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
115 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
120 return distributedDims;
153 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
154 "to set the warp size");
155 if (!gpuFuncOp.getBody().hasOneBlock())
157 gpuFuncOp,
"expected gpu.func to have a single block");
160 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
161 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
165 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
166 return isa<gpu::WarpExecuteOnLane0Op>(op);
169 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
170 gpuFuncOp.getBlocks().back().getTerminator());
173 gpuFuncOp,
"expected gpu.func terminator to be gpu.return");
176 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributionBBArgs(),
179 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
181 auto newGpuFunc = gpu::GPUFuncOp::create(
182 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
184 privateAttributionsTypes);
185 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
189 auto laneId = gpu::LaneIdOp::create(
191 mlir::IntegerAttr());
192 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
193 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
194 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
196 newGpuFunc.getArgumentTypes());
197 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
200 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
201 origReturnOp.getOperands());
202 rewriter.
eraseOp(origReturnOp);
205 warpOp.getBodyRegion().begin());
209 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
210 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
248 using gpu::WarpDistributionPattern::WarpDistributionPattern;
249 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
252 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
255 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
259 xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
262 descOp,
"the tensor descriptor lacks layout attribute");
265 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
266 rewriter, warpOp, descOp->getOperands(),
267 descOp.getOperandTypes(), newRetIndices);
270 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
272 xegpu::TensorDescType distributedTensorDescTy =
273 descOp.getType().dropLayouts();
275 Value newDescOp = xegpu::CreateNdDescOp::create(
276 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
279 Value distributedVal = newWarpOp.getResult(operandIdx);
282 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
321 using gpu::WarpDistributionPattern::WarpDistributionPattern;
322 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
324 gpu::YieldOp yield = warpOp.getTerminator();
325 Operation *lastNode = yield->getPrevNode();
326 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
334 "the store op must have offsets");
339 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
340 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
343 storeOp,
"the source tensor descriptor lacks layout attribute");
345 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
347 if (failed(distributedTypeByWarpOpOrFailure))
349 "Failed to distribute the type");
350 VectorType distributedTypeByWarpOp =
351 distributedTypeByWarpOpOrFailure.value();
355 storeOp.getTensorDesc()};
357 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
358 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
359 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
360 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
370 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
372 if (failed(storeNdDistributedValueTyOrFailure))
374 storeOp,
"Failed to get distributed vector type for the store op");
375 newStoreOperands.push_back(resolveDistributedTy(
376 newWarpOp.getResult(newRetIndices[0]),
377 storeNdDistributedValueTyOrFailure.value(), rewriter));
380 xegpu::TensorDescType distributedTensorDescTy =
381 storeOp.getTensorDescType().dropLayouts();
382 newStoreOperands.push_back(
383 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
384 distributedTensorDescTy, rewriter));
386 for (
size_t i = 2; i < newRetIndices.size(); ++i)
387 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
390 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
391 newStoreOperands, storeOp->getAttrs());
435 using gpu::WarpDistributionPattern::WarpDistributionPattern;
436 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
439 if (!isa<xegpu::LoadNdOp>(op))
444 gpu::YieldOp yield = warpOp.getTerminator();
445 return yield->getPrevNode() == op;
450 warpOp,
"warp result is not a xegpu::LoadNd op");
456 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
457 "determine transpose "
465 "the load op must have offsets");
471 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
472 xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
475 loadOp,
"the source tensor descriptor lacks layout attribute");
478 VectorType distributedTypeByWarpOp =
479 cast<VectorType>(warpOp.getResult(operandIdx).getType());
484 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
485 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
486 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
487 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
492 FailureOr<VectorType> loadNdDistValueTyOrFailure =
494 if (failed(loadNdDistValueTyOrFailure))
496 loadOp,
"Failed to get distributed vector type for the load op");
497 xegpu::TensorDescType distributedTensorDescTy =
498 loadOp.getTensorDescType().dropLayouts();
502 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
503 distributedTensorDescTy, rewriter)};
505 for (
size_t i = 1; i < newRetIndices.size(); ++i)
506 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
507 auto newLoadOp = xegpu::LoadNdOp::create(
508 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
509 newLoadOperands, loadOp->getAttrs());
515 newLoadOp.setTranspose(
517 Value distributedVal = newWarpOp.getResult(operandIdx);
521 Value tyResolvedVal = resolveDistributedTy(
522 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
563 using gpu::WarpDistributionPattern::WarpDistributionPattern;
564 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
566 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
569 "warp result is not a xegpu::Dpas op");
574 xegpu::LayoutAttr layoutA =
575 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
576 xegpu::LayoutAttr layoutB =
577 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
578 xegpu::LayoutAttr layoutOut =
579 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
581 if (!layoutA || !layoutB || !layoutOut)
584 "the xegpu::Dpas op lacks layout attribute for A, B or output");
586 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
587 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
588 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
589 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
590 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
591 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
593 if (failed(distLhsTypeByWarpOpOrFailure) ||
594 failed(distRhsTypeByWarpOpOrFailure) ||
595 failed(distResultTypeByWarpOpOrFailure))
598 "Failed to distribute the A, B or output types in xegpu::Dpas op");
603 distLhsTypeByWarpOpOrFailure.value(),
604 distRhsTypeByWarpOpOrFailure.value()};
606 if (dpasOp.getAcc()) {
607 newYieldValues.push_back(dpasOp.getAcc());
608 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
611 SmallVector<size_t> newRetIndices;
612 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
613 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
615 FailureOr<VectorType> expectedDistLhsTyOrFailure =
617 FailureOr<VectorType> expectedDistRhsTyOrFailure =
619 FailureOr<VectorType> expectedDistResultTyOrFailure =
622 if (
failed(expectedDistLhsTyOrFailure) ||
623 failed(expectedDistRhsTyOrFailure) ||
624 failed(expectedDistResultTyOrFailure))
627 "Failed to get distributed vector type for the dpas operands.");
630 SmallVector<Value> newDpasOperands;
631 SmallVector<VectorType> newDpasOperandExpectedTypes;
634 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
635 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
636 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
638 newDpasOperandExpectedTypes.push_back(distributedResultTy);
640 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
641 newDpasOperands.push_back(
642 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
643 newDpasOperandExpectedTypes[i], rewriter));
645 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
646 distributedResultTy, newDpasOperands,
649 Value distributedVal = newWarpOp.getResult(operandIdx);
652 resolveDistributedTy(newDpasOp.getResult(),
653 distResultTypeByWarpOpOrFailure.value(), rewriter);
688 using gpu::WarpDistributionPattern::WarpDistributionPattern;
689 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
690 PatternRewriter &rewriter)
const override {
691 gpu::YieldOp yield = warpOp.getTerminator();
692 Operation *lastNode = yield->getPrevNode();
693 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
697 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
701 "the prefetch op must have offsets");
702 SmallVector<Value> offsetsAsValues =
704 SmallVector<Type> offsetTypes = llvm::map_to_vector(
705 offsetsAsValues, [](Value v) {
return v.
getType(); });
707 xegpu::DistributeLayoutAttr layout =
708 prefetchOp.getTensorDescType().getLayoutAttr();
711 prefetchOp,
"the source tensor descriptor lacks layout attribute");
713 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
714 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
715 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
716 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
717 SmallVector<size_t> newRetIndices;
718 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
719 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
722 xegpu::TensorDescType newTensorDescTy =
723 prefetchOp.getTensorDescType().dropLayouts();
725 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
726 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
728 for (
size_t i = 1; i < newRetIndices.size(); ++i)
729 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
730 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
731 rewriter, newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
732 prefetchOp->getAttrs());
742 using gpu::WarpDistributionPattern::WarpDistributionPattern;
743 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
744 PatternRewriter &rewriter)
const override {
745 gpu::YieldOp yield = warpOp.getTerminator();
746 Operation *lastNode = yield->getPrevNode();
748 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
753 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
754 barrierOp->getResultTypes(),
755 barrierOp->getOperands(), barrierOp->getAttrs());
795 using gpu::WarpDistributionPattern::WarpDistributionPattern;
796 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
797 PatternRewriter &rewriter)
const override {
798 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
799 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
802 Value offsets = storeScatterOp.getOffsets();
803 if (!isa<VectorType>(offsets.
getType()))
805 storeScatterOp,
"Store op must have a vector of offsets argument");
806 VectorType offsetsTy = cast<VectorType>(offsets.
getType());
807 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
808 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
811 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
812 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
815 for (
int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
816 if (storeVecTy.getShape()[i] != 1) {
818 storeScatterOp,
"Only unit dimensions allowed for the leading "
819 "dimensions of the store vector!");
823 auto layoutPayload = storeScatterOp.getLayoutAttr();
826 auto layoutMask = layoutOffsets;
828 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
830 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
832 FailureOr<VectorType> distMaskByWarpOpOrFailure =
834 if (
failed(distStoreVecByWarpOpOrFailure) ||
835 failed(distOffsetsByWarpOpOrFailure) ||
836 failed(distMaskByWarpOpOrFailure)) {
839 "Some vector operands have no layouts, using defaults instead.");
842 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
843 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
844 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
846 SmallVector<size_t> newRetIndices;
847 SmallVector<Value> operands = storeScatterOp->getOperands();
848 SmallVector<Type> operandTypesToYield = {
849 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
851 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
852 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
857 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
858 distPayloadTy.getElementType());
860 VectorType distOffsetsTy1D = VectorType::get(
861 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
862 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
863 distMaskTy.getElementType());
866 Value distPayloadVal = resolveDistributedTy(
867 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
868 Value distOffsetVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
870 Value distMaskVal = resolveDistributedTy(
871 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
873 SmallVector<Value> newStoreScatterOpOperands = {
874 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
877 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
878 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
879 storeScatterOp->getAttrs());
881 rewriter.
eraseOp(storeScatterOp);
891 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
894 assert(maybeCoords.value().size() == 1 &&
895 "Expected one set of distributed offsets");
899 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
905 using gpu::WarpDistributionPattern::WarpDistributionPattern;
906 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
907 PatternRewriter &rewriter)
const override {
908 gpu::YieldOp yield = warpOp.getTerminator();
909 Operation *lastNode = yield->getPrevNode();
910 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
914 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
915 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
917 if (!producedByLastLoad)
919 warpOp,
"The last op is not xegpu::LoadMatrixOp");
922 VectorType sgPayloadTy =
923 dyn_cast<VectorType>(matrixOp.getResult().getType());
924 VectorType warpResultTy =
925 cast<VectorType>(warpOp.getResult(operandIdx).getType());
928 matrixOp,
"the matrix op payload must be a vector type");
930 auto loc = matrixOp.getLoc();
931 auto offsets = matrixOp.getMixedOffsets();
934 "the load op must have offsets");
935 SmallVector<Value> offsetsAsValues =
938 auto layout = matrixOp.getLayoutAttr();
941 matrixOp,
"the matrix operation lacks layout attribute");
943 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
945 if (
failed(distPayloadByWarpOpOrFailure))
947 matrixOp,
"Failed to distribute matrix op payload based on layout.");
949 SmallVector<Value> operands = {matrixOp.getMemDesc()};
950 const unsigned offsetsStartIdx = operands.size();
951 operands.append(offsetsAsValues);
953 SmallVector<Type> operandTypes =
954 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
956 SmallVector<size_t> newRetIndices;
957 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
958 rewriter, warpOp, operands, operandTypes, newRetIndices);
959 SmallVector<Value> newOperands = llvm::map_to_vector(
960 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
962 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
963 ShapedType::kDynamic);
967 ValueRange(newOperands).drop_front(offsetsStartIdx);
969 SmallVector<Value> newCoords = currentOffsets;
972 if (!matrixOp.getSubgroupBlockIoAttr()) {
973 newCoords = computeDistributedCoordinatesForMatrixOp(
974 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
977 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
978 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
979 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
980 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
983 newWarpOp.getResult(operandIdx),
984 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
991 using gpu::WarpDistributionPattern::WarpDistributionPattern;
992 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
993 PatternRewriter &rewriter)
const override {
994 gpu::YieldOp yield = warpOp.getTerminator();
995 Operation *lastNode = yield->getPrevNode();
996 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1000 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1003 matrixOp,
"the matrix op payload must be a vector type");
1005 auto loc = matrixOp.getLoc();
1006 auto offsets = matrixOp.getMixedOffsets();
1007 if (offsets.empty())
1009 "the store op must have offsets");
1010 SmallVector<Value> offsetsAsValues =
1013 auto layout = matrixOp.getLayoutAttr();
1016 matrixOp,
"the matrix operation lacks layout attribute");
1018 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1020 if (
failed(distPayloadByWarpOpOrFailure))
1022 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1024 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1025 const unsigned offsetsStartIdx = operands.size();
1026 operands.append(offsetsAsValues);
1028 SmallVector<Type> operandTypes =
1029 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
1030 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1032 SmallVector<size_t> newRetIndices;
1033 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1034 rewriter, warpOp, operands, operandTypes, newRetIndices);
1035 SmallVector<Value> newOperands = llvm::map_to_vector(
1036 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1038 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1039 ShapedType::kDynamic);
1043 ValueRange(newOperands).drop_front(offsetsStartIdx);
1045 SmallVector<Value> newCoords = currentOffsets;
1048 if (!matrixOp.getSubgroupBlockIoAttr()) {
1049 newCoords = computeDistributedCoordinatesForMatrixOp(
1050 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1054 xegpu::StoreMatrixOp::create(
1055 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1057 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1092 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1093 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1094 PatternRewriter &rewriter)
const override {
1095 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1098 return isa<xegpu::LoadGatherOp>(op) &&
1099 warpOp.getTerminator()->getPrevNode() == op;
1101 if (!producedByLastLoad)
1103 warpOp,
"The last op is not xegpu::LoadGatherOp");
1107 Value offsets = loadGatherOp.getOffsets();
1108 if (!isa<VectorType>(offsets.getType()) ||
1109 !isa<VectorType>(loadGatherOp.getMask().getType()))
1112 "Load op must have vector arguments for offsets and mask");
1113 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1114 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1115 VectorType resultVecTy =
1116 cast<VectorType>(loadGatherOp.getResult().getType());
1118 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1119 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1120 for (
int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1121 if (resultVecTy.getShape()[i] != 1) {
1123 loadGatherOp,
"Only unit dimensions allowed for the leading "
1124 "dimensions of the load vector!");
1128 auto layoutPayload = loadGatherOp.getLayoutAttr();
1129 auto layoutOffsets =
1131 auto layoutMask = layoutOffsets;
1133 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1135 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1137 if (
failed(distOffsetsByWarpOpOrFailure) ||
1138 failed(distMaskByWarpOpOrFailure)) {
1141 "Some vector operands have no layouts, using defaults instead.");
1144 SmallVector<size_t> newRetIndices;
1145 SmallVector<Value> operands = loadGatherOp->getOperands();
1148 VectorType distResultTy =
1149 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1150 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1151 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1153 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1154 distOffsetsTy, distMaskTy};
1156 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1157 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1162 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1163 distResultTy.getElementType());
1165 VectorType distOffsetsTy1D =
1166 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1168 VectorType distMaskTy1D =
1169 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1172 Value distOffsetVal = resolveDistributedTy(
1173 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1174 Value distmaskVal = resolveDistributedTy(
1175 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1177 SmallVector<Value> newLoadGatherOperands = {
1178 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1180 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1181 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1182 loadGatherOp->getAttrs());
1184 Value distributedVal = newWarpOp.getResult(operandIdx);
1188 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1200 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1201 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1202 PatternRewriter &rewriter)
const override {
1204 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1207 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->
getNumRegions())
1209 int operandIdx = -1;
1211 OpOperand *operand = getWarpResult(
1212 warpOp, [&](Operation *op) {
return warpRegionPreYieldOp == op; });
1217 warpOp.getResult(operandIdx).getType())
1219 "The op result is not uniform.");
1223 bool uniformValuesOnly =
1224 llvm::all_of(warpRegionPreYieldOp->
getResults(), [](Value v) {
1225 return !xegpu::getDistributeLayoutAttr(v);
1227 uniformValuesOnly &=
1228 llvm::all_of(warpRegionPreYieldOp->
getOpOperands(), [](OpOperand &opr) {
1229 return !xegpu::getDistributeLayoutAttr(opr);
1231 if (!uniformValuesOnly)
1233 "Some values are not uniform.");
1234 SmallVector<size_t> newRetIndices;
1235 SmallVector<Value> operands =
1236 llvm::to_vector_of<Value>(warpRegionPreYieldOp->
getOperands());
1237 SmallVector<Type> operandTypes =
1239 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1240 rewriter, warpOp, operands, operandTypes, newRetIndices);
1243 IRMapping operandMapper;
1244 for (
auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1245 operandMapper.
map(warpRegionPreYieldOp->
getOperand(oldOperandIdx),
1246 newWarpOp->getResult(newOperandIdx));
1247 Operation *clonedOp = rewriter.
clone(*warpRegionPreYieldOp, operandMapper);
1249 rewriter.
eraseOp(warpRegionPreYieldOp);
1251 assert(operandIdx != -1 &&
"Expected a warp result for the operation");
1315 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1316 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1317 PatternRewriter &rewriter)
const override {
1318 OpOperand *yieldOperand =
1319 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1325 VectorType sourceType = reductionOp.getSourceVectorType();
1326 int64_t sourceRank = sourceType.getRank();
1330 "Only 2D+ reductions are supported.");
1332 for (int64_t i = 0; i < sourceRank - 2; ++i) {
1333 if (sourceType.getShape()[i] != 1)
1335 warpOp,
"Only unit dimensions allowed for the leading dimensions.");
1338 int64_t rowIdx = sourceRank - 2;
1339 int64_t columnIdx = sourceRank - 1;
1340 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1341 if (reductionDims.size() != 1)
1343 "Only 1 reduction dim is supported.");
1344 int64_t reductionDim = reductionDims[0];
1346 if (reductionDim != rowIdx && reductionDim != columnIdx)
1348 warpOp,
"Reduction dim must be among the last 2 dimensions.");
1349 VectorType distributedResultType =
1350 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1351 VectorType resultType = cast<VectorType>(reductionOp.getType());
1352 xegpu::DistributeLayoutAttr sourceLayout =
1355 FailureOr<VectorType> sourceDistTypeOrFailure =
1357 if (
failed(sourceDistTypeOrFailure))
1359 warpOp,
"Failed to distribute the source vector type.");
1360 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1362 bool rowDistributed =
1363 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1364 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1365 sourceType.getShape()[columnIdx];
1366 if (rowDistributed && columnDistributed)
1368 warpOp,
"Expecting source to be distributed in a single dimension.");
1369 int64_t sourceDistDim =
1370 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1371 if (sourceDistDim == -1)
1373 warpOp,
"Expecting a distributed source vector.");
1374 bool resultDistributed =
1375 distributedResultType.getNumElements() < resultType.getNumElements();
1389 bool isReductionLaneLocal =
1390 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1391 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1392 if (isReductionLaneLocal && !resultDistributed)
1394 warpOp,
"Expecting a distributed result for lane-local reduction.");
1396 if (!isReductionLaneLocal && resultDistributed)
1399 "Expecting a broadcasted result for non-lane-local reduction.");
1403 if (isReductionLaneLocal) {
1405 SmallVector<size_t> newRetIndices;
1406 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1407 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1408 {sourceDistType, distributedResultType}, newRetIndices);
1413 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1425 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1503 PatternRewriter &rewriter)
const override {
1504 OpOperand *yieldOperand =
1512 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1513 VectorType destType =
1514 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1516 xegpu::DistributeLayoutAttr sourceLayout =
1518 xegpu::DistributeLayoutAttr resultLayout =
1521 FailureOr<VectorType> sourceDistType;
1522 Type sourceElemOrDistType;
1526 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1529 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1531 broadcastOp.emitWarning()
1532 <<
"Broadcast input layout must be a slice of result layout.";
1535 if (rankDiff == 0) {
1536 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1537 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1538 broadcastUnitDimsSet.end());
1539 assert(sourceLayout.isEqualTo(
1540 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1541 "The sg_data for unit dimensions should be set as 1");
1542 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1547 if (
failed(sourceDistType)) {
1549 warpOp,
"Failed to distribute the source vector type.");
1551 sourceElemOrDistType = sourceDistType.value();
1557 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1559 sourceElemOrDistType = broadcastOp.getSourceType();
1561 FailureOr<VectorType> destDistType =
1563 if (
failed(destDistType)) {
1565 warpOp,
"Failed to distribute the dest vector type.");
1568 SmallVector<size_t> newRetIndices;
1570 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1573 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1575 Value newBroadcast = distributedSource;
1577 if (sourceElemOrDistType != destDistType.value()) {
1580 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1581 destDistType.value(), distributedSource);
1592 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1594 PatternRewriter &rewriter)
const override {
1595 OpOperand *yieldOperand =
1603 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1604 xegpu::DistributeLayoutAttr sourceLayout =
1606 xegpu::DistributeLayoutAttr resultLayout =
1608 if (!sourceLayout || !resultLayout)
1611 "the source or result of shape_cast op lacks distribution layout");
1613 FailureOr<VectorType> sourceDistTypeOrFailure =
1615 shapeCastOp.getSourceVectorType());
1616 if (
failed(sourceDistTypeOrFailure))
1618 warpOp,
"failed to get distributed vector type for source");
1619 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1621 SmallVector<size_t> newRetIndices;
1623 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1626 Value source = newWarpOp.getResult(newRetIndices[0]);
1628 Value newShapeCast = vector::ShapeCastOp::create(
1629 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1640struct VectorExtractStridedSliceDistribution
1642 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1644 PatternRewriter &rewriter)
const override {
1645 OpOperand *operand =
1646 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1652 auto distributedType =
1653 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1655 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1656 auto distributedDims =
1657 getDistributedDims(extractResultType, distributedType);
1661 VectorType updatedSourceType = extractOp.getSourceVectorType();
1662 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1663 extractOp.getSizes(), [](Attribute attr) { return attr; });
1664 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1665 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1666 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1667 extractOp.getStrides(), [](Attribute attr) { return attr; });
1671 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1672 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1674 extractOp.getSourceVectorType().getDimSize(i)));
1676 updatedStrides.push_back(
1682 if (distributedDims.size() > 0) {
1683 if (distributedDims.size() != 1)
1685 warpOp,
"Source can not be distributed in multiple dimensions.");
1686 int64_t distributedDim = distributedDims[0];
1687 int sourceDistrDimSize =
1688 extractOp.getSourceVectorType().getShape()[distributedDim];
1690 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1692 warpOp,
"the source of extract_strided_slice op lacks distribution "
1694 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1697 int subgroupSize = sourceLaneLayout[distributedDim];
1700 if (sourceDistrDimSize % subgroupSize != 0)
1703 "Source size along distributed dimension is not a multiple of "
1705 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1707 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1709 warpOp,
"Expecting unit lane data in source layout");
1712 int64_t distrDimOffset =
1713 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1714 if (distrDimOffset % subgroupSize != 0)
1716 warpOp,
"Offset along distributed dimension "
1717 "is not a multiple of subgroup size.");
1719 sourceLayout, extractOp.getSourceVectorType())
1723 distributedType.getDimSize(distributedDim));
1726 updatedOffsets[distributedDim] =
1731 SmallVector<size_t> newRetIndices;
1733 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1736 Value source = newWarpOp.getResult(newRetIndices[0]);
1738 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1739 rewriter, extractOp.getLoc(), distributedType, source,
1740 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1741 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1742 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1752struct VectorInsertStridedSliceDistribution
1754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1756 PatternRewriter &rewriter)
const override {
1757 OpOperand *operand =
getWarpResult(warpOp, [&](Operation *op) {
1759 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1760 warpOp.getTerminator()->getPrevNode() == op;
1767 auto distributedType =
1768 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1770 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1771 auto destDistributedDims =
1772 getDistributedDims(insertResultType, distributedType);
1776 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1777 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1778 VectorType updatedSourceType = insertOp.getSourceVectorType();
1779 VectorType updatedDestType = insertOp.getDestVectorType();
1780 if (destDistributedDims.size() > 0) {
1782 if (destDistributedDims.size() != 1)
1785 "Expecting source to be distributed in a single dimension.");
1786 int64_t destDistributedDim = destDistributedDims[0];
1788 VectorType srcType = insertOp.getSourceVectorType();
1789 VectorType destType = insertOp.getDestVectorType();
1793 int64_t sourceDistributedDim =
1794 destDistributedDim - (destType.getRank() - srcType.getRank());
1795 if (sourceDistributedDim < 0)
1798 "distributed dimension must be in the last k (i.e. source "
1799 "rank) dims of dest vector");
1800 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1804 if (!destLayout || !sourceLayout ||
1805 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1806 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1808 warpOp,
"the source or dest of insert_strided_slice op lacks "
1809 "distribution layout");
1813 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1816 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1817 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1818 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1819 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1821 warpOp,
"Expecting unit lane data in source and dest layouts");
1823 if (srcDistrDimSize % subgroupSize != 0)
1825 warpOp,
"Distributed dimension size in source is not a multiple of "
1829 int64_t destDistrDimOffset =
1830 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1831 if (destDistrDimOffset % subgroupSize != 0)
1834 "Offset along distributed dimension in dest is not a multiple of "
1838 sourceLayout, insertOp.getSourceVectorType())
1841 destLayout, insertOp.getDestVectorType())
1845 updatedOffsets[destDistributedDim] =
1850 SmallVector<size_t> newRetIndices;
1852 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1853 {updatedSourceType, updatedDestType}, newRetIndices);
1856 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1857 Value dest = newWarpOp.getResult(newRetIndices[1]);
1859 Value newInsertOp = vector::InsertStridedSliceOp::create(
1860 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1861 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1862 insertOp.getStrides());
1872struct MemrefExtractAlignedPointerAsIndexDistribution final
1874 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1875 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1876 PatternRewriter &rewriter)
const override {
1877 OpOperand *operand = getWarpResult(
1878 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1882 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1886 SmallVector<size_t> newRetIndices;
1887 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1888 rewriter, warpOp, extractOp.getSource(),
1889 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1891 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1892 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1893 newWarpOp.getResult(newRetIndices[0]));
1894 Value resultVal = newWarpOp.getResult(operandIdx);
1906 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1907 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1908 PatternRewriter &rewriter)
const override {
1909 OpOperand *operand =
1910 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1913 warpOp,
"warp result is not a vector::BitCast op");
1916 VectorType distributedSourceType =
1919 bitcastOp.getSourceVectorType())
1920 .value_or(VectorType());
1921 if (!distributedSourceType)
1923 bitcastOp,
"Failed to distribute the source vector type in "
1924 "vector::BitCast op");
1925 VectorType distributedResultType =
1926 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1927 SmallVector<size_t> newRetIndices;
1928 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1929 rewriter, warpOp, bitcastOp.getSource(),
1930 TypeRange{distributedSourceType}, newRetIndices);
1932 auto newBitcastOp = vector::BitCastOp::create(
1933 rewriter, newWarpOp.getLoc(), distributedResultType,
1934 newWarpOp.getResult(newRetIndices[0]));
1935 Value distributedVal = newWarpOp.getResult(operandIdx);
1950 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1951 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1952 PatternRewriter &rewriter)
const override {
1953 OpOperand *operand =
1954 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1957 warpOp,
"warp result is not a vector::Transpose op");
1960 xegpu::DistributeLayoutAttr sourceLayout =
1962 xegpu::DistributeLayoutAttr resultLayout =
1964 if (!sourceLayout || !resultLayout)
1967 "the source or result vector of the transpose op lacks layout "
1969 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1970 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1973 if (sourceRank != 2 || resultRank != 2)
1975 transposeOp,
"the source or result vector of the transpose op "
1976 "does not have 2D layout");
1977 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1979 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1980 xegpu::LayoutKind::Lane))
1983 "the source or result vector layouts must be 2D transposes of each "
1985 FailureOr<VectorType> distributedSourceTypeOrFailure =
1987 transposeOp.getSourceVectorType());
1988 if (
failed(distributedSourceTypeOrFailure))
1990 transposeOp,
"Failed to distribute the source vector type in "
1991 "vector::Transpose op");
1992 SmallVector<size_t> newRetIndices;
1993 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1994 rewriter, warpOp, transposeOp.getVector(),
1995 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1997 auto newTransposeOp = vector::TransposeOp::create(
1998 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2000 Value distributedVal = newWarpOp.getResult(operandIdx);
2011 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2012 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2013 PatternRewriter &rewriter)
const override {
2014 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2017 warpOp,
"warp result is not a vector::StepOp op");
2020 xegpu::DistributeLayoutAttr resultLayout =
2024 stepOp,
"the result vector of the step op lacks layout "
2026 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2029 stepOp,
"the result layout must be a slice layout");
2030 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2032 stepOp,
"expecting 1 dim in the effective result layout");
2035 auto loc = stepOp.getLoc();
2036 auto stepResultVecTy = stepOp.getResult().getType();
2037 Value distributedVal = warpOp.getResult(operandIdx);
2038 VectorType newVecTy = cast<VectorType>(distributedVal.
getType());
2040 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2041 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2042 if (
failed(laneDataBlockCoords))
2044 stepOp,
"failed to compute lane data block coordinates");
2046 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2047 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2048 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
2049 newVecTy.getNumElements() / laneDataBlockLength);
2050 SmallVector<Value> stepVals;
2058 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2059 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2060 stepVals.push_back(laneDataBlockStartCoord);
2061 for (
int i = 1; i < laneDataBlockLength; ++i) {
2063 stepVals.push_back(arith::AddIOp::create(
2064 rewriter, loc, laneDataBlockStartCoord, offset));
2067 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
2068 "Expecting the number of step values to match the number of "
2069 "elements in the vector");
2071 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2077struct ConvertLayoutDistribution
2082 PatternRewriter &rewriter)
const override {
2083 auto inputLayout = op.getInputLayoutAttr();
2084 auto targetLayout = op.getTargetLayoutAttr();
2085 Type valType = op.getResult().getType();
2087 if (!inputLayout || !targetLayout)
2094 auto resShape = cast<VectorType>(valType).getShape();
2095 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
2096 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2097 xegpu::LayoutKind::Lane)) {
2099 op,
"lowering incompatible convert_layout not yet supported");
2109struct XeGPUSubgroupDistributePass final
2111 XeGPUSubgroupDistributePass> {
2112 void runOnOperation()
override;
2118 patterns.
add<CreateNdDescDistribution, StoreNdDistribution,
2119 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2120 GpuBarrierDistribution, VectorMultiReductionDistribution,
2121 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2122 VectorBitcastDistribution, LoadMatrixDistribution,
2123 StoreMatrixDistribution, ConvertLayoutDistribution,
2124 MemrefExtractAlignedPointerAsIndexDistribution>(
2126 PatternHierarchy::Regular);
2130 .
add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2131 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2132 VectorStepSliceDistribution, SinkUniformOps>(
2134 PatternHierarchy::AboveRegular);
2142void XeGPUSubgroupDistributePass::runOnOperation() {
2149 signalPassFailure();
2160 signalPassFailure();
2167 getOperation()->walk([&](Operation *op) {
2168 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2169 vector::moveScalarUniformCode(warpOp);
2178 auto distributionFn = [](Value val) {
2179 VectorType vecType = dyn_cast<VectorType>(val.getType());
2180 int64_t vecRank = vecType ? vecType.getRank() : 0;
2189 assert(layout.getRank() == vecRank &&
2190 "Expecting vector and layout rank to match");
2194 SmallVector<unsigned int> distributedDims;
2195 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2196 if (v > 1 && vecType.getShape()[i] % v == 0)
2197 distributedDims.push_back(i);
2203 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2204 int64_t warpSz) {
return Value(); };
2206 vector::populateDistributeReduction(
2208 PatternHierarchy::Regular);
2210 vector::populatePropagateWarpVectorDistributionPatterns(
2211 patterns, distributionFn, shuffleFn,
2212 PatternHierarchy::Regular);
2214 signalPassFailure();
2224 bool foundWarpOp =
false;
2225 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2235 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2241 Value input = op.getOperand(0);
2242 Value output = op.getResult(0);
2245 xegpu::TensorDescType inputDescType =
2246 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2247 xegpu::TensorDescType outputDescType =
2248 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2249 assert(inputDescType && outputDescType &&
2250 "Unrealized conversion cast must have tensor descriptor types");
2255 if (inputDescType.getLayout()) {
2256 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2258 argument.setType(output.
getType());
2260 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2261 argument.getOwner()->getParentOp())) {
2262 auto result = loopOp.getTiedLoopResult(argument);
2271 if (outputDescType.getLayout())
2274 if (op->use_empty())
static Type getElementType(Type type)
Determine the element type of type.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0