35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallVector.h"
41#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
42#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46#define DEBUG_TYPE "xegpu-subgroup-distribute"
47#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
52 "resolve_simt_type_mismatch";
65static constexpr unsigned regularPatternBenefit = 1;
66static constexpr unsigned highPatternBenefit = 2;
81static FailureOr<VectorType>
82getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
83 VectorType originalType) {
86 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
87 "Expecting a valid layout.");
89 layout.getEffectiveLaneLayoutAsInt();
90 assert(
static_cast<size_t>(originalType.getRank()) >=
91 effectiveLaneLayout.size() &&
92 "Rank of the original vector type should be greater or equal to the "
93 "size of the lane layout to distribute the vector type.");
97 unsigned distributionStart =
98 originalType.getRank() - effectiveLaneLayout.size();
99 for (
auto [i, dim] : llvm::enumerate(originalType.getShape())) {
100 if (i < distributionStart)
103 if (dim % effectiveLaneLayout[i - distributionStart] != 0)
105 distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
107 return VectorType::get(distributedShape, originalType.getElementType());
125static Value resolveDistributedTy(
Value orig, T expected,
128 if (orig.
getType() == expected)
131 if (isa<VectorType>(orig.
getType())) {
133 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
134 return castOp.getResult();
138 if (isa<xegpu::TensorDescType>(orig.
getType())) {
139 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
142 return castOp.getResult(0);
144 llvm_unreachable(
"Unsupported type for reconciliation");
151static bool requirePacked(
const xegpu::LayoutAttr layout) {
154 auto laneData = layout.getEffectiveLaneDataAsInt();
155 if (laneData.size() != 2)
157 return laneData[0] != 1;
161static bool requireTranspose(
const xegpu::LayoutAttr layout,
170 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171 if (laneLayout.size() != 2)
179 VectorType distributedType) {
180 assert(originalType.getRank() == distributedType.getRank() &&
181 "sequential and distributed vector types must have the same rank");
183 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
184 if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
185 distributedDims.push_back(i);
188 return distributedDims;
221 gpuFuncOp,
"Subgroup distribution requires target attribute attached "
222 "to set the warp size");
224 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
225 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
229 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
230 return isa<gpu::WarpExecuteOnLane0Op>(op);
235 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
238 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
240 auto newGpuFunc = gpu::GPUFuncOp::create(
241 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
243 privateAttributionsTypes);
244 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
248 auto laneId = gpu::LaneIdOp::create(
250 mlir::IntegerAttr());
251 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
252 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
253 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
255 newGpuFunc.getArgumentTypes());
256 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
259 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
261 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
262 origRetunOp.getOperands());
266 warpOp.getBodyRegion().begin());
270 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
271 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
309 using gpu::WarpDistributionPattern::WarpDistributionPattern;
310 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
313 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
316 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
320 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
323 descOp,
"the tensor descriptor lacks layout attribute");
325 if (descOp.getMixedOffsets().size())
327 descOp,
"xegpu::CreateNdDescOp must not have offsets");
331 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
332 rewriter, warpOp, descOp->getOperands(),
333 descOp.getOperandTypes(), newRetIndices);
336 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
338 xegpu::TensorDescType distributedTensorDescTy =
339 descOp.getType().dropLayouts();
341 Value newDescOp = xegpu::CreateNdDescOp::create(
342 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
345 Value distributedVal = newWarpOp.getResult(operandIdx);
348 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
387 using gpu::WarpDistributionPattern::WarpDistributionPattern;
388 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
390 gpu::YieldOp yield = warpOp.getTerminator();
391 Operation *lastNode = yield->getPrevNode();
392 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
400 "the store op must have offsets");
404 llvm::map_range(offsetsAsValues, [](
Value v) {
return v.
getType(); }));
405 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
406 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
409 storeOp,
"the source tensor descriptor lacks layout attribute");
411 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
412 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
413 if (failed(distributedTypeByWarpOpOrFailure))
415 "Failed to distribute the type");
416 VectorType distributedTypeByWarpOp =
417 distributedTypeByWarpOpOrFailure.value();
421 storeOp.getTensorDesc()};
423 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
424 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
425 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
426 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
436 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
438 if (failed(storeNdDistributedValueTyOrFailure))
440 storeOp,
"Failed to get distributed vector type for the store op");
441 newStoreOperands.push_back(resolveDistributedTy(
442 newWarpOp.getResult(newRetIndices[0]),
443 storeNdDistributedValueTyOrFailure.value(), rewriter));
446 xegpu::TensorDescType distributedTensorDescTy =
447 storeOp.getTensorDescType().dropLayouts();
448 newStoreOperands.push_back(
449 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
450 distributedTensorDescTy, rewriter));
452 for (
size_t i = 2; i < newRetIndices.size(); ++i)
453 newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
456 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
457 newStoreOperands, storeOp->getAttrs());
501 using gpu::WarpDistributionPattern::WarpDistributionPattern;
502 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
503 PatternRewriter &rewriter)
const override {
504 OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
505 if (!isa<xegpu::LoadNdOp>(op))
510 gpu::YieldOp yield = warpOp.getTerminator();
511 return yield->getPrevNode() == op;
516 warpOp,
"warp result is not a xegpu::LoadNd op");
522 loadOp,
"xegpu::LoadNdOp require target attribute attached to "
523 "determine transpose "
528 SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
531 "the load op must have offsets");
532 SmallVector<Value> offsetsAsValues =
534 SmallVector<Type> offsetTypes = llvm::to_vector(
535 llvm::map_range(offsetsAsValues, [](Value v) {
return v.
getType(); }));
537 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
538 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
541 loadOp,
"the source tensor descriptor lacks layout attribute");
544 VectorType distributedTypeByWarpOp =
545 cast<VectorType>(warpOp.getResult(operandIdx).getType());
547 SmallVector<size_t> newRetIndices;
548 SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
549 SmallVector<Type> newYieldedTypes = {tensorDescTy};
550 newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
551 newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
552 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
553 rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
558 FailureOr<VectorType> loadNdDistValueTyOrFailure =
560 if (
failed(loadNdDistValueTyOrFailure))
562 loadOp,
"Failed to get distributed vector type for the load op");
563 xegpu::TensorDescType distributedTensorDescTy =
564 loadOp.getTensorDescType().dropLayouts();
567 SmallVector<Value> newLoadOperands{
568 resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
569 distributedTensorDescTy, rewriter)};
571 for (
size_t i = 1; i < newRetIndices.size(); ++i)
572 newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
573 auto newLoadOp = xegpu::LoadNdOp::create(
574 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
575 newLoadOperands, loadOp->getAttrs());
578 newLoadOp.setPacked(requirePacked(layout));
580 if (requireTranspose(layout, uArch))
581 newLoadOp.setTranspose(
583 Value distributedVal = newWarpOp.getResult(operandIdx);
587 Value tyResolvedVal = resolveDistributedTy(
588 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
629 using gpu::WarpDistributionPattern::WarpDistributionPattern;
630 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
631 PatternRewriter &rewriter)
const override {
632 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
635 "warp result is not a xegpu::Dpas op");
643 xegpu::LayoutAttr layoutA =
644 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
645 xegpu::LayoutAttr layoutB =
646 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
647 xegpu::LayoutAttr layoutOut =
648 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
649 if (!layoutA || !layoutB || !layoutOut)
652 "the xegpu::Dpas op lacks layout attribute for A, B or output");
654 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
655 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
656 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
657 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
658 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
659 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
660 if (
failed(distLhsTypeByWarpOpOrFailure) ||
661 failed(distRhsTypeByWarpOpOrFailure) ||
662 failed(distResultTypeByWarpOpOrFailure))
665 "Failed to distribute the A, B or output types in xegpu::Dpas op");
667 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
669 llvm::SmallVector<Type, 3> newYieldTypes{
670 distLhsTypeByWarpOpOrFailure.value(),
671 distRhsTypeByWarpOpOrFailure.value()};
673 if (dpasOp.getAcc()) {
674 newYieldValues.push_back(dpasOp.getAcc());
675 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
678 SmallVector<size_t> newRetIndices;
679 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
680 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
682 FailureOr<VectorType> expectedDistLhsTyOrFailure =
684 FailureOr<VectorType> expectedDistRhsTyOrFailure =
686 FailureOr<VectorType> expectedDistResultTyOrFailure =
688 if (
failed(expectedDistLhsTyOrFailure) ||
689 failed(expectedDistRhsTyOrFailure) ||
690 failed(expectedDistResultTyOrFailure))
693 "Failed to get distributed vector type for the dpas operands.");
696 SmallVector<Value> newDpasOperands;
697 SmallVector<VectorType> newDpasOperandExpectedTypes;
700 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
701 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
702 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
704 newDpasOperandExpectedTypes.push_back(distributedResultTy);
706 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
707 newDpasOperands.push_back(
708 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
709 newDpasOperandExpectedTypes[i], rewriter));
711 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
712 distributedResultTy, newDpasOperands,
715 Value distributedVal = newWarpOp.getResult(operandIdx);
718 resolveDistributedTy(newDpasOp.getResult(),
719 distResultTypeByWarpOpOrFailure.value(), rewriter);
754 using gpu::WarpDistributionPattern::WarpDistributionPattern;
755 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
756 PatternRewriter &rewriter)
const override {
757 gpu::YieldOp yield = warpOp.getTerminator();
758 Operation *lastNode = yield->getPrevNode();
759 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
763 SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
767 "the prefetch op must have offsets");
768 SmallVector<Value> offsetsAsValues =
770 SmallVector<Type> offsetTypes = llvm::to_vector(
771 llvm::map_range(offsetsAsValues, [](Value v) {
return v.
getType(); }));
773 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
776 prefetchOp,
"the source tensor descriptor lacks layout attribute");
778 SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
779 SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
780 newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
781 newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
782 SmallVector<size_t> newRetIndices;
783 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
784 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
787 xegpu::TensorDescType newTensorDescTy =
788 prefetchOp.getTensorDescType().dropLayouts();
790 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
791 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
793 for (
size_t i = 1; i < newRetIndices.size(); ++i)
794 newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
795 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
796 newPrefetchOperands, prefetchOp->getAttrs());
806 using gpu::WarpDistributionPattern::WarpDistributionPattern;
807 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
808 PatternRewriter &rewriter)
const override {
809 gpu::YieldOp yield = warpOp.getTerminator();
810 Operation *lastNode = yield->getPrevNode();
812 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
817 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
818 barrierOp->getResultTypes(),
819 barrierOp->getOperands(), barrierOp->getAttrs());
850 using gpu::WarpDistributionPattern::WarpDistributionPattern;
851 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
852 PatternRewriter &rewriter)
const override {
853 Operation *lastNode = warpOp.getTerminator()->getPrevNode();
854 auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
857 auto offsets = storeScatterOp.getOffsets();
858 if (!offsets || !isa<VectorType>(offsets.getType()))
860 storeScatterOp,
"Store op must have a vector of offsets argument");
861 VectorType offsetsTy = cast<VectorType>(offsets.getType());
862 VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
863 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
865 "Expected 1D offsets and mask vector");
866 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
867 if (storeVecTy.getRank() > 2)
869 storeScatterOp,
"Expected at most 2D result at SG level");
871 std::string layoutPayloadName =
873 std::string layoutOffsetsName =
875 std::string layoutMaskName =
878 xegpu::LayoutAttr layoutPayload =
879 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
880 xegpu::LayoutAttr layoutOffsets =
881 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
882 xegpu::LayoutAttr layoutMask =
883 storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
885 FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
886 getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
887 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
888 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
889 FailureOr<VectorType> distMaskByWarpOpOrFailure =
890 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
891 if (
failed(distStoreVecByWarpOpOrFailure) ||
892 failed(distOffsetsByWarpOpOrFailure) ||
893 failed(distMaskByWarpOpOrFailure)) {
896 "Some vector operands have no layouts, using defaults instead.");
899 VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
901 VectorType expectedPayloadTy =
902 VectorType::get({distPayloadTyByWarpOp.getNumElements()},
903 distPayloadTyByWarpOp.getElementType());
905 SmallVector<size_t> newRetIndices;
906 SmallVector<Value> operands = storeScatterOp->getOperands();
907 SmallVector<Type> operandTypesToYield = {
908 distPayloadTyByWarpOp, operands[1].getType(),
909 distOffsetsByWarpOpOrFailure.value(),
910 distMaskByWarpOpOrFailure.value()};
912 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
913 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
914 SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
915 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
919 newStoreScatterOpOperands[0] = resolveDistributedTy(
920 newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
921 xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
922 rewriter, newWarpOp.getLoc(),
TypeRange{}, newStoreScatterOpOperands,
923 storeScatterOp->getAttrs());
925 rewriter.
eraseOp(storeScatterOp);
935 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
938 assert(maybeCoords.value().size() == 1 &&
939 "Expected one set of distributed offsets");
943 newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
949 using gpu::WarpDistributionPattern::WarpDistributionPattern;
950 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
951 PatternRewriter &rewriter)
const override {
952 gpu::YieldOp yield = warpOp.getTerminator();
953 Operation *lastNode = yield->getPrevNode();
954 auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
958 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
959 return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
961 if (!producedByLastLoad)
963 warpOp,
"The last op is not xegpu::LoadMatrixOp");
966 VectorType sgPayloadTy =
967 dyn_cast<VectorType>(matrixOp.getResult().getType());
968 VectorType warpResultTy =
969 cast<VectorType>(warpOp.getResult(operandIdx).getType());
972 matrixOp,
"the matrix op payload must be a vector type");
974 auto loc = matrixOp.getLoc();
975 auto offsets = matrixOp.getMixedOffsets();
978 "the load op must have offsets");
979 SmallVector<Value> offsetsAsValues =
982 auto layout = matrixOp.getLayoutAttr();
985 matrixOp,
"the matrix operation lacks layout attribute");
987 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
988 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
989 if (
failed(distPayloadByWarpOpOrFailure))
991 matrixOp,
"Failed to distribute matrix op payload based on layout.");
993 SmallVector<Value> operands = {matrixOp.getMemDesc()};
994 const unsigned offsetsStartIdx = operands.size();
995 operands.append(offsetsAsValues);
997 SmallVector<Type> operandTypes = llvm::to_vector(
998 llvm::map_range(operands, [](Value v) {
return v.
getType(); }));
1000 SmallVector<size_t> newRetIndices;
1001 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1002 rewriter, warpOp, operands, operandTypes, newRetIndices);
1003 SmallVector<Value> newOperands = llvm::map_to_vector(
1004 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1006 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1007 ShapedType::kDynamic);
1011 ValueRange(newOperands).drop_front(offsetsStartIdx);
1013 SmallVector<Value> newCoords = currentOffsets;
1016 if (!matrixOp.getSubgroupBlockIoAttr()) {
1017 newCoords = computeDistributedCoordinatesForMatrixOp(
1018 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1021 xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
1022 rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
1023 newOperands[0],
ValueRange(newCoords), newConstOffsetsAttr,
1024 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1027 newWarpOp.getResult(operandIdx),
1028 resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1035 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1036 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1037 PatternRewriter &rewriter)
const override {
1038 gpu::YieldOp yield = warpOp.getTerminator();
1039 Operation *lastNode = yield->getPrevNode();
1040 auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
1044 VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
1047 matrixOp,
"the matrix op payload must be a vector type");
1049 auto loc = matrixOp.getLoc();
1050 auto offsets = matrixOp.getMixedOffsets();
1051 if (offsets.empty())
1053 "the store op must have offsets");
1054 SmallVector<Value> offsetsAsValues =
1057 auto layout = matrixOp.getLayoutAttr();
1060 matrixOp,
"the matrix operation lacks layout attribute");
1062 FailureOr<VectorType> distPayloadByWarpOpOrFailure =
1063 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
1064 if (
failed(distPayloadByWarpOpOrFailure))
1066 matrixOp,
"Failed to distribute matrix op payload based on layout.");
1068 SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
1069 const unsigned offsetsStartIdx = operands.size();
1070 operands.append(offsetsAsValues);
1072 SmallVector<Type> operandTypes = llvm::to_vector(
1073 llvm::map_range(operands, [](Value v) {
return v.
getType(); }));
1074 operandTypes[0] = *distPayloadByWarpOpOrFailure;
1076 SmallVector<size_t> newRetIndices;
1077 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1078 rewriter, warpOp, operands, operandTypes, newRetIndices);
1079 SmallVector<Value> newOperands = llvm::map_to_vector(
1080 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1082 SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
1083 ShapedType::kDynamic);
1087 ValueRange(newOperands).drop_front(offsetsStartIdx);
1089 SmallVector<Value> newCoords = currentOffsets;
1092 if (!matrixOp.getSubgroupBlockIoAttr()) {
1093 newCoords = computeDistributedCoordinatesForMatrixOp(
1094 rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
1098 xegpu::StoreMatrixOp::create(
1099 rewriter, loc,
TypeRange{}, newOperands[0], newOperands[1],
1101 matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter)
const override {
1130 OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
1133 return isa<xegpu::LoadGatherOp>(op) &&
1134 warpOp.getTerminator()->getPrevNode() == op;
1136 if (!producedByLastLoad)
1138 warpOp,
"The last op is not xegpu::LoadGatherOp");
1142 auto offsets = loadGatherOp.getOffsets();
1143 if (!offsets || !isa<VectorType>(offsets.getType()) ||
1144 !isa<VectorType>(loadGatherOp.getMask().getType()))
1147 "Load op must have a vector arguments for offsets and mask");
1148 VectorType offsetsTy = cast<VectorType>(offsets.getType());
1149 VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
1150 if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
1152 "Expected 1D offsets and mask vector");
1154 std::string layoutOffsetsName =
1156 std::string layoutMaskName =
1159 xegpu::LayoutAttr layoutOffsets =
1160 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1161 xegpu::LayoutAttr layoutMask =
1162 loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1164 FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1165 getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1166 FailureOr<VectorType> distMaskByWarpOpOrFailure =
1167 getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1168 if (
failed(distOffsetsByWarpOpOrFailure) ||
1169 failed(distMaskByWarpOpOrFailure)) {
1172 "Some vector operands have no layouts, using defaults instead.");
1175 SmallVector<size_t> newRetIndices;
1176 SmallVector<Value> operands = loadGatherOp->getOperands();
1177 SmallVector<Type> operandTypesToYield = {
1178 operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1179 distMaskByWarpOpOrFailure.value()};
1182 VectorType distResultTy =
1183 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1185 VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
1186 distResultTy.getElementType());
1188 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1189 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1191 SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1192 newRetIndices, [&](
size_t idx) {
return newWarpOp.getResult(idx); });
1195 xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1196 rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1197 loadGatherOp->getAttrs());
1199 Value distributedVal = newWarpOp.getResult(operandIdx);
1203 resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1212 vector::CombiningKind kind,
1216 assert(src.getType().getRank() == 2 &&
"expected a 2D source vector");
1217 VectorType sourceType = src.getType();
1218 int64_t sourceH = sourceType.getShape()[0];
1219 int64_t sourceW = sourceType.getShape()[1];
1220 int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1222 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
1223 Value reductionResult = arith::ConstantOp::create(
1224 rewriter, loc,
acc.getType(),
1231 for (
int i = 0; i < nSlices; ++i) {
1233 if (reductionDim == 1) {
1234 sliceOffsets = {i, 0};
1235 sliceSizes = {1, sourceW};
1237 sliceOffsets = {0, i};
1238 sliceSizes = {sourceH, 1};
1240 vector::ExtractStridedSliceOp extractOp =
1241 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1242 sliceSizes, {1, 1});
1243 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1244 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1246 VectorType::get({nSliceElements}, sourceType.getElementType()),
1247 extractOp.getResult());
1258 Value accExtract = vector::ExtractOp::create(rewriter, loc,
acc, i);
1259 Value reduction = vector::ReductionOp::create(
1260 rewriter, loc, kind, slice.getResult(), accExtract);
1262 vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1264 return reductionResult;
1323 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1324 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1325 PatternRewriter &rewriter)
const override {
1326 OpOperand *yieldOperand =
1327 getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1333 VectorType sourceType = reductionOp.getSourceVectorType();
1335 if (sourceType.getRank() != 2)
1337 "Only 2D reductions are supported.");
1338 ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1341 if (reductionDims.size() != 1)
1343 warpOp,
"Only 1 reduction dimension is supported.");
1344 int64_t reductionDim = reductionDims[0];
1345 VectorType distributedResultType =
1346 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1347 VectorType resultType = cast<VectorType>(reductionOp.getType());
1348 xegpu::DistributeLayoutAttr sourceLayout =
1351 FailureOr<VectorType> sourceDistTypeOrFailure =
1352 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1353 if (
failed(sourceDistTypeOrFailure))
1355 warpOp,
"Failed to distribute the source vector type.");
1356 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1358 bool dim0Distributed =
1359 sourceDistType.getShape()[0] != sourceType.getShape()[0];
1360 bool dim1Distributed =
1361 sourceDistType.getShape()[1] != sourceType.getShape()[1];
1362 if (dim0Distributed && dim1Distributed)
1364 warpOp,
"Expecting source to be distributed in a single dimension.");
1365 int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1366 if (sourceDistDim == -1)
1368 warpOp,
"Expecting a distributed source vector.");
1369 bool resultDistributed =
1370 distributedResultType.getNumElements() < resultType.getNumElements();
1384 bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1385 (sourceDistDim == 1 && reductionDim == 0);
1386 if (isReductionLaneLocal && !resultDistributed)
1388 warpOp,
"Expecting a distributed result for lane-local reduction.");
1390 if (!isReductionLaneLocal && resultDistributed)
1393 "Expecting a broadcasted result for non-lane-local reduction.");
1397 if (isReductionLaneLocal) {
1399 SmallVector<size_t> newRetIndices;
1400 auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1401 rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1402 {sourceDistType, distributedResultType}, newRetIndices);
1404 Value
result = lowerToVectorReductions(
1407 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1416 Value
result = lowerToVectorReductions(
1419 reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1495 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1497 PatternRewriter &rewriter)
const override {
1498 OpOperand *yieldOperand =
1506 VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
1507 VectorType destType =
1508 dyn_cast<VectorType>(broadcastOp.getResult().getType());
1510 xegpu::DistributeLayoutAttr sourceLayout =
1512 xegpu::DistributeLayoutAttr resultLayout =
1515 FailureOr<VectorType> sourceDistType;
1516 Type sourceElemOrDistType;
1520 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1523 bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
1527 "Broadcast input layout must be a slice of result layout.");
1530 if (rankDiff == 0) {
1532 broadcastOp.computeBroadcastedUnitDims();
1533 resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
1534 bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
1537 warpOp,
"For same-rank broadcast, source must be identical to "
1538 "adjusted result layouts with unit dims.");
1539 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1543 getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1544 if (
failed(sourceDistType)) {
1546 warpOp,
"Failed to distribute the source vector type.");
1548 sourceElemOrDistType = sourceDistType.value();
1554 warpOp,
"Broadcast from scalar must not have a layout attribute.");
1556 sourceElemOrDistType = broadcastOp.getSourceType();
1558 FailureOr<VectorType> destDistType =
1559 getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
1560 if (
failed(destDistType)) {
1562 warpOp,
"Failed to distribute the dest vector type.");
1565 SmallVector<size_t> newRetIndices;
1567 rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
1570 Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
1572 Value newBroadcast = distributedSource;
1574 if (sourceElemOrDistType != destDistType.value()) {
1577 vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
1578 destDistType.value(), distributedSource);
1589 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1591 PatternRewriter &rewriter)
const override {
1592 OpOperand *yieldOperand =
1600 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1601 xegpu::DistributeLayoutAttr sourceLayout =
1603 xegpu::DistributeLayoutAttr resultLayout =
1605 if (!sourceLayout || !resultLayout)
1608 "the source or result of shape_cast op lacks distribution layout");
1612 int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1613 int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1614 if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1616 warpOp,
"shape_cast is rank reducing but source layout is not a "
1617 "slice of result layout");
1618 if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1620 warpOp,
"shape_cast is rank increasing but result layout is not a "
1621 "slice of source layout");
1623 FailureOr<VectorType> sourceDistTypeOrFailure =
1624 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1625 shapeCastOp.getSourceVectorType());
1626 if (
failed(sourceDistTypeOrFailure))
1628 warpOp,
"failed to get distributed vector type for source");
1629 VectorType sourceDistType = sourceDistTypeOrFailure.value();
1631 SmallVector<size_t> newRetIndices;
1633 rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1636 Value source = newWarpOp.getResult(newRetIndices[0]);
1638 Value newShapeCast = vector::ShapeCastOp::create(
1639 rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1650struct VectorExtractStridedSliceDistribution
1652 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1654 PatternRewriter &rewriter)
const override {
1655 OpOperand *operand =
1656 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1662 auto distributedType =
1663 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1665 auto extractResultType = cast<VectorType>(operand->
get().
getType());
1666 auto distributedDims =
1667 getDistributedDims(extractResultType, distributedType);
1671 VectorType updatedSourceType = extractOp.getSourceVectorType();
1672 SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
1673 extractOp.getSizes(), [](Attribute attr) { return attr; });
1674 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1675 extractOp.getOffsets(), [](Attribute attr) { return attr; });
1676 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1677 extractOp.getStrides(), [](Attribute attr) { return attr; });
1681 int64_t sourceRank = extractOp.getSourceVectorType().getRank();
1682 for (int64_t i = extractOp.getSizes().size(); i < sourceRank; ++i) {
1684 extractOp.getSourceVectorType().getDimSize(i)));
1686 updatedStrides.push_back(
1692 if (distributedDims.size() > 0) {
1693 if (distributedDims.size() != 1)
1695 warpOp,
"Source can not be distributed in multiple dimensions.");
1696 int64_t distributedDim = distributedDims[0];
1697 int sourceDistrDimSize =
1698 extractOp.getSourceVectorType().getShape()[distributedDim];
1701 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1703 warpOp,
"the source of extract_strided_slice op lacks distribution "
1705 auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1708 int subgroupSize = sourceLaneLayout[distributedDim];
1711 if (sourceDistrDimSize % subgroupSize != 0)
1714 "Source size along distributed dimension is not a multiple of "
1716 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1718 if (!llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1720 warpOp,
"Expecting unit lane data in source layout");
1723 int64_t distrDimOffset =
1724 cast<IntegerAttr>(updatedOffsets[distributedDim]).getInt();
1725 if (distrDimOffset % subgroupSize != 0)
1727 warpOp,
"Offset along distributed dimension "
1728 "is not a multiple of subgroup size.");
1729 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1730 sourceLayout, extractOp.getSourceVectorType())
1734 distributedType.getDimSize(distributedDim));
1737 updatedOffsets[distributedDim] =
1742 SmallVector<size_t> newRetIndices;
1744 rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
1747 Value source = newWarpOp.getResult(newRetIndices[0]);
1749 Value newExtractOp = vector::ExtractStridedSliceOp::create(
1750 rewriter, extractOp.getLoc(), distributedType, source,
1751 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1752 ArrayAttr::get(rewriter.
getContext(), updatedSizes),
1753 ArrayAttr::get(rewriter.
getContext(), updatedStrides));
1763struct VectorInsertStridedSliceDistribution
1765 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1767 PatternRewriter &rewriter)
const override {
1768 OpOperand *operand =
1769 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1775 auto distributedType =
1776 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1778 auto insertResultType = cast<VectorType>(operand->
get().
getType());
1779 auto destDistributedDims =
1780 getDistributedDims(insertResultType, distributedType);
1784 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1785 insertOp.getOffsets(), [](Attribute attr) { return attr; });
1786 VectorType updatedSourceType = insertOp.getSourceVectorType();
1787 VectorType updatedDestType = insertOp.getDestVectorType();
1788 if (destDistributedDims.size() > 0) {
1790 if (destDistributedDims.size() != 1)
1793 "Expecting source to be distributed in a single dimension.");
1794 int64_t destDistributedDim = destDistributedDims[0];
1796 VectorType srcType = insertOp.getSourceVectorType();
1797 VectorType destType = insertOp.getDestVectorType();
1801 int64_t sourceDistributedDim =
1802 destDistributedDim - (destType.getRank() - srcType.getRank());
1803 if (sourceDistributedDim < 0)
1806 "distributed dimension must be in the last k (i.e. source "
1807 "rank) dims of dest vector");
1808 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1814 if (!destLayout || !sourceLayout ||
1815 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1816 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1818 warpOp,
"the source or dest of insert_strided_slice op lacks "
1819 "distribution layout");
1823 destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
1826 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1827 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1828 if (!llvm::all_of(destLaneData, [](int64_t v) {
return v == 1; }) ||
1829 !llvm::all_of(sourceLaneData, [](int64_t v) {
return v == 1; }))
1831 warpOp,
"Expecting unit lane data in source and dest layouts");
1833 if (srcDistrDimSize % subgroupSize != 0)
1835 warpOp,
"Distributed dimension size in source is not a multiple of "
1839 int64_t destDistrDimOffset =
1840 cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
1841 if (destDistrDimOffset % subgroupSize != 0)
1844 "Offset along distributed dimension in dest is not a multiple of "
1847 updatedSourceType = getDistVecTypeBasedOnLaneLayout(
1848 sourceLayout, insertOp.getSourceVectorType())
1850 updatedDestType = getDistVecTypeBasedOnLaneLayout(
1851 destLayout, insertOp.getDestVectorType())
1855 updatedOffsets[destDistributedDim] =
1860 SmallVector<size_t> newRetIndices;
1862 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1863 {updatedSourceType, updatedDestType}, newRetIndices);
1866 Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
1867 Value dest = newWarpOp.getResult(newRetIndices[1]);
1869 Value newInsertOp = vector::InsertStridedSliceOp::create(
1870 rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
1871 ArrayAttr::get(rewriter.
getContext(), updatedOffsets),
1872 insertOp.getStrides());
1882struct MemrefExtractAlignedPointerAsIndexDistribution final
1884 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1885 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1886 PatternRewriter &rewriter)
const override {
1887 OpOperand *operand = getWarpResult(
1888 warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1892 "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1896 SmallVector<size_t> newRetIndices;
1897 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1898 rewriter, warpOp, extractOp.getSource(),
1899 TypeRange{extractOp.getSource().getType()}, newRetIndices);
1901 auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1902 rewriter, newWarpOp.getLoc(), extractOp.getType(),
1903 newWarpOp.getResult(newRetIndices[0]));
1904 Value distributedVal = newWarpOp.getResult(operandIdx);
1916 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1917 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1918 PatternRewriter &rewriter)
const override {
1919 OpOperand *operand =
1920 getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1923 warpOp,
"warp result is not a vector::BitCast op");
1926 VectorType distributedSourceType =
1927 getDistVecTypeBasedOnLaneLayout(
1929 bitcastOp.getSourceVectorType())
1930 .value_or(VectorType());
1931 if (!distributedSourceType)
1933 bitcastOp,
"Failed to distribute the source vector type in "
1934 "vector::BitCast op");
1935 VectorType distributedResultType =
1936 cast<VectorType>(warpOp.getResult(operandIdx).getType());
1937 SmallVector<size_t> newRetIndices;
1938 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1939 rewriter, warpOp, bitcastOp.getSource(),
1940 TypeRange{distributedSourceType}, newRetIndices);
1942 auto newBitcastOp = vector::BitCastOp::create(
1943 rewriter, newWarpOp.getLoc(), distributedResultType,
1944 newWarpOp.getResult(newRetIndices[0]));
1945 Value distributedVal = newWarpOp.getResult(operandIdx);
1960 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1961 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1962 PatternRewriter &rewriter)
const override {
1963 OpOperand *operand =
1964 getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1967 warpOp,
"warp result is not a vector::Transpose op");
1970 xegpu::DistributeLayoutAttr sourceLayout =
1972 xegpu::DistributeLayoutAttr resultLayout =
1974 if (!sourceLayout || !resultLayout)
1977 "the source or result vector of the transpose op lacks layout "
1979 int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1980 int64_t resultRank = transposeOp.getResultVectorType().getRank();
1983 if (sourceRank != 2 || resultRank != 2)
1985 transposeOp,
"the source or result vector of the transpose op "
1986 "does not have 2D layout");
1987 ArrayRef<int64_t> perm = transposeOp.getPermutation();
1989 if (!resultLayout.isTransposeOf(sourceLayout, perm))
1992 "the source or result vector layouts must be 2D transposes of each "
1994 FailureOr<VectorType> distributedSourceTypeOrFailure =
1995 getDistVecTypeBasedOnLaneLayout(sourceLayout,
1996 transposeOp.getSourceVectorType());
1997 if (
failed(distributedSourceTypeOrFailure))
1999 transposeOp,
"Failed to distribute the source vector type in "
2000 "vector::Transpose op");
2001 SmallVector<size_t> newRetIndices;
2002 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2003 rewriter, warpOp, transposeOp.getVector(),
2004 TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
2006 auto newTransposeOp = vector::TransposeOp::create(
2007 rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
2009 Value distributedVal = newWarpOp.getResult(operandIdx);
2018struct XeGPUSubgroupDistributePass final
2020 XeGPUSubgroupDistributePass> {
2021 void runOnOperation()
override;
2027 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
2028 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
2029 GpuBarrierDistribution, VectorMultiReductionDistribution,
2030 LoadDistribution, StoreDistribution, VectorTransposeDistribution,
2031 VectorBitcastDistribution, LoadMatrixDistribution,
2032 StoreMatrixDistribution,
2033 MemrefExtractAlignedPointerAsIndexDistribution>(
2035 regularPatternBenefit);
2039 .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
2040 VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
2042 highPatternBenefit);
2050void XeGPUSubgroupDistributePass::runOnOperation() {
2059 if (!isa<VectorType>(operand.
get().
getType()))
2061 if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
2066 op->
emitError(
"Could not find layout attribute for operand ")
2068 signalPassFailure();
2081 signalPassFailure();
2088 getOperation()->walk([&](Operation *op) {
2089 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
2090 vector::moveScalarUniformCode(warpOp);
2099 auto distributionFn = [](Value val) {
2100 VectorType vecType = dyn_cast<VectorType>(val.getType());
2101 int64_t vecRank = vecType ? vecType.getRank() : 0;
2111 assert(layout.getRank() == vecRank &&
2112 "Expecting vector and layout rank to match");
2116 SmallVector<unsigned int> distributedDims;
2117 for (
auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
2118 if (v > 1 && vecType.getShape()[i] % v == 0)
2119 distributedDims.push_back(i);
2125 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
2126 int64_t warpSz) {
return Value(); };
2128 auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
2129 vector::CombiningKind kind, uint32_t size) {
2131 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
2133 for (uint64_t i = 1; i < size; i <<= 1) {
2134 Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
2136 gpu::ShuffleMode::XOR)
2137 .getShuffleResult();
2143 vector::populateDistributeReduction(
2145 regularPatternBenefit);
2147 vector::populatePropagateWarpVectorDistributionPatterns(
2148 patterns, distributionFn, shuffleFn,
2149 regularPatternBenefit);
2151 signalPassFailure();
2161 bool foundWarpOp =
false;
2162 getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
2172 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
2178 Value input = op.getOperand(0);
2179 Value output = op.getResult(0);
2182 xegpu::TensorDescType inputDescType =
2183 mlir::dyn_cast<xegpu::TensorDescType>(input.
getType());
2184 xegpu::TensorDescType outputDescType =
2185 mlir::dyn_cast<xegpu::TensorDescType>(output.
getType());
2186 assert(inputDescType && outputDescType &&
2187 "Unrealized conversion cast must have tensor descriptor types");
2192 if (inputDescType.getLayout()) {
2193 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
2195 argument.setType(output.
getType());
2197 if (
auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
2198 argument.getOwner()->getParentOp())) {
2199 auto result = loopOp.getTiedLoopResult(argument);
2208 if (outputDescType.getLayout())
2211 if (op->use_empty())
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
virtual int getSubgroupSize() const =0
StringRef getName() const