28 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38 static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49 static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] =
std::min(shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
95 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 op,
"sg_layout size must match the sg_id_range");
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
117 auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
118 if (
failed(maybeDescOffsets))
123 for (
const auto &sgOffsets : *maybeDescOffsets) {
126 offsetsList.push_back(std::move(newOffsets));
182 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
189 xegpu::TensorDescType tdescTy = op.getType();
191 Type elemTy = tdescTy.getElementType();
192 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
196 layout.dropSgLayoutAndData());
199 for (
auto offsets : offsetsList) {
200 auto newOp = xegpu::CreateNdDescOp::create(
201 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
202 op.getMixedSizes(), op.getMixedStrides());
204 newOps.push_back(newOp);
214 struct WgToSgCreateNdOpNoOffset
219 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
223 if (!op.getMixedOffsets().empty())
228 xegpu::TensorDescType tdescTy = op.getType();
229 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
230 if (!layout || !layout.isForWorkgroup())
233 Type elemTy = tdescTy.getElementType();
238 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
239 xegpu::TensorDescType newTdescTy =
241 layout.dropSgLayoutAndData());
244 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
245 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
246 op.getSource(), op.getMixedSizes(),
247 op.getMixedStrides());
259 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
261 if (!op.getMixedOffsets().empty())
265 for (
auto src : adaptor.getTensorDesc()) {
266 xegpu::TensorDescType tdescTy =
267 dyn_cast<xegpu::TensorDescType>(src.getType());
269 VectorType newResTy =
VectorType::get(srcShape, tdescTy.getElementType());
270 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
271 src, op->getAttrs());
272 newLoadOps.push_back(newLoadOp);
275 return mlir::success();
285 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
287 if (!op.getMixedOffsets().empty())
290 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
291 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
292 op.getL2HintAttr(), op.getL3HintAttr());
304 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
312 for (
auto [tdesc, offsets] :
313 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
314 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
315 VectorType newResTy =
317 auto newOp = xegpu::LoadNdOp::create(
318 rewriter, op.getLoc(), newResTy, tdesc, offsets,
319 nullptr,
nullptr, op.getL1HintAttr(),
320 op.getL2HintAttr(), op.getL3HintAttr());
321 newOps.push_back(newOp);
331 struct WgToSgStoreNdOpWithOffset
335 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
338 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
341 for (
auto [v, tdesc, offsets] :
342 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
343 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
344 op.getL1HintAttr(), op.getL2HintAttr(),
355 struct WgToSgPrefetchNdOpWithOffset
359 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
362 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
365 for (
auto [tdesc, offsets] :
366 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
367 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
368 op.getL1HintAttr(), op.getL2HintAttr(),
380 struct WgToSgUpdateNdOffsetOp
384 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
387 for (
auto tDesc : adaptor.getTensorDesc()) {
388 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
389 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
390 op.getConstOffsets());
391 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
403 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
406 VectorType resultTy = op.getResult().getType();
407 if (resultTy.getRank() != 2)
416 for (
auto aVec : adaptor.getLhs()) {
417 for (
auto bVec : adaptor.getRhs()) {
422 tmpC = adaptor.getAcc()[i++];
423 operands.push_back(tmpC);
427 llvm::cast<VectorType>(aVec.getType()).getShape();
429 llvm::cast<VectorType>(bVec.getType()).getShape();
431 resultTy.getElementType());
432 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
434 originalLayout.dropSgLayoutAndData());
436 newDpasOps.push_back(tmpC);
448 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
451 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
452 if ((offsetSize != 0) || op.getConstOffsetsAttr())
455 for (
auto src : adaptor.getTensorDesc())
456 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(),
TypeRange(), src,
464 struct WgToSgVectorBroadcastOp
469 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
472 VectorType resultType = op.getResult().getType();
475 xegpu::DistributeLayoutAttr layout =
477 if (!layout || !layout.isForWorkgroup())
481 VectorType newResultType =
484 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
488 for (
auto operand : adaptor.getOperands().front()) {
489 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
490 newResultType, operand);
491 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
492 !layout.getEffectiveInstDataAsInt().empty())
494 layout.dropSgLayoutAndData());
496 newBroadcastOps.push_back(newBroadcast.getResult());
516 assert(resultType &&
"Expected result to be a VectorType");
520 xegpu::DistributeLayoutAttr layout =
522 if (!layout || !layout.isForWorkgroup())
527 size_t numVariants = operands.empty() ? 0 : operands.front().size();
529 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
530 return operandVec.size() != numVariants;
535 VectorType newResultType =
538 for (
size_t i = 0; i < numVariants; ++i) {
540 for (
auto &operandVec : operands)
541 opOperands.push_back(operandVec[i]);
544 state.addOperands(opOperands);
545 state.addTypes(newResultType);
550 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
551 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
552 !layout.getEffectiveInstDataAsInt().empty())
553 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
555 state.addAttribute(attr.getName(), attr.getValue());
559 newResults.push_back(newOp->
getResult(0));
593 struct WgToSgConvertLayoutOp
597 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
600 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
601 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
603 if (!input || !target || !input.isForWorkgroup() ||
604 !target.isForWorkgroup())
606 op,
"Input and target layouts must have subgroup layout");
617 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
618 inputOrder != targetOrder)
621 input = input.dropSgLayoutAndData();
622 target = target.dropSgLayoutAndData();
625 if (input && target) {
628 auto newOp = xegpu::ConvertLayoutOp::create(
629 rewriter, op.getLoc(), src.getType(), src, input, target);
669 struct UnrealizedConversionCastOpPattern
675 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
679 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
680 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
682 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
683 !llvm::all_equal(
ValueRange(inputs).getTypes()))
691 if (op.getNumOperands() == 1 &&
692 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
704 if (op.getNumResults() == 1 &&
710 return mlir::failure();
719 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
721 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
722 auto vecType = dyn_cast<VectorType>(op.getType());
723 if (!vecAttr || !vecAttr.isSplat() || !vecType)
726 xegpu::DistributeLayoutAttr layout =
728 if (!layout || !layout.isForWorkgroup())
734 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
743 arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
744 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
745 !layout.getEffectiveInstDataAsInt().empty())
747 layout.dropSgLayoutAndData());
757 struct WgToSgLoadGatherOpWithOffset
761 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
764 if (!op.getOffsets())
768 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
773 xegpu::DistributeLayoutAttr layout =
775 if (!layout || !layout.isForWorkgroup())
781 auto offsetsVecType =
782 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
784 dyn_cast<VectorType>(adaptor.getMask().front().getType());
785 if (!offsetsVecType || !maskVecType ||
786 offsetsVecType.getShape() != maskVecType.getShape()) {
788 "offsets have not been distributed");
794 VectorType newTy =
VectorType::get(sgShape, resultType.getElementType());
795 for (
auto [offsets, mask] :
796 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
797 auto newLoadOp = xegpu::LoadGatherOp::create(
798 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
799 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
801 layout.dropSgLayoutAndData());
802 newLoadOps.push_back(newLoadOp);
811 struct WgToSgStoreScatterOpWithOffset
815 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
818 if (!op.getOffsets())
822 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
826 xegpu::DistributeLayoutAttr layout =
828 if (!layout || !layout.isForWorkgroup())
832 auto offsetsVecType =
833 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
835 dyn_cast<VectorType>(adaptor.getMask().front().getType());
836 if (!offsetsVecType || !maskVecType ||
837 offsetsVecType.getShape() != maskVecType.getShape()) {
839 "offsets have not been distributed");
842 auto chunkSizeOpt = op.getChunkSize();
843 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
845 for (
auto [val, offs, mask] : llvm::zip(
846 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
847 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
848 mask, chunkSizeAttr, op.getL1HintAttr(),
849 op.getL2HintAttr(), op.getL3HintAttr());
851 if (
auto newLayout = layout.dropSgLayoutAndData())
852 op->setAttr(
"layout", newLayout);
862 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
866 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
870 VectorType valueTy = op.getRes().getType();
871 Type elemTy = valueTy.getElementType();
873 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
877 for (
auto offsets : offsetsList) {
878 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
879 op.getMemDesc(), offsets,
880 layout.dropSgLayoutAndData());
881 newOps.push_back(newOp);
892 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
896 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
899 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
900 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
901 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
902 offsets, layout.dropSgLayoutAndData());
912 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
914 xegpu::DistributeLayoutAttr layout =
916 if (!layout || !layout.isForWorkgroup())
920 VectorType type = op.getResult().getType();
921 auto wgShape = type.getShape();
922 std::optional<SmallVector<int64_t>> sgShape =
923 getSgShapeAndCount(wgShape, layout).first;
928 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
929 auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
933 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
934 auto steps = vector::StepOp::create(rewriter, loc, newTy);
936 for (
auto offsets : *sgOffsets) {
939 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
941 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
942 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
943 !layout.getEffectiveInstDataAsInt().empty()) {
945 layout.dropSgLayoutAndData());
947 layout.dropSgLayoutAndData());
949 layout.dropSgLayoutAndData());
951 newOps.push_back(finalSteps);
960 struct WgToSgVectorShapeCastOp
965 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
968 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
973 xegpu::DistributeLayoutAttr layout =
975 if (!layout || !layout.isForWorkgroup())
979 VectorType newResultType =
983 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
991 for (int64_t d : src)
993 srcNonUnit.push_back(d);
994 for (int64_t d : dst)
996 dstNonUnit.push_back(d);
997 return srcNonUnit == dstNonUnit;
1000 if (!onlyUnitDims(srcType.getShape(), sgShape))
1005 int64_t sourceRank = srcType.getRank();
1006 int64_t resultRank = sgShape.size();
1007 xegpu::DistributeLayoutAttr sourceLayout =
1009 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1011 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1015 for (
auto src : adaptor.getSource()) {
1017 rewriter.
create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
1018 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1019 !layout.getEffectiveInstDataAsInt().empty())
1021 layout.dropSgLayoutAndData());
1022 newShapeCastOps.push_back(newShapeCast.getResult());
1035 struct WgToSgMultiDimReductionOp
1040 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1042 VectorType srcType = op.getSourceVectorType();
1043 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1047 auto srcShape = srcType.getShape();
1048 xegpu::DistributeLayoutAttr layout =
1050 if (!layout || !layout.isForWorkgroup())
1053 auto reductionDims = llvm::to_vector(op.getReductionDims());
1057 .getEffectiveSgLayoutAsInt();
1060 .getEffectiveSgDataAsInt();
1064 for (int64_t dim : reductionDims) {
1065 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1068 "sgLayout in each reduced dimension must be 1 and sgData in the "
1069 "reduced dim must match srcShape in that dim");
1074 VectorType newDstType =
1078 for (
auto sgSrc : adaptor.getSource()) {
1079 auto newOp = rewriter.
create<vector::MultiDimReductionOp>(
1080 op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
1081 op.getReductionDims());
1082 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1083 !layout.getEffectiveInstDataAsInt().empty())
1085 layout.dropSgLayoutAndData());
1086 newReductions.push_back(newOp.
getResult());
1100 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1101 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1102 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1103 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1104 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1105 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1106 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1107 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1108 WgToSgMultiDimReductionOp>(
patterns.getContext());
1114 struct XeGPUWgToSgDistributePass
1115 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1116 void runOnOperation()
override;
1120 void XeGPUWgToSgDistributePass::runOnOperation() {
1123 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1124 existingCastOps.push_back(castOp.getOperation());
1137 [&](RankedTensorType type,
1139 Type elemTy = type.getElementType();
1144 std::tie(subShape, count) = getSgShapeAndCount(
1146 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1149 result.append(count, newTy);
1165 [&](xegpu::TensorDescType type,
1167 Type elemTy = type.getElementType();
1172 xegpu::LayoutAttr layout = type.getLayoutAttr();
1173 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1176 layout = layout.dropSgLayoutAndData();
1179 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1180 result.append(count, newTy);
1184 auto getTensorDescType = [](
Operation *op) -> xegpu::TensorDescType {
1185 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1186 return createOp.getType();
1187 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1188 return loadOp.getTensorDescType();
1189 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1190 return storeOp.getTensorDescType();
1191 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1193 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1194 return prefetchOp.getTensorDescType();
1195 return xegpu::TensorDescType();
1198 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1199 return !layout || !layout.isForWorkgroup();
1202 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1203 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1204 xegpu::PrefetchNdOp>([=](
Operation *op) ->
bool {
1205 auto tdescTy = getTensorDescType(op);
1206 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1207 return isLegal(layout);
1210 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1212 return isLegal(layout);
1215 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1216 [=](xegpu::LoadMatrixOp op) ->
bool {
1217 return isLegal(op.getLayoutAttr());
1220 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1221 [=](xegpu::StoreMatrixOp op) ->
bool {
1222 return isLegal(op.getLayoutAttr());
1225 target.addDynamicallyLegalOp<arith::ConstantOp>(
1226 [=](arith::ConstantOp op) ->
bool {
1227 auto vecType = dyn_cast<VectorType>(op.getType());
1232 return isLegal(layout);
1235 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
1239 return isLegal(layout);
1242 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1243 [=](xegpu::LoadGatherOp op) ->
bool {
1245 return isLegal(layout);
1248 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1249 [=](xegpu::StoreScatterOp op) ->
bool {
1251 auto layout = op->getAttrOfType<xegpu::LayoutAttr>(
"layout");
1254 return isLegal(layout);
1257 target.addDynamicallyLegalOp<vector::BroadcastOp>(
1258 [=](vector::BroadcastOp op) ->
bool {
1262 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1263 [=](vector::MultiDimReductionOp op) ->
bool {
1267 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1268 [=](xegpu::ConvertLayoutOp op) ->
bool {
1269 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1272 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1273 [=](
Operation *op) -> std::optional<bool> {
1278 VectorType resultType =
1279 dyn_cast<VectorType>(op->getResult(0).getType());
1285 for (
Value operand : op->getOperands()) {
1286 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1287 if (!operandType || operandType.getShape() != resultType.getShape()) {
1292 xegpu::DistributeLayoutAttr layout =
1294 return isLegal(layout);
1297 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1298 [=](UnrealizedConversionCastOp op) {
1299 return llvm::is_contained(existingCastOps, op.getOperation());
1302 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
1309 return signalPassFailure();
1316 getOperation()->walk([](
Operation *op) {
1318 std::string name = xegpu::getLayoutName(result);
1319 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1320 op->removeAttr(name);
1321 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1322 if (auto newLayout = layout.dropSgLayoutAndData())
1323 op->setAttr(name, newLayout);
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
Attributes are known-constant values of operations.
IntegerAttr getI64IntegerAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.