25#include "llvm/ADT/SetVector.h"
30#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
31#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
43 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
44 parent->
getAttr(
"sg_id_range")))
51static std::pair<SmallVector<int64_t>,
int>
53 xegpu::DistributeLayoutAttr layout) {
56 auto distributedShape = layout.computeDistributedShape(
58 if (
failed(distributedShape))
59 return std::make_pair(sgShape, count);
60 auto sgData = layout.getEffectiveSgDataAsInt();
62 return std::make_pair(sgData, count);
69template <
typename OpType,
70 typename = std::enable_if_t<llvm::is_one_of<
71 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
72 xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
74genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
79 if (origOffsets.empty())
83 xegpu::DistributeLayoutAttr layout;
84 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
85 std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
86 layout = op.getLayoutAttr();
88 layout = op.getDescLayoutAttr();
92 if (!layout || !layout.isForWorkgroup())
96 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
99 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
101 int64_t startOfRange = sgIdRange.getStart().getInt();
102 int64_t endOfRange = sgIdRange.getEnd().getInt();
104 if (layout.getNumSubgroups() != endOfRange - startOfRange)
105 return rewriter.notifyMatchFailure(
106 op,
"sg_layout size must match the sg_id_range");
108 if (startOfRange > 0) {
109 Value startOfRangeVal =
111 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
118 auto maybeDescOffsets =
119 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
120 if (
failed(maybeDescOffsets))
125 for (
const auto &sgOffsets : *maybeDescOffsets) {
128 offsetsList.push_back(std::move(newOffsets));
182struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
183 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
186 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
187 ConversionPatternRewriter &rewriter)
const override {
189 Location loc = op.getLoc();
191 xegpu::TensorDescType tdescTy = op.getType();
192 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
193 if (!layout || !layout.isForWorkgroup())
196 Type elemTy = tdescTy.getElementType();
197 ArrayRef<int64_t> wgShape = tdescTy.getShape();
199 SmallVector<int64_t> sgShape;
201 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
202 xegpu::TensorDescType newTdescTy =
203 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
204 layout.dropSgLayoutAndData());
206 SmallVector<Value> newCreateNdOps(count);
207 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
208 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
209 op.getSource(), op.getMixedSizes(),
210 op.getMixedStrides());
213 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
219struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
220 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
222 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter)
const override {
225 SmallVector<SmallVector<OpFoldResult>> offsetsList;
226 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
229 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
231 layout = layout.dropSgLayoutAndData();
232 SmallVector<Value> newOps;
233 for (
auto [tdesc, offsets] :
234 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
235 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
236 VectorType newResTy =
237 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
238 auto newOp = xegpu::LoadNdOp::create(
239 rewriter, op.getLoc(), newResTy, tdesc, offsets,
240 nullptr,
nullptr, op.getL1HintAttr(),
241 op.getL2HintAttr(), op.getL3HintAttr(), layout);
242 newOps.push_back(newOp);
244 rewriter.replaceOpWithMultiple(op, {newOps});
251struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
252 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
254 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter)
const override {
256 SmallVector<SmallVector<OpFoldResult>> offsetsList;
257 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
260 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
262 layout = layout.dropSgLayoutAndData();
263 for (
auto [v, tdesc, offsets] :
264 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
265 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
266 op.getL1HintAttr(), op.getL2HintAttr(),
267 op.getL3HintAttr(), layout);
269 rewriter.eraseOp(op);
276struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
277 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
279 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter)
const override {
281 SmallVector<SmallVector<OpFoldResult>> offsetsList;
282 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
285 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
287 layout = layout.dropSgLayoutAndData();
288 for (
auto [tdesc, offsets] :
289 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
290 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
291 op.getL1HintAttr(), op.getL2HintAttr(),
292 op.getL3HintAttr(), layout);
294 rewriter.eraseOp(op);
301struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
302 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
304 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter)
const override {
306 Location loc = op.getLoc();
307 VectorType resultTy = op.getResult().getType();
308 if (resultTy.getRank() < 2)
311 auto layoutCd = op.getLayoutCdAttr();
312 auto layoutA = op.getLayoutAAttr();
313 auto layoutB = op.getLayoutBAttr();
314 if (!layoutCd || !layoutA || !layoutB)
317 SmallVector<Value> newDpasOps;
318 for (
auto aVec : adaptor.getLhs()) {
319 for (
auto bVec : adaptor.getRhs()) {
321 llvm::SmallVector<Value> operands({aVec, bVec});
324 tmpC = adaptor.getAcc()[i++];
325 operands.push_back(tmpC);
328 ArrayRef<int64_t> aVecShape =
329 cast<VectorType>(aVec.getType()).getShape();
330 ArrayRef<int64_t> bVecShape =
331 cast<VectorType>(bVec.getType()).getShape();
334 SmallVector<int64_t> resShape(aVecShape.drop_back(2));
335 resShape.push_back(aVecShape[aVecShape.size() - 2]);
336 resShape.push_back(bVecShape[bVecShape.size() - 1]);
337 VectorType resTy = VectorType::get(resShape, resultTy.getElementType());
338 auto newDpasOp = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
339 newDpasOp.setLayoutCdAttr(layoutCd.dropSgLayoutAndData());
340 newDpasOp.setLayoutAAttr(layoutA.dropSgLayoutAndData());
341 newDpasOp.setLayoutBAttr(layoutB.dropSgLayoutAndData());
343 newDpasOps.push_back(newDpasOp);
346 rewriter.replaceOpWithMultiple(op, {newDpasOps});
352struct WgToSgDpasMxOp :
public OpConversionPattern<xegpu::DpasMxOp> {
353 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
355 matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
356 ConversionPatternRewriter &rewriter)
const override {
358 Location loc = op.getLoc();
359 VectorType resultTy = op.getResult().getType();
361 if (resultTy.getRank() < 2)
364 auto layoutCd = op.getLayoutCdAttr();
365 auto layoutA = op.getLayoutAAttr();
366 auto layoutB = op.getLayoutBAttr();
367 auto layoutAScale = op.getLayoutAScaleAttr();
368 auto layoutBScale = op.getLayoutBScaleAttr();
370 if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
374 SmallVector<Value> newDpasMxOps;
375 for (
auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
376 for (
auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
377 Value accVal = (op.getAcc()) ? adaptor.getAcc()[index_c++] : Value();
379 (op.getScaleA()) ? adaptor.getScaleA()[index_a] : Value();
381 (op.getScaleB()) ? adaptor.getScaleB()[index_b] : Value();
383 ArrayRef<int64_t> aVecShape =
384 cast<VectorType>(aVec.getType()).getShape();
385 ArrayRef<int64_t> bVecShape =
386 cast<VectorType>(bVec.getType()).getShape();
388 SmallVector<int64_t> resShape(aVecShape.drop_back(2));
389 resShape.push_back(aVecShape[aVecShape.size() - 2]);
390 resShape.push_back(bVecShape[bVecShape.size() - 1]);
391 VectorType resTy = VectorType::get(resShape, resultTy.getElementType());
392 auto newDpasMxOp = xegpu::DpasMxOp::create(
393 rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
394 layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
395 layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
396 layoutBScale.dropSgLayoutAndData());
398 newDpasMxOps.push_back(newDpasMxOp);
401 rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
407struct WgToSgVectorBroadcastOp
408 :
public OpConversionPattern<vector::BroadcastOp> {
409 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
412 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter)
const override {
415 VectorType resultType = op.getResult().getType();
416 ArrayRef<int64_t> wgShape = resultType.getShape();
418 xegpu::DistributeLayoutAttr layout =
420 if (!layout || !layout.isForWorkgroup())
423 SmallVector<int64_t> sgShape;
425 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
426 VectorType newResultType =
427 VectorType::get(sgShape, resultType.getElementType());
429 SmallVector<Value> newBroadcastOps;
430 auto distSource = adaptor.getOperands().front();
431 int numDistributions = count / distSource.size();
432 for (
int i = 0; i < numDistributions; ++i) {
433 for (
auto operand : distSource) {
434 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
435 newResultType, operand);
437 newBroadcastOps.push_back(newBroadcast.getResult());
440 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
446struct WgToSgElementwiseOp :
public ConversionPattern {
448 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
452 ConversionPatternRewriter &rewriter)
const override {
458 assert(resultType &&
"Expected result to be a VectorType");
462 xegpu::DistributeLayoutAttr layout =
464 if (!layout || !layout.isForWorkgroup())
469 size_t numVariants = operands.empty() ? 0 : operands.front().size();
471 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
472 return operandVec.size() != numVariants;
477 VectorType newResultType =
478 VectorType::get(sgShape, resultType.getElementType());
480 for (
size_t i = 0; i < numVariants; ++i) {
482 for (
auto &operandVec : operands)
483 opOperands.push_back(operandVec[i]);
486 state.addOperands(opOperands);
487 state.addTypes(newResultType);
488 state.addAttributes(op->
getAttrs());
489 Operation *newOp = rewriter.create(state);
491 newResults.push_back(newOp->
getResult(0));
494 rewriter.replaceOpWithMultiple(op, {newResults});
525struct WgToSgConvertLayoutOp
526 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
527 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
531 ConversionPatternRewriter &rewriter)
const override {
533 auto inputLayout = op.getInputLayout();
534 auto targetLayout = op.getTargetLayout();
536 if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
537 !targetLayout.isForWorkgroup())
538 return rewriter.notifyMatchFailure(
539 op,
"Input and target layouts must have subgroup layout");
541 Type resultType = op.getResult().getType();
543 rewriter.replaceOp(op, op.getSource());
544 assert(!inputLayout.dropSgLayoutAndData() &&
545 !targetLayout.dropSgLayoutAndData() &&
546 "unexpected layout attributes for scalar type");
552 inputLayout.getEffectiveSgLayoutAsInt();
555 targetLayout.getEffectiveSgLayoutAsInt();
560 if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
562 inputLayout = inputLayout.dropSgLayoutAndData();
563 targetLayout = targetLayout.dropSgLayoutAndData();
566 if (inputLayout && targetLayout) {
567 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
568 auto newOp = xegpu::ConvertLayoutOp::create(
569 rewriter, loc, src.getType(), src, inputLayout, targetLayout);
573 rewriter.replaceOpWithMultiple(op, {newOps});
578 Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
584 auto bytesPerElement = bitWidth / 8;
588 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
589 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
591 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
594 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
596 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
597 rewriter.getIndexType(),
nullptr);
600 auto storeCoords = inputLayout.computeDistributedCoords(
601 rewriter, loc, sgId.getResult(), wgShape);
602 if (failed(storeCoords))
606 for (
auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
608 for (
Value coord : coords) {
609 storeMatrixOffsets.push_back(coord);
611 xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
612 storeMatrixOffsets,
nullptr );
615 gpu::BarrierOp::create(rewriter, loc);
618 auto loadCoords = targetLayout.computeDistributedCoords(
619 rewriter, loc, sgId.getResult(), wgShape);
620 if (failed(loadCoords))
623 VectorType loadType = VectorType::get(targetSgData, elemTy);
627 for (
auto coords : *loadCoords) {
629 for (
Value coord : coords) {
630 loadMatrixOffsets.push_back(coord);
632 auto loadOp = xegpu::LoadMatrixOp::create(
633 rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
634 targetLayout.dropSgLayoutAndData());
636 finalResults.push_back(loadOp.getResult());
639 rewriter.replaceOpWithMultiple(op, {finalResults});
645struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
646 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
650 ConversionPatternRewriter &rewriter)
const override {
651 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
652 auto vecType = dyn_cast<VectorType>(op.getType());
653 if (!vecAttr || !vecType)
656 xegpu::DistributeLayoutAttr layout =
658 if (!layout || !layout.isForWorkgroup())
664 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
666 auto newType = VectorType::get(sgShape, vecType.getElementType());
668 auto eltType = vecType.getElementType();
670 if (vecAttr.isSplat()) {
675 for (
int i = 0; i < count; ++i) {
676 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
677 newConstOps.push_back(cstOp);
679 rewriter.replaceOpWithMultiple(op, {newConstOps});
681 }
else if (sgShape == wgShape) {
684 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
685 rewriter.replaceOp(op, newConstOp);
691 if (!eltType.isIndex())
692 return rewriter.notifyMatchFailure(
693 op,
"Unsupported element type for non-splat constant op.");
695 if (wgShape.size() > 2)
696 return rewriter.notifyMatchFailure(
697 op,
"Only 1D & 2D vector constant supported");
699 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
700 int64_t rowStride = 0, colStride = 0;
701 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
702 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
706 colStride = cast<IntegerAttr>(values[1]).getInt() -
707 cast<IntegerAttr>(values[0]).getInt();
710 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
711 cast<IntegerAttr>(values[0]).getInt();
714 for (int64_t r = 0; r < rows; ++r) {
715 for (int64_t c = 0; c < cols; ++c) {
716 int64_t idx = r * cols + c;
718 if (c > 0 && cols > 1) {
719 int64_t prevIdx = r * cols + (c - 1);
720 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
721 cast<IntegerAttr>(values[prevIdx]).getInt();
722 if (diff != colStride)
723 return rewriter.notifyMatchFailure(
724 op,
"Non-constant column stride in constant op.");
727 if (r > 0 && rows > 1) {
728 int64_t prevIdx = (r - 1) * cols + c;
729 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
730 cast<IntegerAttr>(values[prevIdx]).getInt();
731 if (diff != rowStride)
732 return rewriter.notifyMatchFailure(
733 op,
"Non-constant row stride in constant op.");
741 SmallVector<Attribute> baseTileValues;
742 int baseTileCols = sgShape[sgShape.size() - 1];
743 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
744 for (int64_t r = 0; r < baseTileRows; ++r) {
745 for (int64_t c = 0; c < baseTileCols; ++c) {
746 baseTileValues.push_back(values[r * cols + c]);
752 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
756 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
758 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
762 SmallVector<Value, 2> strideConsts;
763 strideConsts.push_back(
767 strideConsts.begin(),
770 SmallVector<Value> newConstOps;
771 for (
auto offsets : *sgOffsets) {
774 for (
size_t i = 0; i < strideConsts.size(); ++i) {
776 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
777 offsets[i], strideConsts[i]);
778 mulOffset = arith::AddIOp::create(
779 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
782 auto bcastOffset = vector::BroadcastOp::create(
783 rewriter, loc, baseConstVec.getType(), mulOffset);
785 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
786 newConstOps.push_back(finalConst);
788 rewriter.replaceOpWithMultiple(op, {newConstOps});
796struct WgToSgLoadGatherOp :
public OpConversionPattern<xegpu::LoadGatherOp> {
797 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
799 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter)
const override {
802 Location loc = op.getLoc();
803 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
806 ArrayRef<int64_t> wgShape = resultType.getShape();
808 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
810 if (!layout || !layout.isForWorkgroup())
813 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
816 auto offsetsVecType =
817 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
819 dyn_cast<VectorType>(adaptor.getMask().front().getType());
820 if (!offsetsVecType || !maskVecType ||
821 offsetsVecType.getShape() != maskVecType.getShape()) {
822 return rewriter.notifyMatchFailure(op,
823 "offsets have not been distributed");
826 SmallVector<Value> newLoadOps;
828 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
829 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
830 for (
auto [offsets, mask] :
831 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
832 auto newLayout = layout.dropSgLayoutAndData();
833 auto newLoadOp = xegpu::LoadGatherOp::create(
834 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
835 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
837 newLoadOps.push_back(newLoadOp);
839 rewriter.replaceOpWithMultiple(op, {newLoadOps});
846struct WgToSgStoreScatterOp
847 :
public OpConversionPattern<xegpu::StoreScatterOp> {
848 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
850 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
851 ConversionPatternRewriter &rewriter)
const override {
853 Location loc = op.getLoc();
854 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
858 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
860 if (!layout || !layout.isForWorkgroup())
864 auto offsetsVecType =
865 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
867 dyn_cast<VectorType>(adaptor.getMask().front().getType());
868 if (!offsetsVecType || !maskVecType ||
869 offsetsVecType.getShape() != maskVecType.getShape()) {
870 return rewriter.notifyMatchFailure(op,
871 "offsets have not been distributed");
874 auto chunkSizeOpt = op.getChunkSize();
875 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
876 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
877 for (
auto [val, offs, mask] : llvm::zip(
878 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
879 xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
880 mask, chunkSizeAttr, op.getL1HintAttr(),
881 op.getL2HintAttr(), op.getL3HintAttr(),
882 layout.dropSgLayoutAndData());
884 rewriter.eraseOp(op);
889struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
890 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
892 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter)
const override {
895 SmallVector<SmallVector<OpFoldResult>> offsetsList;
896 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
899 ArrayRef<int64_t> wgShape = op.getDataShape();
900 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
901 assert(valueTy &&
"the value type must be vector type!");
902 Type elemTy = valueTy.getElementType();
904 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
905 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
906 VectorType newResTy = VectorType::get(sgShape, elemTy);
907 SmallVector<Value> newOps;
908 for (
auto offsets : offsetsList) {
909 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
910 op.getMemDesc(), offsets,
911 layout.dropSgLayoutAndData());
912 newOps.push_back(newOp);
914 rewriter.replaceOpWithMultiple(op, {newOps});
920struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
921 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
923 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
924 ConversionPatternRewriter &rewriter)
const override {
926 SmallVector<SmallVector<OpFoldResult>> offsetsList;
927 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
930 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
931 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
932 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
933 offsets, layout.dropSgLayoutAndData());
934 rewriter.eraseOp(op);
940struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
941 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
943 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter)
const override {
945 xegpu::DistributeLayoutAttr layout =
947 if (!layout || !layout.isForWorkgroup())
950 Location loc = op.getLoc();
951 VectorType type = op.getResult().getType();
952 auto wgShape = type.getShape();
953 std::optional<SmallVector<int64_t>> sgShape =
954 getSgShapeAndCount(wgShape, layout).first;
959 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
961 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
965 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
966 auto steps = vector::StepOp::create(rewriter, loc, newTy);
967 SmallVector<Value> newOps;
968 for (
auto offsets : *sgOffsets) {
971 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
973 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
974 newOps.push_back(finalSteps);
977 rewriter.replaceOpWithMultiple(op, {newOps});
983struct WgToSgVectorShapeCastOp
984 :
public OpConversionPattern<vector::ShapeCastOp> {
985 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
988 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
989 ConversionPatternRewriter &rewriter)
const override {
991 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
995 ArrayRef<int64_t> wgShape = resultType.getShape();
996 xegpu::DistributeLayoutAttr layout =
998 if (!layout || !layout.isForWorkgroup())
1003 auto srcType = dyn_cast<VectorType>(op.getSource().getType());
1007 ArrayRef<int64_t> srcShape = srcType.getShape();
1009 xegpu::DistributeLayoutAttr layoutToDistribute = layout;
1010 SmallVector<int64_t> expandedUnitDims;
1012 xegpu::DistributeLayoutAttr sourceLayout =
1015 if (!sourceLayout.isSliceOf(layout))
1016 return rewriter.notifyMatchFailure(
1017 op,
"The ShapeCast op only expands dimensions, the input layout "
1018 "must be a slice of the result layout.");
1020 assert(layoutToDistribute.isEqualTo(
1021 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
1022 "The sg_data for unit dimensions should be set as 1");
1025 SmallVector<int64_t> sgShape =
1026 getSgShapeAndCount(wgShape, layoutToDistribute).first;
1027 VectorType newResultType =
1028 VectorType::get(sgShape, resultType.getElementType());
1030 SmallVector<Value> newShapeCastOps;
1031 for (
auto src : adaptor.getSource()) {
1032 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1033 newResultType, src);
1034 newShapeCastOps.push_back(newShapeCast.getResult());
1037 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1074struct WgToSgMultiDimReductionOp
1075 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1076 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1079 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1080 ConversionPatternRewriter &rewriter)
const override {
1081 Location loc = op.getLoc();
1083 VectorType srcType = op.getSourceVectorType();
1084 Type resultTy = op.getResult().getType();
1085 VectorType dstVecType = dyn_cast<VectorType>(resultTy);
1086 bool isScalarResult = !dstVecType;
1088 auto originalSrcShape = srcType.getShape();
1089 Type elemTy = srcType.getElementType();
1091 xegpu::DistributeLayoutAttr layout =
1093 if (!layout || !layout.isForWorkgroup())
1096 auto reductionDims = llvm::to_vector(op.getReductionDims());
1099 SmallVector<int64_t> sgLayout;
1100 SmallVector<int64_t> sgData;
1101 xegpu::DistributeLayoutAttr parentLayout;
1102 if (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1103 parentLayout = sliceAttr.getParent();
1104 sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
1105 sgData = parentLayout.getEffectiveSgDataAsInt();
1107 return rewriter.notifyMatchFailure(
1108 op,
"Reduction should have SliceAttr layout");
1111 SmallVector<Value> localReductions;
1112 auto sgSrcs = adaptor.getSource();
1113 auto sgSrcType = dyn_cast<VectorType>(sgSrcs.front().getType());
1114 SmallVector<int64_t> sgSrcShape(sgSrcType.getShape().begin(),
1115 sgSrcType.getShape().end());
1122 auto originalDstShape = dstVecType.getShape();
1123 SmallVector<int64_t> sgDstShape =
1124 getSgShapeAndCount(originalDstShape, layout).first;
1125 sgDstType = VectorType::get(sgDstShape, elemTy);
1130 for (
auto sgSrc : sgSrcs) {
1133 rewriter, loc, sgDstType, op.getKind());
1135 auto localReduce = vector::MultiDimReductionOp::create(
1136 rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc,
1138 localReductions.push_back(localReduce.getResult());
1142 SmallVector<int64_t> crossSgReductionDims;
1143 for (int64_t reductionDim : reductionDims) {
1144 bool needsCrossSubgroupReduction =
1145 (sgLayout[reductionDim] > 1) &&
1146 (sgData[reductionDim] < originalSrcShape[reductionDim]);
1148 if (needsCrossSubgroupReduction) {
1149 crossSgReductionDims.push_back(reductionDim);
1154 if (crossSgReductionDims.empty()) {
1155 SmallVector<Value> results;
1156 for (
auto localResult : localReductions) {
1158 rewriter, loc, op.getKind(), localResult, adaptor.getAcc()[0]);
1159 results.push_back(finalResult);
1161 rewriter.replaceOpWithMultiple(op, {results});
1166 auto slmStoreDataShape = sgSrcShape;
1167 for (int64_t dim : reductionDims)
1168 slmStoreDataShape[dim] = 1;
1169 VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
1170 SmallVector<Value> slmStoreData;
1171 for (
auto localResult : localReductions) {
1172 if (isScalarResult) {
1174 slmStoreData.push_back(vector::BroadcastOp::create(
1175 rewriter, loc, slmStoreDataType, localResult));
1177 slmStoreData.push_back(vector::ShapeCastOp::create(
1178 rewriter, loc, slmStoreDataType, localResult));
1182 SmallVector<int64_t> slmShape(originalSrcShape.begin(),
1183 originalSrcShape.end());
1184 SmallVector<int> slmSgData(sgData.begin(), sgData.end());
1185 SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
1186 for (
int dim : reductionDims) {
1187 slmShape[dim] = sgLayout[dim];
1190 xegpu::LayoutAttr slmStoreLayout =
1191 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1195 auto bytesPerElement = bitWidth / 8;
1197 auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1198 auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1200 auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
1203 xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1206 auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1207 rewriter.getIndexType(),
nullptr);
1209 auto slmStoreCoords =
1210 slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1211 if (
failed(slmStoreCoords))
1213 for (
auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
1214 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1215 xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
1220 gpu::BarrierOp::create(rewriter, loc);
1223 SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
1224 for (int64_t dim : reductionDims) {
1225 slmLoadDataShape[dim] = slmShape[dim];
1226 slmSgData[dim] = slmShape[dim];
1228 xegpu::LayoutAttr slmLoadLayout =
1229 xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
1230 auto slmLoadCoords =
1231 slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
1232 if (
failed(slmLoadCoords))
1235 VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
1236 SmallVector<Value> slmLoadData;
1237 for (
auto coord : *slmLoadCoords) {
1238 SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
1239 slmLoadData.push_back(xegpu::LoadMatrixOp::create(
1240 rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
1247 rewriter, loc, sgDstType, op.getKind());
1249 SmallVector<Value> finalResults;
1250 for (
size_t i = 0; i < slmLoadData.size(); ++i) {
1251 auto loaded = slmLoadData[i];
1252 auto finalReduce = vector::MultiDimReductionOp::create(
1253 rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
1256 rewriter, loc, op.getKind(), finalReduce.getResult(),
1257 adaptor.getAcc()[i]));
1259 rewriter.replaceOpWithMultiple(op, {finalResults});
1265struct WgToSgVectorTransposeOp
1266 :
public OpConversionPattern<vector::TransposeOp> {
1267 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1270 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1271 ConversionPatternRewriter &rewriter)
const override {
1272 VectorType resultType = op.getResultVectorType();
1274 ArrayRef<int64_t> wgShape = resultType.getShape();
1275 xegpu::DistributeLayoutAttr layout =
1277 if (!layout || !layout.isForWorkgroup())
1279 xegpu::DistributeLayoutAttr sourceLayout =
1281 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1284 SmallVector<int64_t> sourceSgLayout =
1285 sourceLayout.getEffectiveSgLayoutAsInt();
1286 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1288 ArrayRef<int64_t> permutation = op.getPermutation();
1289 size_t permutationSize = permutation.size();
1290 if (sourceSgLayout.size() != permutationSize ||
1291 resultSgLayout.size() != permutationSize) {
1292 return rewriter.notifyMatchFailure(
1293 op,
"Layouts and permutation must have the same rank");
1298 if (!layout.isTransposeOf(sourceLayout, permutation,
1299 xegpu::LayoutKind::Subgroup))
1300 return rewriter.notifyMatchFailure(
1301 op,
"Result layout is not a valid transpose of source layout "
1302 "according to permutation");
1304 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1305 VectorType newResultType =
1306 VectorType::get(sgShape, resultType.getElementType());
1308 SmallVector<Value> newTransposeOps;
1309 for (
auto src : adaptor.getVector()) {
1310 auto newTranspose = vector::TransposeOp::create(
1311 rewriter, op.getLoc(), newResultType, src, permutation);
1312 newTransposeOps.push_back(newTranspose.getResult());
1314 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1320template <
typename MaskOpType>
1321struct WgToSgVectorMaskOp :
public OpConversionPattern<MaskOpType> {
1322 using OpConversionPattern<MaskOpType>::OpConversionPattern;
1324 LogicalResult matchAndRewrite(
1326 typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1327 ConversionPatternRewriter &rewriter)
const override {
1328 xegpu::DistributeLayoutAttr layout =
1330 if (!layout || !layout.isForWorkgroup())
1333 Location loc = op.getLoc();
1334 VectorType type = op.getResult().getType();
1335 auto wgShape = type.getShape();
1337 SmallVector<Value> wgMaskDimSizes;
1338 if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1339 for (int64_t maskSize : op.getMaskDimSizes()) {
1340 wgMaskDimSizes.push_back(
1343 }
else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1344 wgMaskDimSizes = llvm::to_vector(op.getOperands());
1348 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1350 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1354 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1355 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1359 SmallVector<Value> newCreateMaskOps;
1360 for (
auto offsetSet : *sgOffsets) {
1361 SmallVector<Value> maskOperands;
1363 for (
auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
1366 Value offset = offsetSet[i];
1367 Value adjustedMaskSize =
1368 arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
1371 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1373 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1374 maskOperands.push_back(sgMaskSize);
1377 auto newCreateMaskOp =
1378 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1379 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1382 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1387using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1388using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1391struct WgToSgVectorBitCastOp :
public OpConversionPattern<vector::BitCastOp> {
1392 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
1395 matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter)
const override {
1397 VectorType resultType = op.getResultVectorType();
1399 ArrayRef<int64_t> wgShape = resultType.getShape();
1400 xegpu::DistributeLayoutAttr layout =
1402 if (!layout || !layout.isForWorkgroup())
1405 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1406 VectorType newResultType =
1407 VectorType::get(sgShape, resultType.getElementType());
1409 SmallVector<Value> newBitCastOps;
1410 for (
auto src : adaptor.getSource()) {
1412 vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
1413 newBitCastOps.push_back(newBitCast.getResult());
1416 rewriter.replaceOpWithMultiple(op, {newBitCastOps});
1422struct WgToSgVectorInterleaveOp
1423 :
public OpConversionPattern<vector::InterleaveOp> {
1424 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1427 matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
1428 ConversionPatternRewriter &rewriter)
const override {
1429 VectorType resultType = op.getResultVectorType();
1431 ArrayRef<int64_t> wgShape = resultType.getShape();
1432 xegpu::DistributeLayoutAttr layout =
1434 if (!layout || !layout.isForWorkgroup())
1437 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1438 VectorType newResultType =
1439 VectorType::get(sgShape, resultType.getElementType());
1441 SmallVector<Value> newInterleaveOps;
1444 for (
auto [
lhs,
rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
1445 auto newInterleave = vector::InterleaveOp::create(
1446 rewriter, op.getLoc(), newResultType,
lhs,
rhs);
1447 newInterleaveOps.push_back(newInterleave.getResult());
1450 rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
1456struct WgToSgVectorDeinterleaveOp
1457 :
public OpConversionPattern<vector::DeinterleaveOp> {
1458 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1461 matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter)
const override {
1463 SmallVector<Value> newRes1Ops;
1464 SmallVector<Value> newRes2Ops;
1466 for (
auto src : adaptor.getSource()) {
1467 auto newDeinterleave =
1468 vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
1469 newRes1Ops.push_back(newDeinterleave.getRes1());
1470 newRes2Ops.push_back(newDeinterleave.getRes2());
1473 SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
1474 rewriter.replaceOpWithMultiple(op, results);
1486 converter.addConversion([](
Type type) ->
Type {
return type; });
1489 converter.addConversion(
1490 [](xegpu::TensorDescType type,
1492 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
1493 if (!layout || !layout.isForWorkgroup())
1494 return std::nullopt;
1496 Type elemTy = type.getElementType();
1501 std::tie(subShape, count) = getSgShapeAndCount(
shape, layout);
1503 layout = layout.dropSgLayoutAndData();
1505 auto newTy = xegpu::TensorDescType::get(
1506 type.
getContext(), subShape, elemTy, type.getEncoding(), layout);
1507 result.append(count, newTy);
1513 auto getSubShapeAndCount = [](VectorType vecTy,
1514 xegpu::DistributeLayoutAttr layout)
1516 if (!layout.isForWorkgroup())
1518 return getSgShapeAndCount(vecTy.getShape(), layout);
1523 std::move(loopArgTypes));
1527 patterns.
add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, WgToSgDpasOp,
1528 WgToSgDpasMxOp, WgToSgPrefetchNdOp, WgToSgElementwiseOp,
1529 WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1530 WgToSgArithConstantOp, WgToSgLoadGatherOp, WgToSgStoreScatterOp,
1531 WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp,
1532 WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
1533 WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
1534 WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp,
1535 WgToSgVectorInterleaveOp, WgToSgVectorDeinterleaveOp>(
1542struct XeGPUWgToSgDistributePass
1544 void runOnOperation()
override;
1548void XeGPUWgToSgDistributePass::runOnOperation() {
1550 Operation *op = getOperation();
1552 signalPassFailure();
1557 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1558 getOperation()->walk(
1559 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1566 RewritePatternSet patterns(ctx);
1567 ConversionTarget
target(*ctx);
1568 TypeConverter converter;
1571 auto materializeCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
1572 Location loc) -> Value {
1573 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
1576 converter.addSourceMaterialization(materializeCast);
1577 converter.addTargetMaterialization(materializeCast);
1581 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1582 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1583 return createOp.getType();
1584 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1585 return loadOp.getTensorDescType();
1586 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1587 return storeOp.getTensorDescType();
1588 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1589 return prefetchOp.getTensorDescType();
1590 return xegpu::TensorDescType();
1593 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1594 return !layout || !layout.isForWorkgroup();
1597 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1598 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(
1599 [=](Operation *op) ->
bool {
1600 auto tdescTy = getTensorDescType(op);
1602 dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1603 return isLegal(layout);
1606 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1607 auto layout = op.getLayoutCdAttr();
1608 return isLegal(layout);
1611 target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
1612 [=](xegpu::DpasMxOp op) ->
bool {
1613 auto layout = op.getLayoutCdAttr();
1614 return isLegal(layout);
1617 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1618 [=](xegpu::LoadMatrixOp op) ->
bool {
1619 return isLegal(op.getLayoutAttr());
1622 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1623 [=](xegpu::StoreMatrixOp op) ->
bool {
1624 return isLegal(op.getLayoutAttr());
1627 target.addDynamicallyLegalOp<arith::ConstantOp>(
1628 [=](arith::ConstantOp op) ->
bool {
1629 auto vecType = dyn_cast<VectorType>(op.getType());
1635 return isLegal(layout);
1638 target.addDynamicallyLegalOp<
1639 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1640 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
1641 vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp,
1642 vector::DeinterleaveOp>([=](Operation *op) ->
bool {
1646 return isLegal(layout);
1649 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1650 [=](xegpu::LoadGatherOp op) ->
bool {
1651 auto layout = op.getLayoutAttr();
1652 return isLegal(layout);
1655 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1656 [=](xegpu::StoreScatterOp op) ->
bool {
1657 auto layout = op.getLayoutAttr();
1658 return isLegal(layout);
1661 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1662 [=](xegpu::ConvertLayoutOp op) ->
bool {
1663 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1666 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1667 [=](Operation *op) -> std::optional<bool> {
1672 VectorType resultType =
1680 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1681 if (!operandType || operandType.getShape() != resultType.getShape()) {
1686 xegpu::DistributeLayoutAttr layout =
1688 return isLegal(layout);
1691 target.addLegalOp<UnrealizedConversionCastOp>();
1693 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1699 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1700 return signalPassFailure();
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUWgToSgDistributeTypeConversions(TypeConverter &converter, Operation *topLevelOp)
Define the type conversions needed for XeGPU workgroup to subgroup distribution.
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DenseMap< Value, SmallVector< Type > > precomputeLoopBlockArgTypes(Operation *topLevelOp, SubShapeAndCountFn getSubShapeAndCount)
Pre-computes distributed VectorType mappings for every value carried through an SCF loop under topLev...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
void addVectorTypeConversion(TypeConverter &converter, SubShapeAndCountFn getSubShapeAndCount, DenseMap< Value, SmallVector< Type > > loopArgTypes)
Adds a context-aware VectorType conversion to converter (1:1 shape-changing or 1:N,...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void cleanupUnrealizedConversionCasts(Operation *root, const llvm::SmallSetVector< UnrealizedConversionCastOp, 8 > &existingCasts)
Cleans up UnrealizedConversionCastOps inserted during SCF structural type conversion and/or XeGPU unr...
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.