28 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38 static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49 static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] =
std::min(shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
95 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 op,
"sg_layout size must match the sg_id_range");
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
117 auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
118 if (
failed(maybeDescOffsets))
123 for (
const auto &sgOffsets : *maybeDescOffsets) {
126 offsetsList.push_back(std::move(newOffsets));
182 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
189 xegpu::TensorDescType tdescTy = op.getType();
191 Type elemTy = tdescTy.getElementType();
192 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
196 layout.dropSgLayoutAndData());
199 for (
auto offsets : offsetsList) {
200 auto newOp = xegpu::CreateNdDescOp::create(
201 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
202 op.getMixedSizes(), op.getMixedStrides());
204 newOps.push_back(newOp);
214 struct WgToSgCreateNdOpNoOffset
219 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
223 if (!op.getMixedOffsets().empty())
228 xegpu::TensorDescType tdescTy = op.getType();
229 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
230 if (!layout || !layout.isForWorkgroup())
233 Type elemTy = tdescTy.getElementType();
238 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
239 xegpu::TensorDescType newTdescTy =
241 layout.dropSgLayoutAndData());
244 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
245 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
246 op.getSource(), op.getMixedSizes(),
247 op.getMixedStrides());
259 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
261 if (!op.getMixedOffsets().empty())
265 for (
auto src : adaptor.getTensorDesc()) {
266 xegpu::TensorDescType tdescTy =
267 dyn_cast<xegpu::TensorDescType>(src.getType());
269 VectorType newResTy =
VectorType::get(srcShape, tdescTy.getElementType());
270 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
271 src, op->getAttrs());
272 newLoadOps.push_back(newLoadOp);
275 return mlir::success();
285 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
287 if (!op.getMixedOffsets().empty())
290 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
291 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
292 op.getL2HintAttr(), op.getL3HintAttr());
304 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
312 for (
auto [tdesc, offsets] :
313 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
314 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
315 VectorType newResTy =
317 auto newOp = xegpu::LoadNdOp::create(
318 rewriter, op.getLoc(), newResTy, tdesc, offsets,
319 nullptr,
nullptr, op.getL1HintAttr(),
320 op.getL2HintAttr(), op.getL3HintAttr());
321 newOps.push_back(newOp);
331 struct WgToSgStoreNdOpWithOffset
335 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
338 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
341 for (
auto [v, tdesc, offsets] :
342 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
343 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
344 op.getL1HintAttr(), op.getL2HintAttr(),
355 struct WgToSgPrefetchNdOpWithOffset
359 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
362 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
365 for (
auto [tdesc, offsets] :
366 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
367 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
368 op.getL1HintAttr(), op.getL2HintAttr(),
380 struct WgToSgUpdateNdOffsetOp
384 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
387 for (
auto tDesc : adaptor.getTensorDesc()) {
388 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
389 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
390 op.getConstOffsets());
391 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
403 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
406 VectorType resultTy = op.getResult().getType();
407 if (resultTy.getRank() != 2)
416 for (
auto aVec : adaptor.getLhs()) {
417 for (
auto bVec : adaptor.getRhs()) {
422 tmpC = adaptor.getAcc()[i++];
423 operands.push_back(tmpC);
427 llvm::cast<VectorType>(aVec.getType()).getShape();
429 llvm::cast<VectorType>(bVec.getType()).getShape();
431 resultTy.getElementType());
432 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
434 originalLayout.dropSgLayoutAndData());
436 newDpasOps.push_back(tmpC);
448 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
451 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
452 if ((offsetSize != 0) || op.getConstOffsetsAttr())
455 for (
auto src : adaptor.getTensorDesc())
456 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(),
TypeRange(), src,
464 struct WgToSgVectorBroadcastOp
469 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
472 VectorType resultType = op.getResult().getType();
475 xegpu::DistributeLayoutAttr layout =
477 if (!layout || !layout.isForWorkgroup())
481 VectorType newResultType =
484 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
488 for (
auto operand : adaptor.getOperands().front()) {
489 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
490 newResultType, operand);
491 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
492 !layout.getEffectiveInstDataAsInt().empty())
494 layout.dropSgLayoutAndData());
496 newBroadcastOps.push_back(newBroadcast.getResult());
516 assert(resultType &&
"Expected result to be a VectorType");
520 xegpu::DistributeLayoutAttr layout =
522 if (!layout || !layout.isForWorkgroup())
527 size_t numVariants = operands.empty() ? 0 : operands.front().size();
529 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
530 return operandVec.size() != numVariants;
535 VectorType newResultType =
538 for (
size_t i = 0; i < numVariants; ++i) {
540 for (
auto &operandVec : operands)
541 opOperands.push_back(operandVec[i]);
544 state.addOperands(opOperands);
545 state.addTypes(newResultType);
550 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
551 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
552 !layout.getEffectiveInstDataAsInt().empty())
553 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
555 state.addAttribute(attr.getName(), attr.getValue());
559 newResults.push_back(newOp->
getResult(0));
593 struct WgToSgConvertLayoutOp
597 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
600 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
601 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
603 if (!input || !target || !input.isForWorkgroup() ||
604 !target.isForWorkgroup())
606 op,
"Input and target layouts must have subgroup layout");
617 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
618 inputOrder != targetOrder)
621 input = input.dropSgLayoutAndData();
622 target = target.dropSgLayoutAndData();
625 if (input && target) {
628 auto newOp = xegpu::ConvertLayoutOp::create(
629 rewriter, op.getLoc(), src.getType(), src, input, target);
669 struct UnrealizedConversionCastOpPattern
675 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
679 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
680 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
682 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
683 !llvm::all_equal(
ValueRange(inputs).getTypes()))
691 if (op.getNumOperands() == 1 &&
692 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
704 if (op.getNumResults() == 1 &&
710 return mlir::failure();
719 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
721 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
722 auto vecType = dyn_cast<VectorType>(op.getType());
723 if (!vecAttr || !vecType)
726 xegpu::DistributeLayoutAttr layout =
728 if (!layout || !layout.isForWorkgroup())
734 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
738 auto eltType = vecType.getElementType();
740 auto setLayoutIfNeeded = [&](
Value val) {
741 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
742 !layout.getEffectiveInstDataAsInt().empty()) {
744 layout.dropSgLayoutAndData());
748 if (vecAttr.isSplat()) {
752 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
753 setLayoutIfNeeded(cstOp->getResult(0));
756 }
else if (sgShape == wgShape) {
759 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
760 setLayoutIfNeeded(newConstOp->getResult(0));
767 if (!eltType.isIndex())
769 op,
"Unsupported element type for non-splat constant op.");
771 if (wgShape.size() > 2)
773 op,
"Only 1D & 2D vector constant supported");
776 int64_t rowStride = 0, colStride = 0;
777 int64_t
rows = wgShape.size() == 1 ? 1 : wgShape[0];
778 int64_t
cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
782 colStride = cast<IntegerAttr>(values[1]).getInt() -
783 cast<IntegerAttr>(values[0]).getInt();
786 rowStride = cast<IntegerAttr>(values[
cols]).getInt() -
787 cast<IntegerAttr>(values[0]).getInt();
790 for (int64_t r = 0; r <
rows; ++r) {
791 for (int64_t c = 0; c <
cols; ++c) {
792 int64_t idx = r *
cols + c;
794 if (c > 0 &&
cols > 1) {
795 int64_t prevIdx = r *
cols + (c - 1);
796 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
797 cast<IntegerAttr>(values[prevIdx]).getInt();
798 if (diff != colStride)
800 op,
"Non-constant column stride in constant op.");
803 if (r > 0 &&
rows > 1) {
804 int64_t prevIdx = (r - 1) *
cols + c;
805 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
806 cast<IntegerAttr>(values[prevIdx]).getInt();
807 if (diff != rowStride)
809 op,
"Non-constant row stride in constant op.");
818 int baseTileCols = sgShape[sgShape.size() - 1];
819 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
820 for (int64_t r = 0; r < baseTileRows; ++r) {
821 for (int64_t c = 0; c < baseTileCols; ++c) {
822 baseTileValues.push_back(values[r *
cols + c]);
828 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
832 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
834 auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
839 strideConsts.push_back(
843 strideConsts.begin(),
847 for (
auto offsets : *sgOffsets) {
850 for (
size_t i = 0; i < strideConsts.size(); ++i) {
852 arith::MulIOp::create(rewriter, loc, rewriter.
getIndexType(),
853 offsets[i], strideConsts[i]);
854 mulOffset = arith::AddIOp::create(
858 auto bcastOffset = vector::BroadcastOp::create(
859 rewriter, loc, baseConstVec.getType(), mulOffset);
861 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
862 setLayoutIfNeeded(baseConstVec);
863 setLayoutIfNeeded(bcastOffset);
864 setLayoutIfNeeded(finalConst);
865 newConstOps.push_back(finalConst);
875 struct WgToSgLoadGatherOpWithOffset
879 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
882 if (!op.getOffsets())
886 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
891 xegpu::DistributeLayoutAttr layout =
893 if (!layout || !layout.isForWorkgroup())
899 auto offsetsVecType =
900 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
902 dyn_cast<VectorType>(adaptor.getMask().front().getType());
903 if (!offsetsVecType || !maskVecType ||
904 offsetsVecType.getShape() != maskVecType.getShape()) {
906 "offsets have not been distributed");
912 VectorType newTy =
VectorType::get(sgShape, resultType.getElementType());
913 for (
auto [offsets, mask] :
914 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
915 auto newLoadOp = xegpu::LoadGatherOp::create(
916 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
917 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
919 layout.dropSgLayoutAndData());
920 newLoadOps.push_back(newLoadOp);
929 struct WgToSgStoreScatterOpWithOffset
933 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
936 if (!op.getOffsets())
940 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
944 xegpu::DistributeLayoutAttr layout =
946 if (!layout || !layout.isForWorkgroup())
950 auto offsetsVecType =
951 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
953 dyn_cast<VectorType>(adaptor.getMask().front().getType());
954 if (!offsetsVecType || !maskVecType ||
955 offsetsVecType.getShape() != maskVecType.getShape()) {
957 "offsets have not been distributed");
960 auto chunkSizeOpt = op.getChunkSize();
961 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
963 for (
auto [val, offs, mask] : llvm::zip(
964 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
965 auto store = xegpu::StoreScatterOp::create(
966 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
967 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
969 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
970 !layout.getEffectiveInstDataAsInt().empty()) {
971 for (
OpOperand &operand : store->getOpOperands()) {
973 if (operand.getOperandNumber() == 1)
987 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
991 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
995 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
996 assert(valueTy &&
"the value type must be vector type!");
997 Type elemTy = valueTy.getElementType();
999 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1003 for (
auto offsets : offsetsList) {
1004 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1005 op.getMemDesc(), offsets,
1006 layout.dropSgLayoutAndData());
1007 newOps.push_back(newOp);
1018 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1022 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1025 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1026 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1027 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1028 offsets, layout.dropSgLayoutAndData());
1038 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1040 xegpu::DistributeLayoutAttr layout =
1042 if (!layout || !layout.isForWorkgroup())
1046 VectorType type = op.getResult().getType();
1047 auto wgShape = type.getShape();
1048 std::optional<SmallVector<int64_t>> sgShape =
1049 getSgShapeAndCount(wgShape, layout).first;
1054 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1055 auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
1059 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1060 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1062 for (
auto offsets : *sgOffsets) {
1065 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1067 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1068 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1069 !layout.getEffectiveInstDataAsInt().empty()) {
1071 layout.dropSgLayoutAndData());
1073 layout.dropSgLayoutAndData());
1075 layout.dropSgLayoutAndData());
1077 newOps.push_back(finalSteps);
1086 struct WgToSgVectorShapeCastOp
1091 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1094 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1099 xegpu::DistributeLayoutAttr layout =
1101 if (!layout || !layout.isForWorkgroup())
1105 VectorType newResultType =
1109 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1117 for (int64_t d : src)
1119 srcNonUnit.push_back(d);
1120 for (int64_t d : dst)
1122 dstNonUnit.push_back(d);
1123 return srcNonUnit == dstNonUnit;
1126 if (!onlyUnitDims(srcType.getShape(), sgShape))
1131 int64_t sourceRank = srcType.getRank();
1132 int64_t resultRank = sgShape.size();
1133 xegpu::DistributeLayoutAttr sourceLayout =
1135 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1137 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1141 for (
auto src : adaptor.getSource()) {
1142 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1143 newResultType, src);
1144 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1145 !layout.getEffectiveInstDataAsInt().empty())
1147 layout.dropSgLayoutAndData());
1148 newShapeCastOps.push_back(newShapeCast.getResult());
1161 struct WgToSgMultiDimReductionOp
1166 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1168 VectorType srcType = op.getSourceVectorType();
1169 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1173 auto srcShape = srcType.getShape();
1174 xegpu::DistributeLayoutAttr layout =
1176 if (!layout || !layout.isForWorkgroup())
1179 auto reductionDims = llvm::to_vector(op.getReductionDims());
1183 .getEffectiveSgLayoutAsInt();
1186 .getEffectiveSgDataAsInt();
1190 for (int64_t dim : reductionDims) {
1191 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1194 "sgLayout in each reduced dimension must be 1 and sgData in the "
1195 "reduced dim must match srcShape in that dim");
1200 VectorType newDstType =
1204 for (
auto sgSrc : adaptor.getSource()) {
1205 auto newOp = vector::MultiDimReductionOp::create(
1206 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1207 adaptor.getAcc()[0], op.getReductionDims());
1208 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1209 !layout.getEffectiveInstDataAsInt().empty())
1211 layout.dropSgLayoutAndData());
1212 newReductions.push_back(newOp.
getResult());
1226 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1227 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1228 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1229 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1230 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1231 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1232 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1233 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1234 WgToSgMultiDimReductionOp>(
patterns.getContext());
1240 struct XeGPUWgToSgDistributePass
1241 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1242 void runOnOperation()
override;
1246 void XeGPUWgToSgDistributePass::runOnOperation() {
1249 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1250 existingCastOps.push_back(castOp.getOperation());
1263 [&](RankedTensorType type,
1265 Type elemTy = type.getElementType();
1270 std::tie(subShape, count) = getSgShapeAndCount(
1272 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1275 result.append(count, newTy);
1291 [&](xegpu::TensorDescType type,
1293 Type elemTy = type.getElementType();
1298 xegpu::LayoutAttr layout = type.getLayoutAttr();
1299 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1302 layout = layout.dropSgLayoutAndData();
1305 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1306 result.append(count, newTy);
1310 auto getTensorDescType = [](
Operation *op) -> xegpu::TensorDescType {
1311 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1312 return createOp.getType();
1313 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1314 return loadOp.getTensorDescType();
1315 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1316 return storeOp.getTensorDescType();
1317 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1319 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1320 return prefetchOp.getTensorDescType();
1321 return xegpu::TensorDescType();
1324 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1325 return !layout || !layout.isForWorkgroup();
1328 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1329 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1330 xegpu::PrefetchNdOp>([=](
Operation *op) ->
bool {
1331 auto tdescTy = getTensorDescType(op);
1332 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1333 return isLegal(layout);
1336 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1338 return isLegal(layout);
1341 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1342 [=](xegpu::LoadMatrixOp op) ->
bool {
1343 return isLegal(op.getLayoutAttr());
1346 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1347 [=](xegpu::StoreMatrixOp op) ->
bool {
1348 return isLegal(op.getLayoutAttr());
1351 target.addDynamicallyLegalOp<arith::ConstantOp>(
1352 [=](arith::ConstantOp op) ->
bool {
1353 auto vecType = dyn_cast<VectorType>(op.getType());
1358 return isLegal(layout);
1361 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
1365 return isLegal(layout);
1368 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1369 [=](xegpu::LoadGatherOp op) ->
bool {
1371 return isLegal(layout);
1374 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1375 [=](xegpu::StoreScatterOp op) ->
bool {
1377 return isLegal(layout);
1380 target.addDynamicallyLegalOp<vector::BroadcastOp>(
1381 [=](vector::BroadcastOp op) ->
bool {
1385 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1386 [=](vector::MultiDimReductionOp op) ->
bool {
1390 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1391 [=](xegpu::ConvertLayoutOp op) ->
bool {
1392 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1395 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1396 [=](
Operation *op) -> std::optional<bool> {
1401 VectorType resultType =
1402 dyn_cast<VectorType>(op->getResult(0).getType());
1408 for (
Value operand : op->getOperands()) {
1409 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1410 if (!operandType || operandType.getShape() != resultType.getShape()) {
1415 xegpu::DistributeLayoutAttr layout =
1417 return isLegal(layout);
1420 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1421 [=](UnrealizedConversionCastOp op) {
1422 return llvm::is_contained(existingCastOps, op.getOperation());
1425 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
1432 return signalPassFailure();
1439 getOperation()->walk([](
Operation *op) {
1441 std::string name = xegpu::getLayoutName(result);
1442 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1443 op->removeAttr(name);
1444 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1445 if (auto newLayout = layout.dropSgLayoutAndData())
1446 op->setAttr(name, newLayout);
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
Attributes are known-constant values of operations.
IntegerAttr getI64IntegerAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.