35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallVector.h"
41#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
42#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46#define DEBUG_TYPE "xegpu-subgroup-distribute"
47#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
52 "resolve_simt_type_mismatch";
65static constexpr unsigned regularPatternBenefit = 1;
66static constexpr unsigned highPatternBenefit = 2;
81static FailureOr<VectorType>
82getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
83 VectorType originalType) {
86 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
87 "Expecting a valid layout.");
89 layout.getEffectiveLaneLayoutAsInt();
90 assert(
static_cast<size_t>(originalType.getRank()) >=
91 effectiveLaneLayout.size() &&
92 "Rank of the original vector type should be greater or equal to the "
93 "size of the lane layout to distribute the vector type.");
97 unsigned distributionStart =
98 originalType.getRank() - effectiveLaneLayout.size();
99 for (
auto [i, dim] : llvm::enumerate(originalType.getShape())) {
100 if (i < distributionStart)
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
107 return VectorType::get(distributedShape, originalType.getElementType());
125static Value resolveDistributedTy(
Value orig, T expected,
128 if (orig.
getType() == expected)
131 if (isa<VectorType>(orig.
getType())) {
133 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
134 return castOp.getResult();
138 if (isa<xegpu::TensorDescType>(orig.
getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
142 return castOp.getResult(0);
144 llvm_unreachable(
"Unsupported type for reconciliation");
151static bool requirePacked(
const xegpu::LayoutAttr layout) {
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
157 return laneData[0] != 1;
161static bool requireTranspose(
const xegpu::LayoutAttr layout,
170 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171 if (laneLayout.size() != 2)
179 VectorType distributedType) {
180 assert(originalType.getRank() == distributedType.getRank() &&
181 "sequential and distributed vector types must have the same rank");
183 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
184 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
185 distributedDims.push_back(i);
188 return distributedDims;
221 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
222 "to set the warp size");
224 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
225 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
229 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
230 return isa<gpu::WarpExecuteOnLane0Op>(op);
235 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
238 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
240 auto newGpuFunc = gpu::GPUFuncOp::create(
241 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
243 privateAttributionsTypes);
244 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
248 auto laneId = gpu::LaneIdOp::create(
250 mlir::IntegerAttr());
251 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
252 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
253 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
255 newGpuFunc.getArgumentTypes());
256 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
259 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
261 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
262 origRetunOp.getOperands());
266 warpOp.getBodyRegion().begin());
270 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
271 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
309 using gpu::WarpDistributionPattern::WarpDistributionPattern;
310 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
313 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
316 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
320 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
323 descOp,
"the tensor descriptor lacks layout attribute");
325 if (descOp.getMixedOffsets().size())
327 descOp,
"xegpu::CreateNdDescOp must not have offsets");
331 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
332 rewriter, warpOp, descOp->getOperands(),
333 descOp.getOperandTypes(), newRetIndices);
336 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
338 xegpu::TensorDescType distributedTensorDescTy =
339 descOp.getType().dropLayouts();
341 Value newDescOp = xegpu::CreateNdDescOp::create(
342 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
345 Value distributedVal = newWarpOp.getResult(operandIdx);
348 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
387 using gpu::WarpDistributionPattern::WarpDistributionPattern;
388 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
390 gpu::YieldOp yield = warpOp.getTerminator();
391 Operation *lastNode = yield->getPrevNode();
392 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
400 "the store op must have offsets");
404 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
405 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
406 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
409 storeOp,
"the source tensor descriptor lacks layout attribute");
411 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
412 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
413 if (failed(distributedTypeByWarpOpOrFailure))
415 "Failed to distribute the type");
416 VectorType distributedTypeByWarpOp =
417 distributedTypeByWarpOpOrFailure.value();
421 storeOp.getTensorDesc()};
423 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
424 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
425 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
426 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
436 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
438 if (failed(storeNdDistributedValueTyOrFailure))
440 storeOp,
"Failed to get distributed vector type for the store op");
441 newStoreOperands.push_back(resolveDistributedTy(
442 newWarpOp.getResult(newRetIndices[0]),
443 storeNdDistributedValueTyOrFailure.value(), rewriter));
446 xegpu::TensorDescType distributedTensorDescTy =
447 storeOp.getTensorDescType().dropLayouts();
448 newStoreOperands.push_back(
449 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
450 distributedTensorDescTy, rewriter));
452 for (
size_t i = 2; i < newRetIndices.size(); ++i)
453 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
456 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
457 newStoreOperands, storeOp->getAttrs());
501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
502 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
503 PatternRewriter &rewriter)
const override {
504 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
505 if (!isa<xegpu::LoadNdOp>(op))
510 gpu::YieldOp yield = warpOp.getTerminator();
511 return yield->getPrevNode() == op;
516 warpOp,
"warp result is not a xegpu::LoadNd op");
522 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
523 "determine transpose "
528 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
531 "the load op must have offsets");
532 SmallVector<Value> offsetsAsValues =
534 SmallVector<Type> offsetTypes = llvm::to_vector(
535 llvm::map_range(offsetsAsValues, [](Value v) {
return v.
getType(); }));
537 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
538 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
541 loadOp,
"the source tensor descriptor lacks layout attribute");
544 VectorType distributedTypeByWarpOp =
545 cast<VectorType>(warpOp.getResult(operandIdx).getType());
547 SmallVector<size_t> newRetIndices;
548 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
549 SmallVector<Type> newYieldedTypes = {tensorDescTy};
550 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
551 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
552 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
553 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
558 FailureOr<VectorType> loadNdDistValueTyOrFailure =
560 if (
failed(loadNdDistValueTyOrFailure))
562 loadOp,
"Failed to get distributed vector type for the load op");
563 xegpu::TensorDescType distributedTensorDescTy =
564 loadOp.getTensorDescType().dropLayouts();
567 SmallVector<Value> newLoadOperands{
568 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
569 distributedTensorDescTy, rewriter)};
571 for (
size_t i = 1; i < newRetIndices.size(); ++i)
572 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
573 auto newLoadOp = xegpu::LoadNdOp::create(
574 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
575 newLoadOperands, loadOp->getAttrs());
578 newLoadOp.setPacked(requirePacked(layout));
580 if (requireTranspose(layout, uArch))
581 newLoadOp.setTranspose(
583 Value distributedVal = newWarpOp.getResult(operandIdx);
587 Value tyResolvedVal = resolveDistributedTy(
588 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
630 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
631 PatternRewriter &rewriter)
const override {
632 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
635 "warp result is not a xegpu::Dpas op");
640 xegpu::LayoutAttr layoutA =
641 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutAAttr());
642 xegpu::LayoutAttr layoutB =
643 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutBAttr());
644 xegpu::LayoutAttr layoutOut =
645 dyn_cast<xegpu::LayoutAttr>(dpasOp.getLayoutCdAttr());
647 if (!layoutA || !layoutB || !layoutOut)
650 "the xegpu::Dpas op lacks layout attribute for A, B or output");
652 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
653 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
654 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
655 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
656 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
657 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
659 if (
failed(distLhsTypeByWarpOpOrFailure) ||
660 failed(distRhsTypeByWarpOpOrFailure) ||
661 failed(distResultTypeByWarpOpOrFailure))
664 "Failed to distribute the A, B or output types in xegpu::Dpas op");
666 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
668 llvm::SmallVector<Type, 3> newYieldTypes{
669 distLhsTypeByWarpOpOrFailure.value(),
670 distRhsTypeByWarpOpOrFailure.value()};
672 if (dpasOp.getAcc()) {
673 newYieldValues.push_back(dpasOp.getAcc());
674 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
677 SmallVector<size_t> newRetIndices;
678 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
679 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
681 FailureOr<VectorType> expectedDistLhsTyOrFailure =
683 FailureOr<VectorType> expectedDistRhsTyOrFailure =
685 FailureOr<VectorType> expectedDistResultTyOrFailure =
688 if (
failed(expectedDistLhsTyOrFailure) ||
689 failed(expectedDistRhsTyOrFailure) ||
690 failed(expectedDistResultTyOrFailure))
693 "Failed to get distributed vector type for the dpas operands.");
696 SmallVector<Value> newDpasOperands;
697 SmallVector<VectorType> newDpasOperandExpectedTypes;
700 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
701 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
702 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
704 newDpasOperandExpectedTypes.push_back(distributedResultTy);
706 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
707 newDpasOperands.push_back(
708 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
709 newDpasOperandExpectedTypes[i], rewriter));
711 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
712 distributedResultTy, newDpasOperands,
715 Value distributedVal = newWarpOp.getResult(operandIdx);
718 resolveDistributedTy(newDpasOp.getResult(),
719 distResultTypeByWarpOpOrFailure.value(), rewriter);
754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
755 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
756 PatternRewriter &rewriter)
const override {
757 gpu::YieldOp yield = warpOp.getTerminator();
758 Operation *lastNode = yield->getPrevNode();
759 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
763 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
767 "the prefetch op must have offsets");
768 SmallVector<Value> offsetsAsValues =
770 SmallVector<Type> offsetTypes = llvm::to_vector(
771 llvm::map_range(offsetsAsValues, [](Value v) {
return v.
getType(); }));
773 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
776 prefetchOp,
"the source tensor descriptor lacks layout attribute");
778 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
779 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
780 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
781 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
782 SmallVector<size_t> newRetIndices;
783 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
784 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
787 xegpu::TensorDescType newTensorDescTy =
788 prefetchOp.getTensorDescType().dropLayouts();
790 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
791 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
793 for (
size_t i = 1; i < newRetIndices.size(); ++i)
794 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
795 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
796 newPrefetchOperands, prefetchOp->getAttrs());
806 using gpu::WarpDistributionPattern::WarpDistributionPattern;
807 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
808 PatternRewriter &rewriter)
const override {
809 gpu::YieldOp yield = warpOp.getTerminator();
810 Operation *lastNode = yield->getPrevNode();
812 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
817 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
818 barrierOp->getResultTypes(),
819 barrierOp->getOperands(), barrierOp->getAttrs());
850 using gpu::WarpDistributionPattern::WarpDistributionPattern;
851 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
852 PatternRewriter &rewriter)
const override {
853 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
854 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
857 auto offsets = storeScatterOp.getOffsets();
858 if (!offsets || !isa<VectorType>(offsets.getType()))
860 storeScatterOp,
"Store op must have a vector of offsets argument");
861 VectorType offsetsTy = cast<VectorType>(offsets.getType());
862 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
863 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
865 "Expected 1D offsets and mask vector");
866 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
867 if (storeVecTy.getRank() > 2)
869 storeScatterOp,
"Expected at most 2D result at SG level");
871 std::string layoutPayloadName =
873 std::string layoutOffsetsName =
875 std::string layoutMaskName =
878 xegpu::LayoutAttr layoutPayload =
879 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
880 xegpu::LayoutAttr layoutOffsets =
881 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
882 xegpu::LayoutAttr layoutMask =
883 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
885 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
886 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
887 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
888 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
889 FailureOr<VectorType> distMaskByWarpOpOrFailure =
890 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
891 if (
failed(distStoreVecByWarpOpOrFailure) ||
892 failed(distOffsetsByWarpOpOrFailure) ||
893 failed(distMaskByWarpOpOrFailure)) {
896 "Some vector operands have no layouts, using defaults instead.");
899 VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
901 VectorType expectedPayloadTy =
902 VectorType::get({distPayloadTyByWarpOp.getNumElements()},
903 distPayloadTyByWarpOp.getElementType());
905 SmallVector<size_t> newRetIndices;
906 SmallVector<Value> operands = storeScatterOp->getOperands();
907 SmallVector<Type> operandTypesToYield = {
908 distPayloadTyByWarpOp, operands[1].getType(),
909 distOffsetsByWarpOpOrFailure.value(),
910 distMaskByWarpOpOrFailure.value()};
912 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
913 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
914 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
915 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
919 newStoreScatterOpOperands[0] = resolveDistributedTy(
920 newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
921 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
922 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
923 storeScatterOp->getAttrs());
925 rewriter.
eraseOp(storeScatterOp);
935 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
938 assert(maybeCoords.value().size() == 1 &&
939 "Expected one set of distributed offsets");
943 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
949 using gpu::WarpDistributionPattern::WarpDistributionPattern;
950 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
951 PatternRewriter &rewriter)
const override {
952 gpu::YieldOp yield = warpOp.getTerminator();
953 Operation *lastNode = yield->getPrevNode();
954 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
958 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
959 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
961 if (!producedByLastLoad)
963 warpOp,
"The last op is not xegpu::LoadMatrixOp");
966 VectorType sgPayloadTy =
967 dyn_cast<VectorType>(matrixOp.getResult().getType());
968 VectorType warpResultTy =
969 cast<VectorType>(warpOp.getResult(operandIdx).getType());
972 matrixOp,
"the matrix op payload must be a vector type");
974 auto loc = matrixOp.getLoc();
975 auto offsets = matrixOp.getMixedOffsets();
978 "the load op must have offsets");
979 SmallVector<Value> offsetsAsValues =
982 auto layout = matrixOp.getLayoutAttr();
985 matrixOp,
"the matrix operation lacks layout attribute");
987 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
988 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
989 if (
failed(distPayloadByWarpOpOrFailure))
991 matrixOp,
"Failed to distribute matrix op payload based on layout.");
993 SmallVector<Value> operands = {matrixOp.getMemDesc()};
994 const unsigned offsetsStartIdx = operands.size();
995 operands.append(offsetsAsValues);
997 SmallVector<Type> operandTypes = llvm::to_vector(
998 llvm::map_range(operands, [](Value v) {
return v.
getType(); }));
1000 SmallVector<size_t> newRetIndices;
1001 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1002 rewriter, warpOp, operands, operandTypes, newRetIndices);
1003 SmallVector<Value> newOperands = llvm::map_to_vector(
1004 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1006 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1007 ShapedType::kDynamic);
1011 ValueRange(newOperands).drop_front(offsetsStartIdx);
1013 SmallVector<Value> newCoords = currentOffsets;
1016 if (!matrixOp.getSubgroupBlockIoAttr()) {
1017 newCoords = computeDistributedCoordinatesForMatrixOp(
1018 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1021 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1022 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1023 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
1024 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1027 newWarpOp.getResult(operandIdx),
1028 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1035 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1036 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1037 PatternRewriter &rewriter)
const override {
1038 gpu::YieldOp yield = warpOp.getTerminator();
1039 Operation *lastNode = yield->getPrevNode();
1040 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1044 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1047 matrixOp,
"the matrix op payload must be a vector type");
1049 auto loc = matrixOp.getLoc();
1050 auto offsets = matrixOp.getMixedOffsets();
1051 if (offsets.empty())
1053 "the store op must have offsets");
1054 SmallVector<Value> offsetsAsValues =
1057 auto layout = matrixOp.getLayoutAttr();
1060 matrixOp,
"the matrix operation lacks layout attribute");
1062 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1063 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1064 if (
failed(distPayloadByWarpOpOrFailure))
1066 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1068 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1069 const unsigned offsetsStartIdx = operands.size();
1070 operands.append(offsetsAsValues);
1072 SmallVector<Type> operandTypes = llvm::to_vector(
1073 llvm::map_range(operands, [](Value v) {
return v.
getType(); }));
1074 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1076 SmallVector<size_t> newRetIndices;
1077 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1078 rewriter, warpOp, operands, operandTypes, newRetIndices);
1079 SmallVector<Value> newOperands = llvm::map_to_vector(
1080 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1082 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1083 ShapedType::kDynamic);
1087 ValueRange(newOperands).drop_front(offsetsStartIdx);
1089 SmallVector<Value> newCoords = currentOffsets;
1092 if (!matrixOp.getSubgroupBlockIoAttr()) {
1093 newCoords = computeDistributedCoordinatesForMatrixOp(
1094 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1098 xegpu::StoreMatrixOp::create(
1099 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1101 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter)
const override {
1130 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1133 return isa<xegpu::LoadGatherOp>(op) &&
1134 warpOp.getTerminator()->getPrevNode() == op;
1136 if (!producedByLastLoad)
1138 warpOp,
"The last op is not xegpu::LoadGatherOp");
1142 auto offsets = loadGatherOp.getOffsets();
1143 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1144 !isa<VectorType>(loadGatherOp.getMask().getType()))
1147 "Load op must have a vector arguments for offsets and mask");
1148 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1149 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1150 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
1152 "Expected 1D offsets and mask vector");
1154 std::string layoutOffsetsName =
1156 std::string layoutMaskName =
1159 xegpu::LayoutAttr layoutOffsets =
1160 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1161 xegpu::LayoutAttr layoutMask =
1162 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1164 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1165 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1166 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1167 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1168 if (
failed(distOffsetsByWarpOpOrFailure) ||
1169 failed(distMaskByWarpOpOrFailure)) {
1172 "Some vector operands have no layouts, using defaults instead.");
1175 SmallVector<size_t> newRetIndices;
1176 SmallVector<Value> operands = loadGatherOp->getOperands();
1177 SmallVector<Type> operandTypesToYield = {
1178 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1179 distMaskByWarpOpOrFailure.value()};
1182 VectorType distResultTy =
1183 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1185 VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
1186 distResultTy.getElementType());
1188 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1189 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1191 SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1192 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1195 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1196 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1197 loadGatherOp->getAttrs());
1199 Value distributedVal = newWarpOp.getResult(operandIdx);
1203 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1212 vector::CombiningKind kind,
1216 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1217 VectorType sourceType = src.getType();
1218 int64_t sourceH = sourceType.getShape()[0];
1219 int64_t sourceW = sourceType.getShape()[1];
1220 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1222 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1223 Value reductionResult = arith::ConstantOp::create(
1224 rewriter, loc,
acc.getType(),
1231 for (
int i = 0; i < nSlices; ++i) {
1233 if (reductionDim == 1) {
1234 sliceOffsets = {i, 0};
1235 sliceSizes = {1, sourceW};
1237 sliceOffsets = {0, i};
1238 sliceSizes = {sourceH, 1};
1240 vector::ExtractStridedSliceOp extractOp =
1241 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1242 sliceSizes, {1, 1});
1244 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1246 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1248 VectorType::get({nSliceElements}, sourceType.getElementType()),
1249 extractOp.getResult());
1262 Value accExtract = vector::ExtractOp::create(rewriter, loc,
acc, i);
1263 Value reduction = vector::ReductionOp::create(
1264 rewriter, loc, kind, slice.getResult(), accExtract);
1266 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1268 return reductionResult;
1327 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1328 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1329 PatternRewriter &rewriter)
const override {
1330 OpOperand *yieldOperand =
1331 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1337 VectorType sourceType = reductionOp.getSourceVectorType();
1339 if (sourceType.getRank() != 2)
1341 "Only 2D reductions are supported.");
1342 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1345 if (reductionDims.size() != 1)
1347 warpOp,
"Only 1 reduction dimension is supported.");
1348 int64_t reductionDim = reductionDims[0];
1349 VectorType distributedResultType =
1350 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1351 VectorType resultType = cast<VectorType>(reductionOp.getType());
1352 xegpu::DistributeLayoutAttr sourceLayout =
1355 FailureOr<VectorType> sourceDistTypeOrFailure =
1356 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1357 if (
failed(sourceDistTypeOrFailure))
1359 warpOp,
"Failed to distribute the source vector type.");
1360 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1362 bool dim0Distributed =
1363 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1364 bool dim1Distributed =
1365 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1366 if (dim0Distributed && dim1Distributed)
1368 warpOp,
"Expecting source to be distributed in a single dimension.");
1369 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1370 if (sourceDistDim == -1)
1372 warpOp,
"Expecting a distributed source vector.");
1373 bool resultDistributed =
1374 distributedResultType.getNumElements() < resultType.getNumElements();
1388 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1389 (sourceDistDim == 1 && reductionDim == 0);
1390 if (isReductionLaneLocal && !resultDistributed)
1392 warpOp,
"Expecting a distributed result for lane-local reduction.");
1394 if (!isReductionLaneLocal && resultDistributed)
1397 "Expecting a broadcasted result for non-lane-local reduction.");
1401 if (isReductionLaneLocal) {
1403 SmallVector<size_t> newRetIndices;
1404 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1405 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1406 {sourceDistType, distributedResultType}, newRetIndices);
1408 Value
result = lowerToVectorReductions(
1411 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1420 Value
result = lowerToVectorReductions(
1423 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1499 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1501 PatternRewriter &rewriter)
const override {
1502 OpOperand *yieldOperand =
1510 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1511 VectorType destType =
1512 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1514 xegpu::DistributeLayoutAttr sourceLayout =
1516 xegpu::DistributeLayoutAttr resultLayout =
1519 FailureOr<VectorType> sourceDistType;
1520 Type sourceElemOrDistType;
1524 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1527 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1531 "Broadcast input layout must be a slice of result layout.");
1534 if (rankDiff == 0) {
1536 broadcastOp.computeBroadcastedUnitDims();
1537 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1540 warpOp,
"For same-rank broadcast, source must be identical to "
1541 "adjusted result layouts with unit dims.");
1542 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1543 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1547 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1548 if (
failed(sourceDistType)) {
1550 warpOp,
"Failed to distribute the source vector type.");
1552 sourceElemOrDistType = sourceDistType.value();
1558 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1560 sourceElemOrDistType = broadcastOp.getSourceType();
1562 FailureOr<VectorType> destDistType =
1563 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1564 if (
failed(destDistType)) {
1566 warpOp,
"Failed to distribute the dest vector type.");
1569 SmallVector<size_t> newRetIndices;
1571 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1574 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1576 Value newBroadcast = distributedSource;
1578 if (sourceElemOrDistType != destDistType.value()) {
1581 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1582 destDistType.value(), distributedSource);
1593 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1595 PatternRewriter &rewriter)
const override {
1596 OpOperand *yieldOperand =
1604 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1605 xegpu::DistributeLayoutAttr sourceLayout =
1607 xegpu::DistributeLayoutAttr resultLayout =
1609 if (!sourceLayout || !resultLayout)
1612 "the source or result of shape_cast op lacks distribution layout");
1616 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1617 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1618 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1620 warpOp,
"shape_cast is rank reducing but source layout is not a "
1621 "slice of result layout");
1622 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1624 warpOp,
"shape_cast is rank increasing but result layout is not a "
1625 "slice of source layout");
1627 FailureOr<VectorType> sourceDistTypeOrFailure =
1628 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1629 shapeCastOp.getSourceVectorType());
1630 if (
failed(sourceDistTypeOrFailure))
1632 warpOp,
"failed to get distributed vector type for source");
1633 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1635 SmallVector<size_t> newRetIndices;
1637 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1640 Value source = newWarpOp.getResult(newRetIndices[0]);
1642 Value newShapeCast = vector::ShapeCastOp::create(
1643 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1654struct VectorExtractStridedSliceDistribution
1656 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1658 PatternRewriter &rewriter)
const override {
1659 OpOperand *operand =
1660 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1666 auto distributedType =
1667 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1669 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1670 auto distributedDims =
1671 getDistributedDims(extractResultType, distributedType);
1675 VectorType updatedSourceType = extractOp.getSourceVectorType();
1676 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1677 extractOp.getSizes(), [](Attribute attr) { return attr; });
1678 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1679 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1680 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1681 extractOp.getStrides(), [](Attribute attr) { return attr; });
1685 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1686 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1688 extractOp.getSourceVectorType().getDimSize(i)));
1690 updatedStrides.push_back(
1696 if (distributedDims.size() > 0) {
1697 if (distributedDims.size() != 1)
1699 warpOp,
"Source can not be distributed in multiple dimensions.");
1700 int64_t distributedDim = distributedDims[0];
1701 int sourceDistrDimSize =
1702 extractOp.getSourceVectorType().getShape()[distributedDim];
1704 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1706 warpOp,
"the source of extract_strided_slice op lacks distribution "
1708 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1711 int subgroupSize = sourceLaneLayout[distributedDim];
1714 if (sourceDistrDimSize % subgroupSize != 0)
1717 "Source size along distributed dimension is not a multiple of "
1719 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1721 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1723 warpOp,
"Expecting unit lane data in source layout");
1726 int64_t distrDimOffset =
1727 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1728 if (distrDimOffset % subgroupSize != 0)
1730 warpOp,
"Offset along distributed dimension "
1731 "is not a multiple of subgroup size.");
1732 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1733 sourceLayout, extractOp.getSourceVectorType())
1737 distributedType.getDimSize(distributedDim));
1740 updatedOffsets[distributedDim] =
1745 SmallVector<size_t> newRetIndices;
1747 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1750 Value source = newWarpOp.getResult(newRetIndices[0]);
1752 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1753 rewriter, extractOp.getLoc(), distributedType, source,
1754 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1755 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1756 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1766struct VectorInsertStridedSliceDistribution
1768 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1770 PatternRewriter &rewriter)
const override {
1771 OpOperand *operand =
1772 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1778 auto distributedType =
1779 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1781 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1782 auto destDistributedDims =
1783 getDistributedDims(insertResultType, distributedType);
1787 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1788 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1789 VectorType updatedSourceType = insertOp.getSourceVectorType();
1790 VectorType updatedDestType = insertOp.getDestVectorType();
1791 if (destDistributedDims.size() > 0) {
1793 if (destDistributedDims.size() != 1)
1796 "Expecting source to be distributed in a single dimension.");
1797 int64_t destDistributedDim = destDistributedDims[0];
1799 VectorType srcType = insertOp.getSourceVectorType();
1800 VectorType destType = insertOp.getDestVectorType();
1804 int64_t sourceDistributedDim =
1805 destDistributedDim - (destType.getRank() - srcType.getRank());
1806 if (sourceDistributedDim < 0)
1809 "distributed dimension must be in the last k (i.e. source "
1810 "rank) dims of dest vector");
1811 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1815 if (!destLayout || !sourceLayout ||
1816 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1817 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1819 warpOp,
"the source or dest of insert_strided_slice op lacks "
1820 "distribution layout");
1824 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1827 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1828 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1829 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1830 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1832 warpOp,
"Expecting unit lane data in source and dest layouts");
1834 if (srcDistrDimSize % subgroupSize != 0)
1836 warpOp,
"Distributed dimension size in source is not a multiple of "
1840 int64_t destDistrDimOffset =
1841 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1842 if (destDistrDimOffset % subgroupSize != 0)
1845 "Offset along distributed dimension in dest is not a multiple of "
1848 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1849 sourceLayout, insertOp.getSourceVectorType())
1851 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1852 destLayout, insertOp.getDestVectorType())
1856 updatedOffsets[destDistributedDim] =
1861 SmallVector<size_t> newRetIndices;
1863 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1864 {updatedSourceType, updatedDestType}, newRetIndices);
1867 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1868 Value dest = newWarpOp.getResult(newRetIndices[1]);
1870 Value newInsertOp = vector::InsertStridedSliceOp::create(
1871 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1872 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1873 insertOp.getStrides());
1883struct MemrefExtractAlignedPointerAsIndexDistribution final
1885 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1886 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1887 PatternRewriter &rewriter)
const override {
1888 OpOperand *operand = getWarpResult(
1889 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1893 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1897 SmallVector<size_t> newRetIndices;
1898 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1899 rewriter, warpOp, extractOp.getSource(),
1900 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1902 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1903 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1904 newWarpOp.getResult(newRetIndices[0]));
1905 Value distributedVal = newWarpOp.getResult(operandIdx);
1917 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1918 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1919 PatternRewriter &rewriter)
const override {
1920 OpOperand *operand =
1921 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1924 warpOp,
"warp result is not a vector::BitCast op");
1927 VectorType distributedSourceType =
1928 getDistVecTypeBasedOnLaneLayout(
1930 bitcastOp.getSourceVectorType())
1931 .value_or(VectorType());
1932 if (!distributedSourceType)
1934 bitcastOp,
"Failed to distribute the source vector type in "
1935 "vector::BitCast op");
1936 VectorType distributedResultType =
1937 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1938 SmallVector<size_t> newRetIndices;
1939 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1940 rewriter, warpOp, bitcastOp.getSource(),
1941 TypeRange{distributedSourceType}, newRetIndices);
1943 auto newBitcastOp = vector::BitCastOp::create(
1944 rewriter, newWarpOp.getLoc(), distributedResultType,
1945 newWarpOp.getResult(newRetIndices[0]));
1946 Value distributedVal = newWarpOp.getResult(operandIdx);
1961 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1962 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1963 PatternRewriter &rewriter)
const override {
1964 OpOperand *operand =
1965 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1968 warpOp,
"warp result is not a vector::Transpose op");
1971 xegpu::DistributeLayoutAttr sourceLayout =
1973 xegpu::DistributeLayoutAttr resultLayout =
1975 if (!sourceLayout || !resultLayout)
1978 "the source or result vector of the transpose op lacks layout "
1980 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1981 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1984 if (sourceRank != 2 || resultRank != 2)
1986 transposeOp,
"the source or result vector of the transpose op "
1987 "does not have 2D layout");
1988 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1990 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1993 "the source or result vector layouts must be 2D transposes of each "
1995 FailureOr<VectorType> distributedSourceTypeOrFailure =
1996 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1997 transposeOp.getSourceVectorType());
1998 if (
failed(distributedSourceTypeOrFailure))
2000 transposeOp,
"Failed to distribute the source vector type in "
2001 "vector::Transpose op");
2002 SmallVector<size_t> newRetIndices;
2003 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2004 rewriter, warpOp, transposeOp.getVector(),
2005 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2007 auto newTransposeOp = vector::TransposeOp::create(
2008 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2010 Value distributedVal = newWarpOp.getResult(operandIdx);
2019struct XeGPUSubgroupDistributePass final
2021 XeGPUSubgroupDistributePass> {
2022 void runOnOperation()
override;
2028 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2029 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2030 GpuBarrierDistribution, VectorMultiReductionDistribution,
2031 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2032 VectorBitcastDistribution, LoadMatrixDistribution,
2033 StoreMatrixDistribution,
2034 MemrefExtractAlignedPointerAsIndexDistribution>(
2036 regularPatternBenefit);
2040 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2041 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
2043 highPatternBenefit);
2051void XeGPUSubgroupDistributePass::runOnOperation() {
2058 signalPassFailure();
2069 signalPassFailure();
2076 getOperation()->walk([&](Operation *op) {
2077 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2078 vector::moveScalarUniformCode(warpOp);
2087 auto distributionFn = [](Value val) {
2088 VectorType vecType = dyn_cast<VectorType>(val.getType());
2089 int64_t vecRank = vecType ? vecType.getRank() : 0;
2099 assert(layout.getRank() == vecRank &&
2100 "Expecting vector and layout rank to match");
2104 SmallVector<unsigned int> distributedDims;
2105 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2106 if (v > 1 && vecType.getShape()[i] % v == 0)
2107 distributedDims.push_back(i);
2113 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2114 int64_t warpSz) {
return Value(); };
2116 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2117 vector::CombiningKind kind, uint32_t size) {
2119 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2121 for (uint64_t i = 1; i < size; i <<= 1) {
2122 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2124 gpu::ShuffleMode::XOR)
2125 .getShuffleResult();
2131 vector::populateDistributeReduction(
2133 regularPatternBenefit);
2135 vector::populatePropagateWarpVectorDistributionPatterns(
2136 patterns, distributionFn, shuffleFn,
2137 regularPatternBenefit);
2139 signalPassFailure();
2149 bool foundWarpOp =
false;
2150 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2160 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2166 Value input = op.getOperand(0);
2167 Value output = op.getResult(0);
2170 xegpu::TensorDescType inputDescType =
2171 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2172 xegpu::TensorDescType outputDescType =
2173 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2174 assert(inputDescType && outputDescType &&
2175 "Unrealized conversion cast must have tensor descriptor types");
2180 if (inputDescType.getLayout()) {
2181 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2183 argument.setType(output.
getType());
2185 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2186 argument.getOwner()->getParentOp())) {
2187 auto result = loopOp.getTiedLoopResult(argument);
2196 if (outputDescType.getLayout())
2199 if (op->use_empty())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's regio...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0
StringRef getName() const