29#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
42 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
43 parent->
getAttr(
"sg_id_range")))
50static std::pair<SmallVector<int64_t>,
int>
52 xegpu::DistributeLayoutAttr layout) {
55 auto distributedShape = layout.computeDistributedShape(
57 if (
failed(distributedShape))
58 return std::make_pair(sgShape, count);
59 auto sgData = layout.getEffectiveSgDataAsInt();
61 return std::make_pair(sgData, count);
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
72 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
79 if (origOffsets.empty())
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
88 layout = op.getDescLayoutAttr();
92 if (!layout || !layout.isForWorkgroup())
96 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op,
"sg_layout size must match the sg_id_range");
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (
failed(maybeDescOffsets))
125 for (
const auto &sgOffsets : *maybeDescOffsets) {
128 offsetsList.push_back(std::move(newOffsets));
180struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
181 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
184 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
186 SmallVector<SmallVector<OpFoldResult>> offsetsList;
187 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
190 MLIRContext *ctx = op.getContext();
191 xegpu::TensorDescType tdescTy = op.getType();
192 ArrayRef<int64_t> wgShape = tdescTy.getShape();
193 Type elemTy = tdescTy.getElementType();
194 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
195 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
197 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
198 layout.dropSgLayoutAndData());
200 SmallVector<Value> newOps;
201 for (
auto offsets : offsetsList) {
202 auto newOp = xegpu::CreateNdDescOp::create(
203 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
204 op.getMixedSizes(), op.getMixedStrides());
206 newOps.push_back(newOp);
208 rewriter.replaceOpWithMultiple(op, {newOps});
216struct WgToSgCreateNdOpNoOffset
217 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
218 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
221 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
222 ConversionPatternRewriter &rewriter)
const override {
225 if (!op.getMixedOffsets().empty())
228 Location loc = op.getLoc();
229 MLIRContext *ctx = op.getContext();
230 xegpu::TensorDescType tdescTy = op.getType();
231 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
232 if (!layout || !layout.isForWorkgroup())
235 Type elemTy = tdescTy.getElementType();
236 ArrayRef<int64_t> wgShape = tdescTy.getShape();
238 SmallVector<int64_t> sgShape;
240 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
241 xegpu::TensorDescType newTdescTy =
242 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
243 layout.dropSgLayoutAndData());
245 SmallVector<Value> newCreateNdOps(count);
246 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
247 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
248 op.getSource(), op.getMixedSizes(),
249 op.getMixedStrides());
252 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
258struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
259 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
261 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 if (!op.getMixedOffsets().empty())
266 SmallVector<Value> newLoadOps;
267 for (
auto src : adaptor.getTensorDesc()) {
268 xegpu::TensorDescType tdescTy =
269 dyn_cast<xegpu::TensorDescType>(src.getType());
270 ArrayRef<int64_t> srcShape = tdescTy.getShape();
271 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
272 auto newLoadOp = xegpu::LoadNdOp::create(
273 rewriter, op.getLoc(), newResTy, src,
275 newLoadOps.push_back(newLoadOp);
277 rewriter.replaceOpWithMultiple(op, {newLoadOps});
278 return mlir::success();
285struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
286 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
288 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter)
const override {
290 if (!op.getMixedOffsets().empty())
293 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
294 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
297 rewriter.eraseOp(op);
304struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
305 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
307 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override {
310 SmallVector<SmallVector<OpFoldResult>> offsetsList;
311 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
314 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
316 layout = layout.dropSgLayoutAndData();
317 SmallVector<Value> newOps;
318 for (
auto [tdesc, offsets] :
319 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
320 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
321 VectorType newResTy =
322 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
323 auto newOp = xegpu::LoadNdOp::create(
324 rewriter, op.getLoc(), newResTy, tdesc, offsets,
325 nullptr,
nullptr, op.getL1HintAttr(),
326 op.getL2HintAttr(), op.getL3HintAttr(), layout);
327 newOps.push_back(newOp);
329 rewriter.replaceOpWithMultiple(op, {newOps});
337struct WgToSgStoreNdOpWithOffset
338 :
public OpConversionPattern<xegpu::StoreNdOp> {
339 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
341 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const override {
343 SmallVector<SmallVector<OpFoldResult>> offsetsList;
344 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
347 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
349 layout = layout.dropSgLayoutAndData();
350 for (
auto [v, tdesc, offsets] :
351 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
352 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
353 op.getL1HintAttr(), op.getL2HintAttr(),
354 op.getL3HintAttr(), layout);
356 rewriter.eraseOp(op);
364struct WgToSgPrefetchNdOpWithOffset
365 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
366 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
368 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter)
const override {
370 SmallVector<SmallVector<OpFoldResult>> offsetsList;
371 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
374 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
376 layout = layout.dropSgLayoutAndData();
377 for (
auto [tdesc, offsets] :
378 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
379 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
380 op.getL1HintAttr(), op.getL2HintAttr(),
381 op.getL3HintAttr(), layout);
383 rewriter.eraseOp(op);
392struct WgToSgUpdateNdOffsetOp
393 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
394 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
396 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
397 ConversionPatternRewriter &rewriter)
const override {
398 llvm::SmallVector<Value> newUpdateTileOffsetOps;
399 for (
auto tDesc : adaptor.getTensorDesc()) {
400 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
401 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
402 op.getConstOffsets());
403 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
406 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
412struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
413 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
415 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter)
const override {
417 Location loc = op.getLoc();
418 VectorType resultTy = op.getResult().getType();
419 if (resultTy.getRank() != 2)
422 auto layoutCd = op.getLayoutCdAttr();
423 auto layoutA = op.getLayoutAAttr();
424 auto layoutB = op.getLayoutBAttr();
425 if (!layoutCd || !layoutA || !layoutB)
428 SmallVector<Value> newDpasOps;
429 for (
auto aVec : adaptor.getLhs()) {
430 for (
auto bVec : adaptor.getRhs()) {
432 llvm::SmallVector<Value> operands({aVec, bVec});
435 tmpC = adaptor.getAcc()[i++];
436 operands.push_back(tmpC);
439 ArrayRef<int64_t> aVecShape =
440 llvm::cast<VectorType>(aVec.getType()).getShape();
441 ArrayRef<int64_t> bVecShape =
442 llvm::cast<VectorType>(bVec.getType()).getShape();
443 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
444 resultTy.getElementType());
445 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
446 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
447 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
448 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
450 newDpasOps.push_back(newDpasOp);
453 rewriter.replaceOpWithMultiple(op, {newDpasOps});
459struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
460 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
462 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter)
const override {
465 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
466 if ((offsetSize != 0) || op.getConstOffsetsAttr())
469 for (
auto src : adaptor.getTensorDesc())
470 xegpu::PrefetchNdOp::create(
473 rewriter.eraseOp(op);
479struct WgToSgVectorBroadcastOp
480 :
public OpConversionPattern<vector::BroadcastOp> {
481 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
484 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const override {
487 VectorType resultType = op.getResult().getType();
488 ArrayRef<int64_t> wgShape = resultType.getShape();
490 xegpu::DistributeLayoutAttr layout =
492 if (!layout || !layout.isForWorkgroup())
495 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
496 VectorType newResultType =
497 VectorType::get(sgShape, resultType.getElementType());
499 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
502 SmallVector<Value> newBroadcastOps;
503 for (
auto operand : adaptor.getOperands().front()) {
504 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
505 newResultType, operand);
507 newBroadcastOps.push_back(newBroadcast.getResult());
509 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
516 WgToSgElementwiseOp(MLIRContext *ctx)
517 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
520 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
521 ConversionPatternRewriter &rewriter)
const override {
527 assert(resultType &&
"Expected result to be a VectorType");
529 ArrayRef<int64_t> wgShape = resultType.getShape();
531 xegpu::DistributeLayoutAttr layout =
533 if (!layout || !layout.isForWorkgroup())
536 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
538 size_t numVariants = operands.empty() ? 0 : operands.front().size();
540 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
541 return operandVec.size() != numVariants;
545 SmallVector<Value> newResults;
546 VectorType newResultType =
547 VectorType::get(sgShape, resultType.getElementType());
549 for (
size_t i = 0; i < numVariants; ++i) {
550 SmallVector<Value> opOperands;
551 for (
auto &operandVec : operands)
552 opOperands.push_back(operandVec[i]);
555 state.addOperands(opOperands);
556 state.addTypes(newResultType);
557 state.addAttributes(op->
getAttrs());
558 Operation *newOp = rewriter.create(state);
560 newResults.push_back(newOp->
getResult(0));
563 rewriter.replaceOpWithMultiple(op, {newResults});
594struct WgToSgConvertLayoutOp
595 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
596 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
599 matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
600 ConversionPatternRewriter &rewriter)
const override {
601 Location loc = op.getLoc();
602 auto inputLayout = op.getInputLayout();
603 auto targetLayout = op.getTargetLayout();
605 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
606 !targetLayout.isForWorkgroup())
607 return rewriter.notifyMatchFailure(
608 op,
"Input and target layouts must have subgroup layout");
610 Type resultType = op.getResult().getType();
612 rewriter.replaceOp(op, op.getSource());
613 assert(!inputLayout.dropSgLayoutAndData() &&
614 !targetLayout.dropSgLayoutAndData() &&
615 "unexpected layout attributes for scalar type");
619 ArrayRef<int64_t> wgShape = cast<VectorType>(resultType).getShape();
620 SmallVector<int64_t> inputSgLayout =
621 inputLayout.getEffectiveSgLayoutAsInt();
622 SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
623 SmallVector<int64_t> targetSgLayout =
624 targetLayout.getEffectiveSgLayoutAsInt();
625 SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
628 SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
629 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
630 xegpu::LayoutKind::Subgroup)) {
631 inputLayout = inputLayout.dropSgLayoutAndData();
632 targetLayout = targetLayout.dropSgLayoutAndData();
634 SmallVector<Value> newOps(adaptor.getSource());
635 if (inputLayout && targetLayout) {
636 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
637 auto newOp = xegpu::ConvertLayoutOp::create(
638 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
642 rewriter.replaceOpWithMultiple(op, {newOps});
647 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
649 SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
653 auto bytesPerElement = bitWidth / 8;
657 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
658 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
660 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
663 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
665 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
666 rewriter.getIndexType(),
nullptr);
669 auto storeCoords = inputLayout.computeDistributedCoords(
670 rewriter, loc, sgId.getResult(), wgShape);
675 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
676 SmallVector<OpFoldResult> storeMatrixOffsets;
677 for (Value coord : coords) {
678 storeMatrixOffsets.push_back(coord);
680 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
681 storeMatrixOffsets,
nullptr );
684 gpu::BarrierOp::create(rewriter, loc);
687 auto loadCoords = targetLayout.computeDistributedCoords(
688 rewriter, loc, sgId.getResult(), wgShape);
692 VectorType loadType = VectorType::get(targetSgData, elemTy);
695 SmallVector<Value> finalResults;
696 for (
auto coords : *loadCoords) {
697 SmallVector<OpFoldResult> loadMatrixOffsets;
698 for (Value coord : coords) {
699 loadMatrixOffsets.push_back(coord);
701 auto loadOp = xegpu::LoadMatrixOp::create(
702 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
703 targetLayout.dropSgLayoutAndData());
705 finalResults.push_back(loadOp.getResult());
708 rewriter.replaceOpWithMultiple(op, {finalResults});
744struct UnrealizedConversionCastOpPattern
745 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
746 using OpConversionPattern<
747 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
750 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter)
const override {
754 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
755 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
757 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
758 !llvm::all_equal(
ValueRange(inputs).getTypes()))
766 if (op.getNumOperands() == 1 &&
767 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
768 rewriter.replaceOp(op, inputs);
779 if (op.getNumResults() == 1 &&
781 rewriter.replaceOpWithMultiple(op, {inputs});
785 return mlir::failure();
790struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
791 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
794 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
795 ConversionPatternRewriter &rewriter)
const override {
796 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
797 auto vecType = dyn_cast<VectorType>(op.getType());
798 if (!vecAttr || !vecType)
801 xegpu::DistributeLayoutAttr layout =
803 if (!layout || !layout.isForWorkgroup())
806 ArrayRef<int64_t> wgShape = vecType.getShape();
807 SmallVector<int64_t> sgShape;
809 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
811 auto newType = VectorType::get(sgShape, vecType.getElementType());
812 Location loc = op.getLoc();
813 auto eltType = vecType.getElementType();
815 if (vecAttr.isSplat()) {
817 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
819 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
820 rewriter.replaceOp(op, cstOp);
822 }
else if (sgShape == wgShape) {
825 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
826 rewriter.replaceOp(op, newConstOp);
832 if (!eltType.isIndex())
833 return rewriter.notifyMatchFailure(
834 op,
"Unsupported element type for non-splat constant op.");
836 if (wgShape.size() > 2)
837 return rewriter.notifyMatchFailure(
838 op,
"Only 1D & 2D vector constant supported");
840 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
841 int64_t rowStride = 0, colStride = 0;
842 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
843 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
847 colStride = cast<IntegerAttr>(values[1]).getInt() -
848 cast<IntegerAttr>(values[0]).getInt();
851 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
852 cast<IntegerAttr>(values[0]).getInt();
855 for (int64_t r = 0; r < rows; ++r) {
856 for (int64_t c = 0; c < cols; ++c) {
857 int64_t idx = r * cols + c;
859 if (c > 0 && cols > 1) {
860 int64_t prevIdx = r * cols + (c - 1);
861 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
862 cast<IntegerAttr>(values[prevIdx]).getInt();
863 if (diff != colStride)
864 return rewriter.notifyMatchFailure(
865 op,
"Non-constant column stride in constant op.");
868 if (r > 0 && rows > 1) {
869 int64_t prevIdx = (r - 1) * cols + c;
870 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
871 cast<IntegerAttr>(values[prevIdx]).getInt();
872 if (diff != rowStride)
873 return rewriter.notifyMatchFailure(
874 op,
"Non-constant row stride in constant op.");
882 SmallVector<Attribute> baseTileValues;
883 int baseTileCols = sgShape[sgShape.size() - 1];
884 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
885 for (int64_t r = 0; r < baseTileRows; ++r) {
886 for (int64_t c = 0; c < baseTileCols; ++c) {
887 baseTileValues.push_back(values[r * cols + c]);
893 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
897 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
899 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
903 SmallVector<Value, 2> strideConsts;
904 strideConsts.push_back(
908 strideConsts.begin(),
911 SmallVector<Value> newConstOps;
912 for (
auto offsets : *sgOffsets) {
915 for (
size_t i = 0; i < strideConsts.size(); ++i) {
917 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
918 offsets[i], strideConsts[i]);
919 mulOffset = arith::AddIOp::create(
920 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
923 auto bcastOffset = vector::BroadcastOp::create(
924 rewriter, loc, baseConstVec.getType(), mulOffset);
926 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
927 newConstOps.push_back(finalConst);
929 rewriter.replaceOpWithMultiple(op, {newConstOps});
937struct WgToSgLoadGatherOpWithOffset
938 :
public OpConversionPattern<xegpu::LoadGatherOp> {
939 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
941 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter)
const override {
944 if (!op.getOffsets())
947 Location loc = op.getLoc();
948 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
951 ArrayRef<int64_t> wgShape = resultType.getShape();
953 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
955 if (!layout || !layout.isForWorkgroup())
958 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
961 auto offsetsVecType =
962 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
964 dyn_cast<VectorType>(adaptor.getMask().front().getType());
965 if (!offsetsVecType || !maskVecType ||
966 offsetsVecType.getShape() != maskVecType.getShape()) {
967 return rewriter.notifyMatchFailure(op,
968 "offsets have not been distributed");
971 SmallVector<Value> newLoadOps;
973 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
974 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
975 for (
auto [offsets, mask] :
976 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
977 auto newLayout = layout.dropSgLayoutAndData();
978 auto newLoadOp = xegpu::LoadGatherOp::create(
979 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
980 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
982 newLoadOps.push_back(newLoadOp);
984 rewriter.replaceOpWithMultiple(op, {newLoadOps});
991struct WgToSgStoreScatterOpWithOffset
992 :
public OpConversionPattern<xegpu::StoreScatterOp> {
993 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
995 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter)
const override {
998 if (!op.getOffsets())
1001 Location loc = op.getLoc();
1002 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
1006 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1008 if (!layout || !layout.isForWorkgroup())
1012 auto offsetsVecType =
1013 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
1015 dyn_cast<VectorType>(adaptor.getMask().front().getType());
1016 if (!offsetsVecType || !maskVecType ||
1017 offsetsVecType.getShape() != maskVecType.getShape()) {
1018 return rewriter.notifyMatchFailure(op,
1019 "offsets have not been distributed");
1022 auto chunkSizeOpt = op.getChunkSize();
1023 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
1024 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
1025 for (
auto [val, offs, mask] : llvm::zip(
1026 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
1027 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
1028 mask, chunkSizeAttr, op.getL1HintAttr(),
1029 op.getL2HintAttr(), op.getL3HintAttr(),
1030 layout.dropSgLayoutAndData());
1032 rewriter.eraseOp(op);
1037struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
1038 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
1040 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
1041 ConversionPatternRewriter &rewriter)
const override {
1043 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1044 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1047 ArrayRef<int64_t> wgShape = op.getDataShape();
1048 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
1049 assert(valueTy &&
"the value type must be vector type!");
1050 Type elemTy = valueTy.getElementType();
1052 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1053 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1054 VectorType newResTy = VectorType::get(sgShape, elemTy);
1055 SmallVector<Value> newOps;
1056 for (
auto offsets : offsetsList) {
1057 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1058 op.getMemDesc(), offsets,
1059 layout.dropSgLayoutAndData());
1060 newOps.push_back(newOp);
1062 rewriter.replaceOpWithMultiple(op, {newOps});
1068struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1069 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1071 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter)
const override {
1074 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1075 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1078 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1079 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1080 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1081 offsets, layout.dropSgLayoutAndData());
1082 rewriter.eraseOp(op);
1088struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1089 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1091 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter)
const override {
1093 xegpu::DistributeLayoutAttr layout =
1095 if (!layout || !layout.isForWorkgroup())
1098 Location loc = op.getLoc();
1099 VectorType type = op.getResult().getType();
1100 auto wgShape = type.getShape();
1101 std::optional<SmallVector<int64_t>> sgShape =
1102 getSgShapeAndCount(wgShape, layout).first;
1107 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1109 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1113 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1114 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1115 SmallVector<Value> newOps;
1116 for (
auto offsets : *sgOffsets) {
1119 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1121 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1122 newOps.push_back(finalSteps);
1125 rewriter.replaceOpWithMultiple(op, {newOps});
1131struct WgToSgVectorShapeCastOp
1132 :
public OpConversionPattern<vector::ShapeCastOp> {
1133 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1136 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter)
const override {
1139 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1143 ArrayRef<int64_t> wgShape = resultType.getShape();
1144 xegpu::DistributeLayoutAttr layout =
1146 if (!layout || !layout.isForWorkgroup())
1151 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1155 ArrayRef<int64_t> srcShape = srcType.getShape();
1157 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1158 SmallVector<int64_t> expandedUnitDims;
1160 xegpu::DistributeLayoutAttr sourceLayout =
1163 auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
1164 return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
1165 return isa<vector::BroadcastOp>(user);
1169 if (!usedByBroadcastOp(op))
1170 return rewriter.notifyMatchFailure(
1171 op,
"ShapeCast ops that expand unit dimensions and are used by "
1172 "non-broadcast operations are not supported.");
1174 if (!sourceLayout.isSliceOf(layout))
1175 return rewriter.notifyMatchFailure(
1176 op,
"The ShapeCast op only expands dimensions, the input layout "
1177 "must be a slice of the result layout.");
1179 assert(layoutToDistribute.isEqualTo(
1180 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1181 "The sg_data for unit dimensions should be set as 1");
1184 SmallVector<int64_t> sgShape =
1185 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1186 VectorType newResultType =
1187 VectorType::get(sgShape, resultType.getElementType());
1189 SmallVector<Value> newShapeCastOps;
1190 for (
auto src : adaptor.getSource()) {
1191 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1192 newResultType, src);
1193 newShapeCastOps.push_back(newShapeCast.getResult());
1196 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1233struct WgToSgMultiDimReductionOp
1234 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1235 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1238 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1239 ConversionPatternRewriter &rewriter)
const override {
1240 Location loc = op.getLoc();
1242 VectorType srcType = op.getSourceVectorType();
1243 Type resultTy = op.getResult().getType();
1244 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1245 bool isScalarResult = !dstVecType;
1247 auto originalSrcShape = srcType.getShape();
1248 int srcVecRank = originalSrcShape.size();
1249 Type elemTy = srcType.getElementType();
1251 xegpu::DistributeLayoutAttr layout =
1253 if (!layout || !layout.isForWorkgroup())
1256 auto reductionDims = llvm::to_vector(op.getReductionDims());
1259 SmallVector<int64_t> sgLayout;
1260 SmallVector<int64_t> sgData;
1261 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1262 sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1263 sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1265 return rewriter.notifyMatchFailure(
1266 op,
"Reduction should have SliceAttr layout");
1269 SmallVector<Value> localReductions;
1270 auto sgSrcs = adaptor.getSource();
1271 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1272 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1273 sgSrcType.getShape().end());
1280 auto originalDstShape = dstVecType.getShape();
1281 SmallVector<int64_t> sgDstShape =
1282 getSgShapeAndCount(originalDstShape, layout).first;
1283 sgDstType = VectorType::get(sgDstShape, elemTy);
1288 for (
auto sgSrc : sgSrcs) {
1291 rewriter, loc, sgDstType, op.getKind());
1293 auto localReduce = vector::MultiDimReductionOp::create(
1294 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1296 localReductions.push_back(localReduce.getResult());
1300 SmallVector<int64_t> crossSgReductionDims;
1301 for (int64_t reductionDim : reductionDims) {
1302 bool needsCrossSubgroupReduction =
1303 (sgLayout[reductionDim] > 1) &&
1304 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1306 if (needsCrossSubgroupReduction) {
1307 crossSgReductionDims.push_back(reductionDim);
1312 if (crossSgReductionDims.empty()) {
1313 SmallVector<Value> results;
1314 for (
auto localResult : localReductions) {
1316 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1317 results.push_back(finalResult);
1319 rewriter.replaceOpWithMultiple(op, {results});
1324 auto slmStoreDataShape = sgSrcShape;
1325 for (int64_t dim : reductionDims)
1326 slmStoreDataShape[dim] = 1;
1327 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1329 if (isScalarResult) {
1331 slmStoreData = vector::BroadcastOp::create(
1332 rewriter, loc, slmStoreDataType, localReductions[0]);
1334 slmStoreData = vector::ShapeCastOp::create(
1335 rewriter, loc, slmStoreDataType, localReductions[0]);
1338 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1339 originalSrcShape.end());
1341 for (int64_t dim : reductionDims)
1342 slmShape[dim] = sgLayout[dim];
1346 auto bytesPerElement = bitWidth / 8;
1348 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1349 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1351 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1354 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1357 if (localReductions.size() > 1) {
1358 return rewriter.notifyMatchFailure(
1360 "Multiple local reductions not supported in current implementation.");
1364 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1365 rewriter.getIndexType(),
nullptr);
1368 SmallVector<Value> sgLayoutValues;
1369 for (int64_t dim : sgLayout)
1370 sgLayoutValues.push_back(
1373 auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1377 SmallVector<Value> sgIds = *sgIdsResult;
1379 auto getSlmOffsets = [&](int64_t reductionDimStride) {
1380 SmallVector<OpFoldResult> offsets;
1381 offsets.reserve(srcVecRank);
1382 for (
int i = 0; i < srcVecRank; ++i) {
1383 Value dimVal = sgIds[i];
1384 int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
1385 ? reductionDimStride
1390 arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1391 offsets.push_back(offsetVal);
1396 SmallVector<OpFoldResult> slmStoreOffsets =
1399 xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
1400 memDesc.getResult(), slmStoreOffsets,
1403 gpu::BarrierOp::create(rewriter, loc);
1406 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1407 for (int64_t dim : reductionDims)
1408 slmLoadDataShape[dim] = slmShape[dim];
1410 SmallVector<OpFoldResult> slmLoadOffsets =
1413 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1414 auto slmLoadOp = xegpu::LoadMatrixOp::create(
1415 rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
1420 rewriter, loc, sgDstType, op.getKind());
1422 auto finalReduce = vector::MultiDimReductionOp::create(
1423 rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
1424 neutralFinalAcc, reductionDims);
1428 finalReduce.getResult(),
1429 adaptor.getAcc()[0]);
1431 rewriter.replaceOp(op, finalResult);
1437struct WgToSgVectorTransposeOp
1438 :
public OpConversionPattern<vector::TransposeOp> {
1439 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1442 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1443 ConversionPatternRewriter &rewriter)
const override {
1444 VectorType resultType = op.getResultVectorType();
1446 ArrayRef<int64_t> wgShape = resultType.getShape();
1447 xegpu::DistributeLayoutAttr layout =
1449 if (!layout || !layout.isForWorkgroup())
1452 xegpu::DistributeLayoutAttr sourceLayout =
1454 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1457 SmallVector<int64_t> sourceSgLayout =
1458 sourceLayout.getEffectiveSgLayoutAsInt();
1459 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1461 ArrayRef<int64_t> permutation = op.getPermutation();
1462 size_t permutationSize = permutation.size();
1463 if (sourceSgLayout.size() != permutationSize ||
1464 resultSgLayout.size() != permutationSize) {
1465 return rewriter.notifyMatchFailure(
1466 op,
"Layouts and permutation must have the same rank");
1471 if (!layout.isTransposeOf(sourceLayout, permutation,
1472 xegpu::LayoutKind::Subgroup))
1473 return rewriter.notifyMatchFailure(
1474 op,
"Result layout is not a valid transpose of source layout "
1475 "according to permutation");
1477 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1478 VectorType newResultType =
1479 VectorType::get(sgShape, resultType.getElementType());
1481 SmallVector<Value> newTransposeOps;
1482 for (
auto src : adaptor.getVector()) {
1483 auto newTranspose = vector::TransposeOp::create(
1484 rewriter, op.getLoc(), newResultType, src, permutation);
1485 newTransposeOps.push_back(newTranspose.getResult());
1487 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1493template <
typename MaskOpType>
1494struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1495 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1497 LogicalResult matchAndRewrite(
1499 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1500 ConversionPatternRewriter &rewriter)
const override {
1501 xegpu::DistributeLayoutAttr layout =
1503 if (!layout || !layout.isForWorkgroup())
1506 Location loc = op.getLoc();
1507 VectorType type = op.getResult().getType();
1508 auto wgShape = type.getShape();
1510 SmallVector<Value> wgMaskDimSizes;
1511 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1512 for (int64_t maskSize : op.getMaskDimSizes()) {
1513 wgMaskDimSizes.push_back(
1516 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1517 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1521 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1523 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1527 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1528 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1532 SmallVector<Value> newCreateMaskOps;
1533 for (
auto offsetSet : *sgOffsets) {
1534 SmallVector<Value> maskOperands;
1536 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1539 Value offset = offsetSet[i];
1540 Value adjustedMaskSize =
1541 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1544 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1546 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1547 maskOperands.push_back(sgMaskSize);
1550 auto newCreateMaskOp =
1551 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1552 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1555 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1560using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1561using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1568 .
add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1569 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1570 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1571 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1572 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1573 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1574 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1575 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1576 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1577 WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1584struct XeGPUWgToSgDistributePass
1585 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1586 void runOnOperation()
override;
1590void XeGPUWgToSgDistributePass::runOnOperation() {
1592 Operation *op = getOperation();
1594 signalPassFailure();
1599 SmallVector<Operation *> existingCastOps;
1600 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1601 existingCastOps.push_back(castOp.getOperation());
1611 TypeConverter converter;
1612 converter.addConversion([&](Type type) -> Type {
return type; });
1613 converter.addConversion(
1614 [&](RankedTensorType type,
1615 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1619 auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
1620 type.getEncoding());
1622 return std::nullopt;
1624 Type elemTy = type.getElementType();
1625 ArrayRef<int64_t> shape = type.getShape();
1628 SmallVector<int64_t> subShape;
1629 std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
1631 auto newTy = VectorType::get(subShape, elemTy);
1632 result.append(count, newTy);
1643 RewritePatternSet patterns(ctx);
1644 ConversionTarget
target(*ctx);
1645 TypeConverter converter;
1646 converter.addConversion([&](Type type) -> Type {
return type; });
1647 converter.addConversion(
1648 [&](xegpu::TensorDescType type,
1649 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1650 xegpu::LayoutAttr layout = type.getLayoutAttr();
1653 if (!layout || !layout.isForWorkgroup())
1654 return std::nullopt;
1656 Type elemTy = type.getElementType();
1657 ArrayRef<int64_t> shape = type.getShape();
1660 SmallVector<int64_t> subShape;
1661 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1663 layout = layout.dropSgLayoutAndData();
1665 auto newTy = xegpu::TensorDescType::get(
1666 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1667 result.append(count, newTy);
1671 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1672 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1673 return createOp.getType();
1674 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1675 return loadOp.getTensorDescType();
1676 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1677 return storeOp.getTensorDescType();
1678 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1680 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1681 return prefetchOp.getTensorDescType();
1682 return xegpu::TensorDescType();
1685 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1686 return !layout || !layout.isForWorkgroup();
1689 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1690 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1691 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1692 auto tdescTy = getTensorDescType(op);
1693 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1694 return isLegal(layout);
1697 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1698 auto layout = op.getLayoutCdAttr();
1699 return isLegal(layout);
1702 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1703 [=](xegpu::LoadMatrixOp op) ->
bool {
1704 return isLegal(op.getLayoutAttr());
1707 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1708 [=](xegpu::StoreMatrixOp op) ->
bool {
1709 return isLegal(op.getLayoutAttr());
1712 target.addDynamicallyLegalOp<arith::ConstantOp>(
1713 [=](arith::ConstantOp op) ->
bool {
1714 auto vecType = dyn_cast<VectorType>(op.getType());
1720 return isLegal(layout);
1723 target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1724 vector::TransposeOp, vector::BroadcastOp,
1725 vector::MultiDimReductionOp,
1726 vector::ConstantMaskOp, vector::CreateMaskOp>(
1727 [=](Operation *op) ->
bool {
1731 return isLegal(layout);
1734 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1735 [=](xegpu::LoadGatherOp op) ->
bool {
1736 auto layout = op.getLayoutAttr();
1737 return isLegal(layout);
1740 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1741 [=](xegpu::StoreScatterOp op) ->
bool {
1742 auto layout = op.getLayoutAttr();
1743 return isLegal(layout);
1746 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1747 [=](xegpu::ConvertLayoutOp op) ->
bool {
1748 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1751 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1752 [=](Operation *op) -> std::optional<bool> {
1757 VectorType resultType =
1765 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1766 if (!operandType || operandType.getShape() != resultType.getShape()) {
1771 xegpu::DistributeLayoutAttr layout =
1773 return isLegal(layout);
1776 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1777 [=](UnrealizedConversionCastOp op) {
1778 return llvm::is_contained(existingCastOps, op.getOperation());
1781 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1787 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1788 return signalPassFailure();
1791 getOperation()->walk([](Operation *op) {
1792 if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
1795 SmallVector<StringAttr> attrsToRemove;
1797 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1798 attrsToRemove.push_back(namedAttr.getName());
1800 for (
auto attrName : attrsToRemove)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.