37#include "llvm/ADT/ArrayRef.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/SmallVectorExtras.h"
44#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-subgroup-distribute"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 "resolve_simt_type_mismatch";
68enum PatternHierarchy :
unsigned { Regular = 1, AboveRegular = 2 };
85static Value resolveDistributedTy(
Value orig, T expected,
91 if (isa<VectorType>(orig.
getType())) {
93 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
94 return castOp.getResult();
98 if (isa<xegpu::TensorDescType>(orig.
getType())) {
99 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
102 return castOp.getResult(0);
104 llvm_unreachable(
"Unsupported type for reconciliation");
111 VectorType distributedType) {
112 assert(originalType.getRank() == distributedType.getRank() &&
113 "sequential and distributed vector types must have the same rank");
115 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
116 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
117 distributedDims.push_back(i);
120 return distributedDims;
153 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
154 "to set the warp size");
156 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
157 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
161 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
162 return isa<gpu::WarpExecuteOnLane0Op>(op);
165 gpu::ReturnOp origReturnOp = dyn_cast_if_present<gpu::ReturnOp>(
166 gpuFuncOp.getBlocks().back().getTerminator());
169 gpuFuncOp,
"expected gpu.func terminator to be gpu.return");
172 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
175 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
177 auto newGpuFunc = gpu::GPUFuncOp::create(
178 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
180 privateAttributionsTypes);
181 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
185 auto laneId = gpu::LaneIdOp::create(
187 mlir::IntegerAttr());
188 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
189 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
190 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
192 newGpuFunc.getArgumentTypes());
193 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
196 gpu::YieldOp::create(rewriter, origReturnOp.getLoc(),
197 origReturnOp.getOperands());
198 rewriter.
eraseOp(origReturnOp);
201 warpOp.getBodyRegion().begin());
205 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
206 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
244 using gpu::WarpDistributionPattern::WarpDistributionPattern;
245 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
248 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
251 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
255 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
258 descOp,
"the tensor descriptor lacks layout attribute");
260 if (descOp.getMixedOffsets().size())
262 descOp,
"xegpu::CreateNdDescOp must not have offsets");
266 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
267 rewriter, warpOp, descOp->getOperands(),
268 descOp.getOperandTypes(), newRetIndices);
271 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
273 xegpu::TensorDescType distributedTensorDescTy =
274 descOp.getType().dropLayouts();
276 Value newDescOp = xegpu::CreateNdDescOp::create(
277 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
280 Value distributedVal = newWarpOp.getResult(operandIdx);
283 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
322 using gpu::WarpDistributionPattern::WarpDistributionPattern;
323 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
325 gpu::YieldOp yield = warpOp.getTerminator();
326 Operation *lastNode = yield->getPrevNode();
327 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
335 "the store op must have offsets");
340 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
341 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
344 storeOp,
"the source tensor descriptor lacks layout attribute");
346 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
348 if (failed(distributedTypeByWarpOpOrFailure))
350 "Failed to distribute the type");
351 VectorType distributedTypeByWarpOp =
352 distributedTypeByWarpOpOrFailure.value();
356 storeOp.getTensorDesc()};
358 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
359 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
360 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
361 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
371 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
373 if (failed(storeNdDistributedValueTyOrFailure))
375 storeOp,
"Failed to get distributed vector type for the store op");
376 newStoreOperands.push_back(resolveDistributedTy(
377 newWarpOp.getResult(newRetIndices[0]),
378 storeNdDistributedValueTyOrFailure.value(), rewriter));
381 xegpu::TensorDescType distributedTensorDescTy =
382 storeOp.getTensorDescType().dropLayouts();
383 newStoreOperands.push_back(
384 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
385 distributedTensorDescTy, rewriter));
387 for (
size_t i = 2; i < newRetIndices.size(); ++i)
388 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
391 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
392 newStoreOperands, storeOp->getAttrs());
436 using gpu::WarpDistributionPattern::WarpDistributionPattern;
437 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
440 if (!isa<xegpu::LoadNdOp>(op))
445 gpu::YieldOp yield = warpOp.getTerminator();
446 return yield->getPrevNode() == op;
451 warpOp,
"warp result is not a xegpu::LoadNd op");
457 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
458 "determine transpose "
466 "the load op must have offsets");
472 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
473 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
476 loadOp,
"the source tensor descriptor lacks layout attribute");
479 VectorType distributedTypeByWarpOp =
480 cast<VectorType>(warpOp.getResult(operandIdx).getType());
485 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
486 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
487 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
488 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
493 FailureOr<VectorType> loadNdDistValueTyOrFailure =
495 if (failed(loadNdDistValueTyOrFailure))
497 loadOp,
"Failed to get distributed vector type for the load op");
498 xegpu::TensorDescType distributedTensorDescTy =
499 loadOp.getTensorDescType().dropLayouts();
503 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
504 distributedTensorDescTy, rewriter)};
506 for (
size_t i = 1; i < newRetIndices.size(); ++i)
507 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
508 auto newLoadOp = xegpu::LoadNdOp::create(
509 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
510 newLoadOperands, loadOp->getAttrs());
516 newLoadOp.setTranspose(
518 Value distributedVal = newWarpOp.getResult(operandIdx);
522 Value tyResolvedVal = resolveDistributedTy(
523 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
564 using gpu::WarpDistributionPattern::WarpDistributionPattern;
565 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
567 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
570 "warp result is not a xegpu::Dpas op");
575 xegpu::LayoutAttr layoutA =
576 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
577 xegpu::LayoutAttr layoutB =
578 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
579 xegpu::LayoutAttr layoutOut =
580 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
582 if (!layoutA || !layoutB || !layoutOut)
585 "the xegpu::Dpas op lacks layout attribute for A, B or output");
587 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
588 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
589 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
590 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
591 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
592 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
594 if (failed(distLhsTypeByWarpOpOrFailure) ||
595 failed(distRhsTypeByWarpOpOrFailure) ||
596 failed(distResultTypeByWarpOpOrFailure))
599 "Failed to distribute the A, B or output types in xegpu::Dpas op");
604 distLhsTypeByWarpOpOrFailure.value(),
605 distRhsTypeByWarpOpOrFailure.value()};
607 if (dpasOp.getAcc()) {
608 newYieldValues.push_back(dpasOp.getAcc());
609 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
612 SmallVector<size_t> newRetIndices;
613 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
614 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
616 FailureOr<VectorType> expectedDistLhsTyOrFailure =
618 FailureOr<VectorType> expectedDistRhsTyOrFailure =
620 FailureOr<VectorType> expectedDistResultTyOrFailure =
623 if (
failed(expectedDistLhsTyOrFailure) ||
624 failed(expectedDistRhsTyOrFailure) ||
625 failed(expectedDistResultTyOrFailure))
628 "Failed to get distributed vector type for the dpas operands.");
631 SmallVector<Value> newDpasOperands;
632 SmallVector<VectorType> newDpasOperandExpectedTypes;
635 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
636 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
637 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
639 newDpasOperandExpectedTypes.push_back(distributedResultTy);
641 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
642 newDpasOperands.push_back(
643 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
644 newDpasOperandExpectedTypes[i], rewriter));
646 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
647 distributedResultTy, newDpasOperands,
650 Value distributedVal = newWarpOp.getResult(operandIdx);
653 resolveDistributedTy(newDpasOp.getResult(),
654 distResultTypeByWarpOpOrFailure.value(), rewriter);
689 using gpu::WarpDistributionPattern::WarpDistributionPattern;
690 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
691 PatternRewriter &rewriter)
const override {
692 gpu::YieldOp yield = warpOp.getTerminator();
693 Operation *lastNode = yield->getPrevNode();
694 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
698 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
702 "the prefetch op must have offsets");
703 SmallVector<Value> offsetsAsValues =
705 SmallVector<Type> offsetTypes = llvm::map_to_vector(
706 offsetsAsValues, [](Value v) {
return v.
getType(); });
708 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
711 prefetchOp,
"the source tensor descriptor lacks layout attribute");
713 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
714 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
715 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
716 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
717 SmallVector<size_t> newRetIndices;
718 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
719 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
722 xegpu::TensorDescType newTensorDescTy =
723 prefetchOp.getTensorDescType().dropLayouts();
725 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
726 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
728 for (
size_t i = 1; i < newRetIndices.size(); ++i)
729 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
730 Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
731 rewriter, newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
732 prefetchOp->getAttrs());
742 using gpu::WarpDistributionPattern::WarpDistributionPattern;
743 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
744 PatternRewriter &rewriter)
const override {
745 gpu::YieldOp yield = warpOp.getTerminator();
746 Operation *lastNode = yield->getPrevNode();
748 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
753 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
754 barrierOp->getResultTypes(),
755 barrierOp->getOperands(), barrierOp->getAttrs());
795 using gpu::WarpDistributionPattern::WarpDistributionPattern;
796 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
797 PatternRewriter &rewriter)
const override {
798 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
799 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
802 auto offsets = storeScatterOp.getOffsets();
803 if (!offsets || !isa<VectorType>(offsets.getType()))
805 storeScatterOp,
"Store op must have a vector of offsets argument");
806 VectorType offsetsTy = cast<VectorType>(offsets.getType());
807 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
808 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
811 int chunkSize = storeScatterOp.getChunkSize().value_or(1);
812 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
815 for (
int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
816 if (storeVecTy.getShape()[i] != 1) {
818 storeScatterOp,
"Only unit dimensions allowed for the leading "
819 "dimensions of the store vector!");
830 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
832 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
834 FailureOr<VectorType> distMaskByWarpOpOrFailure =
836 if (
failed(distStoreVecByWarpOpOrFailure) ||
837 failed(distOffsetsByWarpOpOrFailure) ||
838 failed(distMaskByWarpOpOrFailure)) {
841 "Some vector operands have no layouts, using defaults instead.");
844 VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
845 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
846 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
848 SmallVector<size_t> newRetIndices;
849 SmallVector<Value> operands = storeScatterOp->getOperands();
850 SmallVector<Type> operandTypesToYield = {
851 distPayloadTy, operands[1].getType(), distOffsetsTy, distMaskTy};
853 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
854 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
859 VectorType payloadTy1D = VectorType::get({distPayloadTy.getNumElements()},
860 distPayloadTy.getElementType());
862 VectorType distOffsetsTy1D = VectorType::get(
863 {distOffsetsTy.getNumElements()}, distOffsetsTy.getElementType());
864 VectorType distMaskTy1D = VectorType::get({distMaskTy.getNumElements()},
865 distMaskTy.getElementType());
868 Value distPayloadVal = resolveDistributedTy(
869 newWarpOp.getResult(newRetIndices[0]), payloadTy1D, rewriter);
870 Value distOffsetVal = resolveDistributedTy(
871 newWarpOp.getResult(newRetIndices[2]), distOffsetsTy1D, rewriter);
872 Value distMaskVal = resolveDistributedTy(
873 newWarpOp.getResult(newRetIndices[3]), distMaskTy1D, rewriter);
875 SmallVector<Value> newStoreScatterOpOperands = {
876 distPayloadVal, newWarpOp.getResult(newRetIndices[1]), distOffsetVal,
879 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
880 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
881 storeScatterOp->getAttrs());
883 rewriter.
eraseOp(storeScatterOp);
893 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
896 assert(maybeCoords.value().size() == 1 &&
897 "Expected one set of distributed offsets");
901 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
907 using gpu::WarpDistributionPattern::WarpDistributionPattern;
908 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
909 PatternRewriter &rewriter)
const override {
910 gpu::YieldOp yield = warpOp.getTerminator();
911 Operation *lastNode = yield->getPrevNode();
912 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
916 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
917 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
919 if (!producedByLastLoad)
921 warpOp,
"The last op is not xegpu::LoadMatrixOp");
924 VectorType sgPayloadTy =
925 dyn_cast<VectorType>(matrixOp.getResult().getType());
926 VectorType warpResultTy =
927 cast<VectorType>(warpOp.getResult(operandIdx).getType());
930 matrixOp,
"the matrix op payload must be a vector type");
932 auto loc = matrixOp.getLoc();
933 auto offsets = matrixOp.getMixedOffsets();
936 "the load op must have offsets");
937 SmallVector<Value> offsetsAsValues =
940 auto layout = matrixOp.getLayoutAttr();
943 matrixOp,
"the matrix operation lacks layout attribute");
945 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
947 if (
failed(distPayloadByWarpOpOrFailure))
949 matrixOp,
"Failed to distribute matrix op payload based on layout.");
951 SmallVector<Value> operands = {matrixOp.getMemDesc()};
952 const unsigned offsetsStartIdx = operands.size();
953 operands.append(offsetsAsValues);
955 SmallVector<Type> operandTypes =
956 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
958 SmallVector<size_t> newRetIndices;
959 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
960 rewriter, warpOp, operands, operandTypes, newRetIndices);
961 SmallVector<Value> newOperands = llvm::map_to_vector(
962 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
964 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
965 ShapedType::kDynamic);
969 ValueRange(newOperands).drop_front(offsetsStartIdx);
971 SmallVector<Value> newCoords = currentOffsets;
974 if (!matrixOp.getSubgroupBlockIoAttr()) {
975 newCoords = computeDistributedCoordinatesForMatrixOp(
976 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
979 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
980 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
981 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
982 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
985 newWarpOp.getResult(operandIdx),
986 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
993 using gpu::WarpDistributionPattern::WarpDistributionPattern;
994 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
995 PatternRewriter &rewriter)
const override {
996 gpu::YieldOp yield = warpOp.getTerminator();
997 Operation *lastNode = yield->getPrevNode();
998 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1002 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1005 matrixOp,
"the matrix op payload must be a vector type");
1007 auto loc = matrixOp.getLoc();
1008 auto offsets = matrixOp.getMixedOffsets();
1009 if (offsets.empty())
1011 "the store op must have offsets");
1012 SmallVector<Value> offsetsAsValues =
1015 auto layout = matrixOp.getLayoutAttr();
1018 matrixOp,
"the matrix operation lacks layout attribute");
1020 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1022 if (
failed(distPayloadByWarpOpOrFailure))
1024 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1026 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1027 const unsigned offsetsStartIdx = operands.size();
1028 operands.append(offsetsAsValues);
1030 SmallVector<Type> operandTypes =
1031 llvm::map_to_vector(operands, [](Value v) {
return v.
getType(); });
1032 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1034 SmallVector<size_t> newRetIndices;
1035 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1036 rewriter, warpOp, operands, operandTypes, newRetIndices);
1037 SmallVector<Value> newOperands = llvm::map_to_vector(
1038 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1040 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1041 ShapedType::kDynamic);
1045 ValueRange(newOperands).drop_front(offsetsStartIdx);
1047 SmallVector<Value> newCoords = currentOffsets;
1050 if (!matrixOp.getSubgroupBlockIoAttr()) {
1051 newCoords = computeDistributedCoordinatesForMatrixOp(
1052 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1056 xegpu::StoreMatrixOp::create(
1057 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1059 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1094 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1095 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1096 PatternRewriter &rewriter)
const override {
1097 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1100 return isa<xegpu::LoadGatherOp>(op) &&
1101 warpOp.getTerminator()->getPrevNode() == op;
1103 if (!producedByLastLoad)
1105 warpOp,
"The last op is not xegpu::LoadGatherOp");
1109 auto offsets = loadGatherOp.getOffsets();
1110 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1111 !isa<VectorType>(loadGatherOp.getMask().getType()))
1114 "Load op must have a vector arguments for offsets and mask");
1115 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1116 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1117 VectorType resultVecTy =
1118 cast<VectorType>(loadGatherOp.getResult().getType());
1120 int chunkSize = loadGatherOp.getChunkSize().value_or(1);
1121 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
1122 for (
int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
1123 if (resultVecTy.getShape()[i] != 1) {
1125 loadGatherOp,
"Only unit dimensions allowed for the leading "
1126 "dimensions of the load vector!");
1130 auto layoutOffsets =
1134 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1136 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1138 if (
failed(distOffsetsByWarpOpOrFailure) ||
1139 failed(distMaskByWarpOpOrFailure)) {
1142 "Some vector operands have no layouts, using defaults instead.");
1145 SmallVector<size_t> newRetIndices;
1146 SmallVector<Value> operands = loadGatherOp->getOperands();
1149 VectorType distResultTy =
1150 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1151 VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
1152 VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
1154 SmallVector<Type> operandTypesToYield = {operands[0].getType(),
1155 distOffsetsTy, distMaskTy};
1157 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1158 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1163 VectorType loadVecTy1D = VectorType::get({distResultTy.getNumElements()},
1164 distResultTy.getElementType());
1166 VectorType distOffsetsTy1D =
1167 VectorType::get({distOffsetsByWarpOpOrFailure.value().getNumElements()},
1169 VectorType distMaskTy1D =
1170 VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
1173 Value distOffsetVal = resolveDistributedTy(
1174 newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
1175 Value distmaskVal = resolveDistributedTy(
1176 newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
1178 SmallVector<Value> newLoadGatherOperands = {
1179 newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
1181 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1182 rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
1183 loadGatherOp->getAttrs());
1185 Value distributedVal = newWarpOp.getResult(operandIdx);
1189 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1201 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1202 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1203 PatternRewriter &rewriter)
const override {
1205 Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
1208 if (!warpRegionPreYieldOp || warpRegionPreYieldOp->
getNumRegions())
1210 int operandIdx = -1;
1212 OpOperand *operand = getWarpResult(
1213 warpOp, [&](Operation *op) {
return warpRegionPreYieldOp == op; });
1218 warpOp.getResult(operandIdx).getType())
1220 "The op result is not uniform.");
1224 bool uniformValuesOnly =
1225 llvm::all_of(warpRegionPreYieldOp->
getResults(), [](Value v) {
1226 return !xegpu::getDistributeLayoutAttr(v);
1228 uniformValuesOnly &=
1229 llvm::all_of(warpRegionPreYieldOp->
getOpOperands(), [](OpOperand &opr) {
1230 return !xegpu::getDistributeLayoutAttr(opr);
1232 if (!uniformValuesOnly)
1234 "Some values are not uniform.");
1235 SmallVector<size_t> newRetIndices;
1236 SmallVector<Value> operands =
1237 llvm::to_vector_of<Value>(warpRegionPreYieldOp->
getOperands());
1238 SmallVector<Type> operandTypes =
1240 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1241 rewriter, warpOp, operands, operandTypes, newRetIndices);
1244 IRMapping operandMapper;
1245 for (
auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
1246 operandMapper.
map(warpRegionPreYieldOp->
getOperand(oldOperandIdx),
1247 newWarpOp->getResult(newOperandIdx));
1248 Operation *clonedOp = rewriter.
clone(*warpRegionPreYieldOp, operandMapper);
1250 rewriter.
eraseOp(warpRegionPreYieldOp);
1252 assert(operandIdx != -1 &&
"Expected a warp result for the operation");
1316 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1317 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1318 PatternRewriter &rewriter)
const override {
1319 OpOperand *yieldOperand =
1320 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1326 VectorType sourceType = reductionOp.getSourceVectorType();
1327 int64_t sourceRank = sourceType.getRank();
1331 "Only 2D+ reductions are supported.");
1333 for (int64_t i = 0; i < sourceRank - 2; ++i) {
1334 if (sourceType.getShape()[i] != 1)
1336 warpOp,
"Only unit dimensions allowed for the leading dimensions.");
1339 int64_t rowIdx = sourceRank - 2;
1340 int64_t columnIdx = sourceRank - 1;
1341 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1342 if (reductionDims.size() != 1)
1344 "Only 1 reduction dim is supported.");
1345 int64_t reductionDim = reductionDims[0];
1347 if (reductionDim != rowIdx && reductionDim != columnIdx)
1349 warpOp,
"Reduction dim must be among the last 2 dimensions.");
1350 VectorType distributedResultType =
1351 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1352 VectorType resultType = cast<VectorType>(reductionOp.getType());
1353 xegpu::DistributeLayoutAttr sourceLayout =
1356 FailureOr<VectorType> sourceDistTypeOrFailure =
1358 if (
failed(sourceDistTypeOrFailure))
1360 warpOp,
"Failed to distribute the source vector type.");
1361 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1363 bool rowDistributed =
1364 sourceDistType.getShape()[rowIdx] != sourceType.getShape()[rowIdx];
1365 bool columnDistributed = sourceDistType.getShape()[columnIdx] !=
1366 sourceType.getShape()[columnIdx];
1367 if (rowDistributed && columnDistributed)
1369 warpOp,
"Expecting source to be distributed in a single dimension.");
1370 int64_t sourceDistDim =
1371 rowDistributed ? rowIdx : (columnDistributed ? columnIdx : -1);
1372 if (sourceDistDim == -1)
1374 warpOp,
"Expecting a distributed source vector.");
1375 bool resultDistributed =
1376 distributedResultType.getNumElements() < resultType.getNumElements();
1390 bool isReductionLaneLocal =
1391 (sourceDistDim == rowIdx && reductionDim == columnIdx) ||
1392 (sourceDistDim == columnIdx && reductionDim == rowIdx);
1393 if (isReductionLaneLocal && !resultDistributed)
1395 warpOp,
"Expecting a distributed result for lane-local reduction.");
1397 if (!isReductionLaneLocal && resultDistributed)
1400 "Expecting a broadcasted result for non-lane-local reduction.");
1404 if (isReductionLaneLocal) {
1406 SmallVector<size_t> newRetIndices;
1407 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1408 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1409 {sourceDistType, distributedResultType}, newRetIndices);
1414 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1426 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1502 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1504 PatternRewriter &rewriter)
const override {
1505 OpOperand *yieldOperand =
1513 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1514 VectorType destType =
1515 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1517 xegpu::DistributeLayoutAttr sourceLayout =
1519 xegpu::DistributeLayoutAttr resultLayout =
1522 FailureOr<VectorType> sourceDistType;
1523 Type sourceElemOrDistType;
1527 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1530 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1532 broadcastOp.emitWarning()
1533 <<
"Broadcast input layout must be a slice of result layout.";
1536 if (rankDiff == 0) {
1537 auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
1538 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1539 broadcastUnitDimsSet.end());
1540 assert(sourceLayout.isEqualTo(
1541 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1542 "The sg_data for unit dimensions should be set as 1");
1543 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1548 if (
failed(sourceDistType)) {
1550 warpOp,
"Failed to distribute the source vector type.");
1552 sourceElemOrDistType = sourceDistType.value();
1558 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1560 sourceElemOrDistType = broadcastOp.getSourceType();
1562 FailureOr<VectorType> destDistType =
1564 if (
failed(destDistType)) {
1566 warpOp,
"Failed to distribute the dest vector type.");
1569 SmallVector<size_t> newRetIndices;
1571 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1574 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1576 Value newBroadcast = distributedSource;
1578 if (sourceElemOrDistType != destDistType.value()) {
1581 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1582 destDistType.value(), distributedSource);
1593 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1595 PatternRewriter &rewriter)
const override {
1596 OpOperand *yieldOperand =
1604 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1605 xegpu::DistributeLayoutAttr sourceLayout =
1607 xegpu::DistributeLayoutAttr resultLayout =
1609 if (!sourceLayout || !resultLayout)
1612 "the source or result of shape_cast op lacks distribution layout");
1614 FailureOr<VectorType> sourceDistTypeOrFailure =
1616 shapeCastOp.getSourceVectorType());
1617 if (
failed(sourceDistTypeOrFailure))
1619 warpOp,
"failed to get distributed vector type for source");
1620 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1622 SmallVector<size_t> newRetIndices;
1624 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1627 Value source = newWarpOp.getResult(newRetIndices[0]);
1629 Value newShapeCast = vector::ShapeCastOp::create(
1630 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1641struct VectorExtractStridedSliceDistribution
1643 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1645 PatternRewriter &rewriter)
const override {
1646 OpOperand *operand =
1647 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1653 auto distributedType =
1654 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1656 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1657 auto distributedDims =
1658 getDistributedDims(extractResultType, distributedType);
1662 VectorType updatedSourceType = extractOp.getSourceVectorType();
1663 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1664 extractOp.getSizes(), [](Attribute attr) { return attr; });
1665 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1666 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1667 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1668 extractOp.getStrides(), [](Attribute attr) { return attr; });
1672 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1673 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1675 extractOp.getSourceVectorType().getDimSize(i)));
1677 updatedStrides.push_back(
1683 if (distributedDims.size() > 0) {
1684 if (distributedDims.size() != 1)
1686 warpOp,
"Source can not be distributed in multiple dimensions.");
1687 int64_t distributedDim = distributedDims[0];
1688 int sourceDistrDimSize =
1689 extractOp.getSourceVectorType().getShape()[distributedDim];
1691 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1693 warpOp,
"the source of extract_strided_slice op lacks distribution "
1695 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1698 int subgroupSize = sourceLaneLayout[distributedDim];
1701 if (sourceDistrDimSize % subgroupSize != 0)
1704 "Source size along distributed dimension is not a multiple of "
1706 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1708 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1710 warpOp,
"Expecting unit lane data in source layout");
1713 int64_t distrDimOffset =
1714 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1715 if (distrDimOffset % subgroupSize != 0)
1717 warpOp,
"Offset along distributed dimension "
1718 "is not a multiple of subgroup size.");
1720 sourceLayout, extractOp.getSourceVectorType())
1724 distributedType.getDimSize(distributedDim));
1727 updatedOffsets[distributedDim] =
1732 SmallVector<size_t> newRetIndices;
1734 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1737 Value source = newWarpOp.getResult(newRetIndices[0]);
1739 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1740 rewriter, extractOp.getLoc(), distributedType, source,
1741 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1742 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1743 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1753struct VectorInsertStridedSliceDistribution
1755 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1757 PatternRewriter &rewriter)
const override {
1758 OpOperand *operand =
getWarpResult(warpOp, [&](Operation *op) {
1760 return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1761 warpOp.getTerminator()->getPrevNode() == op;
1768 auto distributedType =
1769 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1771 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1772 auto destDistributedDims =
1773 getDistributedDims(insertResultType, distributedType);
1777 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1778 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1779 VectorType updatedSourceType = insertOp.getSourceVectorType();
1780 VectorType updatedDestType = insertOp.getDestVectorType();
1781 if (destDistributedDims.size() > 0) {
1783 if (destDistributedDims.size() != 1)
1786 "Expecting source to be distributed in a single dimension.");
1787 int64_t destDistributedDim = destDistributedDims[0];
1789 VectorType srcType = insertOp.getSourceVectorType();
1790 VectorType destType = insertOp.getDestVectorType();
1794 int64_t sourceDistributedDim =
1795 destDistributedDim - (destType.getRank() - srcType.getRank());
1796 if (sourceDistributedDim < 0)
1799 "distributed dimension must be in the last k (i.e. source "
1800 "rank) dims of dest vector");
1801 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1805 if (!destLayout || !sourceLayout ||
1806 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1807 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1809 warpOp,
"the source or dest of insert_strided_slice op lacks "
1810 "distribution layout");
1814 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1817 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1818 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1819 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1820 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1822 warpOp,
"Expecting unit lane data in source and dest layouts");
1824 if (srcDistrDimSize % subgroupSize != 0)
1826 warpOp,
"Distributed dimension size in source is not a multiple of "
1830 int64_t destDistrDimOffset =
1831 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1832 if (destDistrDimOffset % subgroupSize != 0)
1835 "Offset along distributed dimension in dest is not a multiple of "
1839 sourceLayout, insertOp.getSourceVectorType())
1842 destLayout, insertOp.getDestVectorType())
1846 updatedOffsets[destDistributedDim] =
1851 SmallVector<size_t> newRetIndices;
1853 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1854 {updatedSourceType, updatedDestType}, newRetIndices);
1857 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1858 Value dest = newWarpOp.getResult(newRetIndices[1]);
1860 Value newInsertOp = vector::InsertStridedSliceOp::create(
1861 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1862 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1863 insertOp.getStrides());
1873struct MemrefExtractAlignedPointerAsIndexDistribution final
1875 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1876 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1877 PatternRewriter &rewriter)
const override {
1878 OpOperand *operand = getWarpResult(
1879 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1883 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1887 SmallVector<size_t> newRetIndices;
1888 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1889 rewriter, warpOp, extractOp.getSource(),
1890 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1892 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1893 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1894 newWarpOp.getResult(newRetIndices[0]));
1895 Value resultVal = newWarpOp.getResult(operandIdx);
1907 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1908 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1909 PatternRewriter &rewriter)
const override {
1910 OpOperand *operand =
1911 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1914 warpOp,
"warp result is not a vector::BitCast op");
1917 VectorType distributedSourceType =
1920 bitcastOp.getSourceVectorType())
1921 .value_or(VectorType());
1922 if (!distributedSourceType)
1924 bitcastOp,
"Failed to distribute the source vector type in "
1925 "vector::BitCast op");
1926 VectorType distributedResultType =
1927 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1928 SmallVector<size_t> newRetIndices;
1929 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1930 rewriter, warpOp, bitcastOp.getSource(),
1931 TypeRange{distributedSourceType}, newRetIndices);
1933 auto newBitcastOp = vector::BitCastOp::create(
1934 rewriter, newWarpOp.getLoc(), distributedResultType,
1935 newWarpOp.getResult(newRetIndices[0]));
1936 Value distributedVal = newWarpOp.getResult(operandIdx);
1951 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1952 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1953 PatternRewriter &rewriter)
const override {
1954 OpOperand *operand =
1955 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1958 warpOp,
"warp result is not a vector::Transpose op");
1961 xegpu::DistributeLayoutAttr sourceLayout =
1963 xegpu::DistributeLayoutAttr resultLayout =
1965 if (!sourceLayout || !resultLayout)
1968 "the source or result vector of the transpose op lacks layout "
1970 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1971 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1974 if (sourceRank != 2 || resultRank != 2)
1976 transposeOp,
"the source or result vector of the transpose op "
1977 "does not have 2D layout");
1978 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1980 if (!resultLayout.isTransposeOf(sourceLayout, perm,
1981 xegpu::LayoutKind::Lane))
1984 "the source or result vector layouts must be 2D transposes of each "
1986 FailureOr<VectorType> distributedSourceTypeOrFailure =
1988 transposeOp.getSourceVectorType());
1989 if (
failed(distributedSourceTypeOrFailure))
1991 transposeOp,
"Failed to distribute the source vector type in "
1992 "vector::Transpose op");
1993 SmallVector<size_t> newRetIndices;
1994 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1995 rewriter, warpOp, transposeOp.getVector(),
1996 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1998 auto newTransposeOp = vector::TransposeOp::create(
1999 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2001 Value distributedVal = newWarpOp.getResult(operandIdx);
2012 using gpu::WarpDistributionPattern::WarpDistributionPattern;
2013 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
2014 PatternRewriter &rewriter)
const override {
2015 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
2018 warpOp,
"warp result is not a vector::StepOp op");
2021 xegpu::DistributeLayoutAttr resultLayout =
2025 stepOp,
"the result vector of the step op lacks layout "
2027 auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
2030 stepOp,
"the result layout must be a slice layout");
2031 if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
2033 stepOp,
"expecting 1 dim in the effective result layout");
2036 auto loc = stepOp.getLoc();
2037 auto stepResultVecTy = stepOp.getResult().getType();
2038 Value distributedVal = warpOp.getResult(operandIdx);
2039 VectorType newVecTy = cast<VectorType>(distributedVal.
getType());
2041 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
2042 rewriter, loc, warpOp.getLaneid(), stepResultVecTy.getShape());
2043 if (
failed(laneDataBlockCoords))
2045 stepOp,
"failed to compute lane data block coordinates");
2047 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
2048 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
2049 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
2050 newVecTy.getNumElements() / laneDataBlockLength);
2051 SmallVector<Value> stepVals;
2059 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
2060 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
2061 stepVals.push_back(laneDataBlockStartCoord);
2062 for (
int i = 1; i < laneDataBlockLength; ++i) {
2064 stepVals.push_back(arith::AddIOp::create(
2065 rewriter, loc, laneDataBlockStartCoord, offset));
2068 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
2069 "Expecting the number of step values to match the number of "
2070 "elements in the vector");
2072 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
2078struct ConvertLayoutDistribution
2083 PatternRewriter &rewriter)
const override {
2084 auto inputLayout = op.getInputLayoutAttr();
2085 auto targetLayout = op.getTargetLayoutAttr();
2086 auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
2088 if (!inputLayout || !targetLayout)
2091 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
2092 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
2093 xegpu::LayoutKind::Lane)) {
2095 op,
"lowering incompatible convert_layout not yet supported");
2105struct XeGPUSubgroupDistributePass final
2107 XeGPUSubgroupDistributePass> {
2108 void runOnOperation()
override;
2114 patterns.
add<CreateNdDescDistribution, StoreNdDistribution,
2115 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2116 GpuBarrierDistribution, VectorMultiReductionDistribution,
2117 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2118 VectorBitcastDistribution, LoadMatrixDistribution,
2119 StoreMatrixDistribution, ConvertLayoutDistribution,
2120 MemrefExtractAlignedPointerAsIndexDistribution>(
2122 PatternHierarchy::Regular);
2126 .
add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2127 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
2128 VectorStepSliceDistribution, SinkUniformOps>(
2130 PatternHierarchy::AboveRegular);
2138void XeGPUSubgroupDistributePass::runOnOperation() {
2145 signalPassFailure();
2156 signalPassFailure();
2163 getOperation()->walk([&](Operation *op) {
2164 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2165 vector::moveScalarUniformCode(warpOp);
2174 auto distributionFn = [](Value val) {
2175 VectorType vecType = dyn_cast<VectorType>(val.getType());
2176 int64_t vecRank = vecType ? vecType.getRank() : 0;
2185 assert(layout.getRank() == vecRank &&
2186 "Expecting vector and layout rank to match");
2190 SmallVector<unsigned int> distributedDims;
2191 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2192 if (v > 1 && vecType.getShape()[i] % v == 0)
2193 distributedDims.push_back(i);
2199 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2200 int64_t warpSz) {
return Value(); };
2202 vector::populateDistributeReduction(
2204 PatternHierarchy::Regular);
2206 vector::populatePropagateWarpVectorDistributionPatterns(
2207 patterns, distributionFn, shuffleFn,
2208 PatternHierarchy::Regular);
2210 signalPassFailure();
2220 bool foundWarpOp =
false;
2221 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2231 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2237 Value input = op.getOperand(0);
2238 Value output = op.getResult(0);
2241 xegpu::TensorDescType inputDescType =
2242 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2243 xegpu::TensorDescType outputDescType =
2244 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2245 assert(inputDescType && outputDescType &&
2246 "Unrealized conversion cast must have tensor descriptor types");
2251 if (inputDescType.getLayout()) {
2252 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2254 argument.setType(output.
getType());
2256 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2257 argument.getOwner()->getParentOp())) {
2258 auto result = loopOp.getTiedLoopResult(argument);
2267 if (outputDescType.getLayout())
2270 if (op->use_empty())
static Type getElementType(Type type)
Determine the element type of type.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0