29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
36#define GEN_PASS_DEF_XEGPUSGTOLANEDISTRIBUTE
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43#define DEBUG_TYPE "xegpu-sg-to-lane-distribute"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49static Value castValueTo(ConversionPatternRewriter &rewriter,
52 if (v.getType() == expectedTy)
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
62 return newOp.getResult(0);
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
71 if (!resLayout || !resLayout.isForSubgroup())
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
84 return op.getReductionDims().size() == 1;
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
93 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
94 "MultiDimReductionOp");
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
108 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
112 return distributedDims;
117struct SgToLaneCreateNdDesc
118 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
119 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
122 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const override {
124 xegpu::TensorDescType resultType = op.getType();
126 if (!resultType.getLayout())
129 auto newOp = xegpu::CreateNdDescOp::create(
130 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
132 rewriter.replaceOp(op, newOp.getResult());
140struct SgToLaneLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
141 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
144 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter)
const override {
146 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
152 if (op.getTensorDescType().getLayout() != layout)
153 return rewriter.notifyMatchFailure(
154 op,
"conflicting layout attributes on tensor descriptor and anchor");
157 return rewriter.notifyMatchFailure(
158 op,
"xegpu::LoadNdOp require target attribute attached to "
159 "determine transpose "
161 auto supportedLaneResultTyOrFailure =
163 auto expectedLaneResultTyOrFailure =
165 if (failed(supportedLaneResultTyOrFailure))
166 return rewriter.notifyMatchFailure(
167 op,
"unable to compute the lane vector type for LoadNdOp");
168 if (failed(expectedLaneResultTyOrFailure))
169 return rewriter.notifyMatchFailure(
170 op,
"unable to compute expected lane vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedLaneResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(),
nullptr);
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedLaneResultTyOrFailure.value()));
190struct SgToLaneStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter)
const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op,
"conflicting layout attributes on tensor descriptor and anchor");
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op,
"conflicting layout attributes on value and anchor");
209 auto supportedLaneValueTyOrFailure =
211 if (failed(supportedLaneValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
214 "unable to compute lane vector type for StoreNdOp value from tensor "
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
220 supportedLaneValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
223 rewriter.eraseOp(op);
231struct SgToLaneDpas :
public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
238 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
239 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
240 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
241 if (!layoutA || !layoutB || !layoutCd)
243 auto laneResultTyOrFailure =
245 auto laneATypeOrFailure =
247 auto laneBTypeOrFailure =
249 auto expectedLaneResultTyOrFailure =
251 if (failed(laneResultTyOrFailure) || failed(laneATypeOrFailure) ||
252 failed(laneBTypeOrFailure))
253 return rewriter.notifyMatchFailure(
254 op,
"failed to calculate supported lane vector types for DpasOp "
256 if (failed(expectedLaneResultTyOrFailure))
257 return rewriter.notifyMatchFailure(
258 op,
"unable to compute expected lane vector type for DpasOp from "
264 const auto *uArchInstruction =
265 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
268 if (uArchInstruction) {
269 auto laneAType = laneATypeOrFailure.value();
270 auto laneBType = laneBTypeOrFailure.value();
272 unsigned aPackedBitWidth =
273 laneAType.getElementTypeBitWidth() * laneAType.getNumElements();
274 unsigned bPackedBitWidth =
275 laneBType.getElementTypeBitWidth() * laneBType.getNumElements();
276 unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
277 unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
279 if (aPackedBitWidth % expectedABitSize != 0)
280 return rewriter.notifyMatchFailure(
282 "A operand packed bit width must be a multiple of uArch packed "
283 "format requirement");
284 if (bPackedBitWidth % expectedBBitSize != 0)
285 return rewriter.notifyMatchFailure(
287 "B operand packed bit width must be a multiple of uArch packed "
288 "format requirement");
292 auto newOp = xegpu::DpasOp::create(
293 rewriter, op->getLoc(), laneResultTyOrFailure.value(),
295 laneATypeOrFailure.value()),
297 laneBTypeOrFailure.value()),
299 laneResultTyOrFailure.value()),
303 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
304 expectedLaneResultTyOrFailure.value()));
311struct SgToLaneElementWise :
public ConversionPattern {
313 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
317 ConversionPatternRewriter &rewriter)
const override {
324 return rewriter.notifyMatchFailure(
325 op,
"operation result is not a vector type");
327 xegpu::DistributeLayoutAttr layout =
329 if (!layout || !layout.isForSubgroup())
330 return rewriter.notifyMatchFailure(
331 op,
"operation result does not have subgroup distribute layout");
333 auto laneShapeOrFailure =
336 if (failed(laneShapeOrFailure))
337 return rewriter.notifyMatchFailure(
338 op,
"unable to compute lane vector type from the layout");
340 VectorType newResultType = laneShapeOrFailure.value();
342 state.addOperands(operands);
343 state.addTypes(newResultType);
346 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
347 state.addAttribute(attr.getName(), attr.getValue());
349 Operation *newOp = rewriter.create(state);
351 rewriter.replaceOp(op, newOp->
getResult(0));
358struct SgToLaneArithConstant :
public OpConversionPattern<arith::ConstantOp> {
359 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
362 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 auto resultType = dyn_cast<VectorType>(op.getType());
369 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
371 return rewriter.notifyMatchFailure(
372 op,
"only dense splat vector constants are supported");
374 xegpu::DistributeLayoutAttr layout =
376 if (!layout || !layout.isForSubgroup())
377 return rewriter.notifyMatchFailure(
378 op,
"operation result does not have subgroup distribute layout");
380 auto laneShapeOrFailure =
383 if (
failed(laneShapeOrFailure))
384 return rewriter.notifyMatchFailure(
385 op,
"unable to compute lane vector type from the layout");
387 VectorType newResultType = laneShapeOrFailure.value();
388 auto sclarValue = dense.getSplatValue<Attribute>();
391 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
393 rewriter.replaceOp(op, newOp.
getResult());
399struct SgToLanePrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
400 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
403 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
405 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
410 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
411 op.getMixedOffsets(), op.getL1HintAttr(),
412 op.getL2HintAttr(), op.getL3HintAttr(),
414 rewriter.eraseOp(op);
452struct SgToLaneLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
453 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
456 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter)
const override {
458 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
462 VectorType origResultTy = op.getValueType();
467 int chunkSize = op.getChunkSize().value_or(1);
468 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
469 ArrayRef<int64_t> shape = origResultTy.getShape();
471 shape.take_front(origResultTy.getRank() - effectiveVecRank),
472 [](int64_t d) { return d != 1; }))
473 return rewriter.notifyMatchFailure(
474 op,
"Only unit dimensions allowed for the leading "
475 "dimensions of the load vector!");
477 auto distResultTyOrFailure =
479 if (
failed(distResultTyOrFailure))
480 return rewriter.notifyMatchFailure(
481 op,
"unable to compute expected lane vector type from lane layout");
483 VectorType distResultTy = distResultTyOrFailure.value();
484 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
485 distResultTy.getElementType());
488 Value distOffsets = adaptor.getOffsets();
489 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
490 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
491 distOffsetsTy.getElementType());
492 distOffsets = castValueTo(
495 Value distMask = adaptor.getMask();
496 auto distMaskTy = cast<VectorType>(distMask.
getType());
497 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
498 distMaskTy.getElementType());
502 Value distSource = adaptor.getSource();
503 auto newOp = xegpu::LoadGatherOp::create(
504 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
505 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
506 op.getL3HintAttr(),
nullptr);
509 if (distResultTy1D != distResultTy)
512 rewriter.replaceOp(op,
result);
521struct SgToLaneVectorReduction
522 :
public OpConversionPattern<vector::ReductionOp> {
523 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
526 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter)
const override {
531 if (!layout || !layout.isForSubgroup())
534 VectorType srcVecType = op.getSourceVectorType();
536 if (srcVecType.getRank() != 1)
537 return rewriter.notifyMatchFailure(
538 op,
"Only rank 1 reductions can be distributed.");
540 if (layout.getRank() != srcVecType.getRank())
541 return rewriter.notifyMatchFailure(
542 op,
"Layout rank does not match vector rank.");
545 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
548 return rewriter.notifyMatchFailure(
549 op,
"xegpu::ReductionOp require target attribute attached to "
550 "determine subgroup size");
554 srcVecType.getShape()[0] % sgSize != 0)
555 return rewriter.notifyMatchFailure(op,
556 "Invalid layout or reduction vector "
557 "dimension must match subgroup size.");
559 if (!op.getType().isIntOrFloat())
560 return rewriter.notifyMatchFailure(
561 op,
"Reduction distribution currently only supports floats and "
565 Value laneValVec = adaptor.getVector();
569 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
572 if (adaptor.getAcc())
574 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
576 rewriter.replaceOp(op, fullReduce);
585struct SgToLaneMultiDimReduction
586 :
public OpConversionPattern<vector::MultiDimReductionOp> {
587 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
590 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter)
const override {
593 ArrayRef<int64_t> reductionDims = op.getReductionDims();
594 assert(reductionDims.size() == 1 &&
595 "Expecting single reduction dimension for subgroup multi "
598 VectorType sourceType = op.getSourceVectorType();
599 int64_t rank = sourceType.getRank();
601 ArrayRef<int64_t> shape = sourceType.getShape();
602 if (llvm::any_of(shape.take_front(rank - 2),
603 [](int64_t d) { return d != 1; }))
604 return rewriter.notifyMatchFailure(
605 op,
"only unit leading dimensions are supported for "
606 "multi_reduction with rank > 2");
610 if (op.getType().isIntOrFloat()) {
611 auto reductionDim = reductionDims[0];
612 VectorType origSourceType = op.getSourceVectorType();
613 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
617 op.getKind(), reductionDimSize);
619 if (adaptor.getAcc())
621 result, adaptor.getAcc());
622 }
else if (isReductionLaneLocal(op)) {
626 auto reductionDim = reductionDims[0];
630 reductionDim, op.getLoc(), rewriter);
632 auto reductionDim = reductionDims[0];
633 VectorType sourceType = op.getSourceVectorType();
634 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
638 reductionDim, reductionDimSize, op.getLoc(), rewriter);
640 rewriter.replaceOp(op,
result);
649 ConversionPatternRewriter &rewriter,
Location loc,
652 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
653 mlir::IntegerAttr());
655 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
658 assert(maybeCoords.value().size() == 1 &&
659 "Expected one set of distributed offsets");
663 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
667struct SgToLaneLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
668 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
671 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
672 ConversionPatternRewriter &rewriter)
const override {
673 auto layout = op.getLayoutAttr();
678 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
680 return rewriter.notifyMatchFailure(
681 op,
"the matrix op payload must be a vector type");
683 auto loc = op.getLoc();
684 auto offsets = op.getMixedOffsets();
686 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
688 FailureOr<VectorType> distPayloadTyOrFailure =
690 if (
failed(distPayloadTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 op,
"Failed to distribute matrix op payload based on layout.");
694 SmallVector<Value> offsetsAsValues =
697 SmallVector<Value> newCoords = offsetsAsValues;
698 if (!op.getSubgroupBlockIoAttr()) {
699 newCoords = computeDistributedCoordsForMatrixOp(
700 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
701 if (newCoords.empty())
702 return rewriter.notifyMatchFailure(
703 op,
"Failed to compute distributed coordinates.");
706 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
707 ShapedType::kDynamic);
709 rewriter.getDenseI64ArrayAttr(newConstOffsets);
711 auto newOp = xegpu::LoadMatrixOp::create(
712 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
713 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
714 xegpu::DistributeLayoutAttr{});
715 rewriter.replaceOp(op, newOp.
getResult());
721struct SgToLaneVectorTranspose
722 :
public OpConversionPattern<vector::TransposeOp> {
723 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
726 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter)
const override {
728 xegpu::DistributeLayoutAttr sourceLayout =
730 xegpu::DistributeLayoutAttr resultLayout =
732 if (!sourceLayout || !resultLayout)
733 return rewriter.notifyMatchFailure(
734 op,
"the source or result vector of the transpose op lacks layout "
736 ArrayRef<int64_t> perm = op.getPermutation();
738 if (!resultLayout.isTransposeOf(sourceLayout, perm,
739 xegpu::LayoutKind::Lane))
740 return rewriter.notifyMatchFailure(
741 op,
"the source or result vector layouts must be transposes of "
743 FailureOr<VectorType> distributedResultTypeOrFailure =
745 if (
failed(distributedResultTypeOrFailure))
746 return rewriter.notifyMatchFailure(
747 op,
"Failed to distribute the result vector type in "
748 "vector::Transpose op");
749 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
750 adaptor.getVector(), perm);
751 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
752 distributedResultTypeOrFailure.value()));
759struct SgToLaneVectorBitcast :
public OpConversionPattern<vector::BitCastOp> {
760 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
763 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter)
const override {
765 xegpu::DistributeLayoutAttr resultLayout =
768 return rewriter.notifyMatchFailure(
769 op,
"result vector of the bitcast op lacks layout attribute");
770 FailureOr<VectorType> distributedResultTypeOrFailure =
772 if (
failed(distributedResultTypeOrFailure))
773 return rewriter.notifyMatchFailure(
774 op,
"Failed to distribute the result vector type in "
775 "vector::BitCast op");
776 auto newOp = vector::BitCastOp::create(
777 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
778 adaptor.getSource());
779 rewriter.replaceOp(op, newOp.
getResult());
807template <
typename OpType,
808 typename = std::enable_if_t<llvm::is_one_of<
809 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
810struct SgToLaneCreateMask :
public OpConversionPattern<OpType> {
811 using OpConversionPattern<OpType>::OpConversionPattern;
814 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
815 ConversionPatternRewriter &rewriter)
const override {
816 xegpu::DistributeLayoutAttr layout =
818 if (!layout || !layout.isForSubgroup())
819 return rewriter.notifyMatchFailure(
820 op,
"operation result does not have subgroup distribute layout");
822 VectorType origType = op.getType();
823 FailureOr<VectorType> distTypeOrFailure =
825 if (
failed(distTypeOrFailure))
826 return rewriter.notifyMatchFailure(
827 op,
"unable to compute lane vector type from the layout");
829 VectorType distType = distTypeOrFailure.value();
830 Location loc = op.getLoc();
833 SmallVector<Value> origBounds;
834 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
835 origBounds.append(op.getOperands().begin(), op.getOperands().end());
837 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
838 for (
auto dimSize : dimSizes)
839 origBounds.push_back(
843 ArrayRef<int64_t> origShape = origType.getShape();
846 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
847 mlir::IntegerAttr());
848 auto maybeCoordsVec =
849 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
850 if (
failed(maybeCoordsVec))
851 return rewriter.notifyMatchFailure(
852 op,
"failed to compute distributed coordinates from layout");
854 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
855 int64_t numElements = distType.getNumElements();
856 assert(
static_cast<int64_t
>(coordsVec.size()) == numElements &&
857 "number of coordinate sets must match number of distributed "
863 SmallVector<Value> maskBits;
864 for (
auto &coords : coordsVec) {
865 Value inBounds = trueVal;
866 for (
size_t i = 0; i < coords.size(); ++i) {
867 Value cmp = arith::CmpIOp::create(
868 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
869 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
871 maskBits.push_back(inBounds);
876 if (numElements == 1) {
878 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
881 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
883 rewriter.replaceOp(op,
result);
889struct SgToLaneStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
890 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
893 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
894 ConversionPatternRewriter &rewriter)
const override {
895 auto layout = op.getLayoutAttr();
900 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
902 return rewriter.notifyMatchFailure(
903 op,
"the matrix op payload must be a vector type");
905 auto loc = op.getLoc();
906 auto offsets = op.getMixedOffsets();
908 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
910 FailureOr<VectorType> distPayloadTyOrFailure =
912 if (
failed(distPayloadTyOrFailure))
913 return rewriter.notifyMatchFailure(
914 op,
"Failed to distribute matrix op payload based on layout.");
916 SmallVector<Value> offsetsAsValues =
919 SmallVector<Value> newCoords = offsetsAsValues;
920 if (!op.getSubgroupBlockIoAttr()) {
921 newCoords = computeDistributedCoordsForMatrixOp(
922 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
923 if (newCoords.empty())
924 return rewriter.notifyMatchFailure(
925 op,
"Failed to compute distributed coordinates.");
928 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
929 ShapedType::kDynamic);
931 rewriter.getDenseI64ArrayAttr(newConstOffsets);
933 xegpu::StoreMatrixOp::create(
936 distPayloadTyOrFailure.value()),
937 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
938 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
939 rewriter.eraseOp(op);
978struct SgToLaneStoreScatter
979 :
public OpConversionPattern<xegpu::StoreScatterOp> {
980 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
983 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter)
const override {
985 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
989 VectorType origValueTy = op.getValueType();
994 int chunkSize = op.getChunkSize().value_or(1);
995 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
996 ArrayRef<int64_t> shape = origValueTy.getShape();
997 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
998 [](int64_t d) { return d != 1; }))
999 return rewriter.notifyMatchFailure(
1000 op,
"Only unit dimensions allowed for the leading "
1001 "dimensions of the store vector!");
1003 auto distValueTyOrFailure =
1005 if (
failed(distValueTyOrFailure))
1006 return rewriter.notifyMatchFailure(
1007 op,
"unable to compute expected lane vector type from lane layout");
1009 VectorType distValueTy = distValueTyOrFailure.value();
1010 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
1011 distValueTy.getElementType());
1013 Value distValue = adaptor.getValue();
1014 if (distValue.
getType() != distValueTy1D)
1019 Value distOffsets = adaptor.getOffsets();
1020 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
1021 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1022 distOffsetsTy.getElementType());
1023 distOffsets = castValueTo(
1026 Value distMask = adaptor.getMask();
1027 auto distMaskTy = cast<VectorType>(distMask.
getType());
1028 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1029 distMaskTy.getElementType());
1033 Value distDest = adaptor.getDest();
1034 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1035 distOffsets, distMask, op.getChunkSizeAttr(),
1036 op.getL1HintAttr(), op.getL2HintAttr(),
1037 op.getL3HintAttr(),
nullptr);
1038 rewriter.eraseOp(op);
1047struct SgToLaneVectorStep :
public OpConversionPattern<vector::StepOp> {
1048 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1051 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1052 ConversionPatternRewriter &rewriter)
const override {
1053 xegpu::DistributeLayoutAttr resultLayout =
1055 if (!resultLayout || !resultLayout.isForSubgroup())
1056 return rewriter.notifyMatchFailure(
1057 op,
"the result vector of the step op lacks subgroup layout");
1059 auto loc = op.getLoc();
1060 auto stepResultVecTy = op.getResult().getType();
1061 auto laneShapeOrFailure =
1063 if (
failed(laneShapeOrFailure))
1064 return rewriter.notifyMatchFailure(
1065 op,
"unable to compute lane vector type from the layout");
1066 VectorType newVecTy = laneShapeOrFailure.value();
1068 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1069 mlir::IntegerAttr());
1070 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1071 rewriter, loc, laneId, stepResultVecTy.getShape());
1072 if (
failed(laneDataBlockCoords))
1073 return rewriter.notifyMatchFailure(
1074 op,
"failed to compute lane data block coordinates");
1076 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1077 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1078 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
1079 newVecTy.getNumElements() / laneDataBlockLength);
1080 SmallVector<Value> stepVals;
1088 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1089 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1090 stepVals.push_back(laneDataBlockStartCoord);
1091 for (
int i = 1; i < laneDataBlockLength; ++i) {
1093 stepVals.push_back(arith::AddIOp::create(
1094 rewriter, loc, laneDataBlockStartCoord, offset));
1097 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
1098 "Expecting the number of step values to match the number of "
1099 "elements in the vector");
1101 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1102 rewriter.replaceOp(op, stepOpVal);
1109struct SgToLaneVectorExtract :
public OpConversionPattern<vector::ExtractOp> {
1110 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1113 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter)
const override {
1116 auto resultType = dyn_cast<VectorType>(op.getType());
1118 return rewriter.notifyMatchFailure(op,
"scalar extract not supported");
1120 xegpu::DistributeLayoutAttr layout =
1122 if (!layout || !layout.isForSubgroup())
1127 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1128 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1129 [](int64_t v) {
return v != 1; }))
1130 return rewriter.notifyMatchFailure(
1131 op,
"only innermost dimension distribution is supported for "
1134 auto newOp = vector::ExtractOp::create(
1135 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1136 rewriter.replaceOp(op, newOp.
getResult());
1142struct SgToLaneVectorShapeCast
1143 :
public OpConversionPattern<vector::ShapeCastOp> {
1144 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1147 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1149 xegpu::DistributeLayoutAttr resultLayout =
1151 if (!resultLayout || !resultLayout.isForSubgroup())
1152 return rewriter.notifyMatchFailure(
1153 op,
"the result vector of the shape_cast op lacks subgroup layout");
1156 resultLayout, op.getResultVectorType());
1157 if (
failed(resultDistTypeOrFailure))
1158 return rewriter.notifyMatchFailure(
1159 op,
"failed to get distributed vector type for result");
1161 Value source = adaptor.getSource();
1162 auto newShapeCast = vector::ShapeCastOp::create(
1163 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1164 rewriter.replaceOp(op, newShapeCast);
1172struct SgToLaneVectorExtractStridedSlice
1173 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
1174 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1177 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter)
const override {
1179 xegpu::DistributeLayoutAttr resultLayout =
1181 if (!resultLayout || !resultLayout.isForSubgroup())
1184 VectorType resultType = op.getType();
1185 auto distResultTyOrFailure =
1187 if (
failed(distResultTyOrFailure))
1188 return rewriter.notifyMatchFailure(
1189 op,
"unable to compute distributed vector type from lane layout");
1190 VectorType distResultTy = *distResultTyOrFailure;
1192 SmallVector<int64_t> distributedDims =
1193 getDistributedDims(resultType, distResultTy);
1196 int64_t sourceRank = op.getSourceVectorType().getRank();
1197 SmallVector<Attribute> updatedSizes =
1198 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1199 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1200 op.getOffsets(), [](Attribute attr) { return attr; });
1201 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1202 op.getStrides(), [](Attribute attr) { return attr; });
1203 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1204 updatedSizes.push_back(
1205 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1206 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1207 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1212 if (!distributedDims.empty()) {
1213 if (distributedDims.size() != 1)
1214 return rewriter.notifyMatchFailure(
1215 op,
"only single dimension distribution is supported");
1216 int64_t distDim = distributedDims[0];
1219 return rewriter.notifyMatchFailure(
1220 op,
"target attribute required to determine subgroup size");
1223 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1224 return rewriter.notifyMatchFailure(
1225 op,
"source of extract_strided_slice lacks distribution layout");
1226 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1227 auto laneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
1230 if (laneLayout[distDim] < subgroupSize &&
1231 subgroupSize % laneLayout[distDim] == 0)
1232 subgroupSize = laneLayout[distDim];
1233 if (sourceDistrDimSize % subgroupSize != 0)
1234 return rewriter.notifyMatchFailure(
1235 op,
"source size along distributed dim is not a multiple of "
1237 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1240 if (distDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1241 sourceLaneData[distDim] != 1)
1242 return rewriter.notifyMatchFailure(
1243 op,
"expecting unit lane data along the distributed dimension");
1244 int64_t distrDimOffset =
1245 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1246 if (distrDimOffset % subgroupSize != 0)
1247 return rewriter.notifyMatchFailure(
1248 op,
"offset along distributed dim is not a multiple of "
1251 updatedSizes[distDim] =
1252 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1253 updatedOffsets[distDim] =
1254 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1257 auto newOp = vector::ExtractStridedSliceOp::create(
1258 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1259 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1260 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1261 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1262 rewriter.replaceOp(op, newOp.
getResult());
1324struct SgToLaneBroadcast :
public OpConversionPattern<vector::BroadcastOp> {
1325 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1328 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1329 ConversionPatternRewriter &rewriter)
const override {
1330 xegpu::DistributeLayoutAttr resultLayout =
1332 if (!resultLayout || !resultLayout.isForSubgroup())
1333 return rewriter.notifyMatchFailure(
1334 op,
"result does not have subgroup distribute layout");
1336 VectorType destType = op.getResultVectorType();
1337 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1339 xegpu::DistributeLayoutAttr sourceLayout =
1343 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1346 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1348 "broadcast source layout must be a slice of result layout");
1349 }
else if (rankDiff == 0) {
1351 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1352 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1353 broadcastUnitDimsSet.end());
1354 assert(sourceLayout.isEqualTo(
1355 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1356 "The sg_data for unit dimensions should be set as 1");
1357 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1362 return rewriter.notifyMatchFailure(
1363 op,
"broadcast from scalar must not have a layout attribute");
1368 if (
failed(destDistType))
1369 return rewriter.notifyMatchFailure(
1370 op,
"failed to distribute the result vector type");
1372 Value source = adaptor.getSource();
1374 if (source.
getType() == destDistType.value()) {
1375 rewriter.replaceOp(op, source);
1379 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1380 destDistType.value(), source);
1381 rewriter.replaceOp(op, newOp);
1389struct SgToLaneVectorInsertStridedSlice
1390 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
1391 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1394 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1395 ConversionPatternRewriter &rewriter)
const override {
1396 xegpu::DistributeLayoutAttr resultLayout =
1398 if (!resultLayout || !resultLayout.isForSubgroup())
1401 VectorType destType = op.getDestVectorType();
1402 auto distDestTyOrFailure =
1404 if (
failed(distDestTyOrFailure))
1405 return rewriter.notifyMatchFailure(
1406 op,
"unable to compute distributed vector type from lane layout");
1407 VectorType distDestTy = *distDestTyOrFailure;
1409 SmallVector<int64_t> destDistributedDims =
1410 getDistributedDims(destType, distDestTy);
1412 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1413 op.getOffsets(), [](Attribute attr) { return attr; });
1415 if (!destDistributedDims.empty()) {
1416 if (destDistributedDims.size() != 1)
1417 return rewriter.notifyMatchFailure(
1418 op,
"only single dimension distribution is supported");
1419 int64_t destDistDim = destDistributedDims[0];
1423 return rewriter.notifyMatchFailure(
1424 op,
"target attribute required to determine subgroup size");
1427 VectorType srcType = op.getSourceVectorType();
1429 int64_t sourceDistDim =
1430 destDistDim - (destType.getRank() - srcType.getRank());
1431 if (sourceDistDim < 0)
1432 return rewriter.notifyMatchFailure(
1433 op,
"distributed dimension must be in the last k dims of dest");
1437 if (!destLayout || !sourceLayout ||
1438 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1439 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1440 return rewriter.notifyMatchFailure(
1441 op,
"source or dest of insert_strided_slice lacks distribution "
1444 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1445 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1448 if ((destDistDim <
static_cast<int64_t
>(destLaneData.size()) &&
1449 destLaneData[destDistDim] != 1) ||
1450 (sourceDistDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1451 sourceLaneData[sourceDistDim] != 1))
1452 return rewriter.notifyMatchFailure(
1453 op,
"expecting unit lane data along the distributed dimension");
1455 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1456 if (srcDistrDimSize % subgroupSize != 0)
1457 return rewriter.notifyMatchFailure(
1458 op,
"source distributed dim size is not a multiple of "
1461 int64_t destDistrDimOffset =
1462 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1463 if (destDistrDimOffset % subgroupSize != 0)
1464 return rewriter.notifyMatchFailure(
1465 op,
"offset along distributed dim is not a multiple of "
1468 updatedOffsets[destDistDim] =
1469 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1472 auto newOp = vector::InsertStridedSliceOp::create(
1473 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1475 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1476 rewriter.replaceOp(op, newOp.
getResult());
1483struct SgToLaneVectorInsert :
public OpConversionPattern<vector::InsertOp> {
1484 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1487 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1488 ConversionPatternRewriter &rewriter)
const override {
1490 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1492 return rewriter.notifyMatchFailure(op,
"scalar insert not supported");
1494 xegpu::DistributeLayoutAttr layout =
1496 if (!layout || !layout.isForSubgroup())
1501 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1502 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1503 [](int64_t v) {
return v != 1; }))
1504 return rewriter.notifyMatchFailure(
1505 op,
"only innermost dimension distribution is supported for "
1508 auto newOp = vector::InsertOp::create(
1509 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1510 op.getMixedPosition());
1511 rewriter.replaceOp(op, newOp.
getResult());
1529static FailureOr<Value>
1530shuffleDataAsLaneLayoutChange(ConversionPatternRewriter &rewriter,
Location loc,
1533 VectorType srcTy = dyn_cast<VectorType>(src.
getType());
1534 if (!srcTy || srcTy.getRank() != 2)
1537 if (targetLaneNum <= 0 || currentLaneNum != targetLaneNum * 2)
1541 srcTy.getNumElements() * srcTy.getElementTypeBitWidth();
1542 if (vectorBitWidth % 32 != 0)
1554 Type shuffleElemTy = rewriter.getI32Type();
1555 int64_t numShuffles = vectorBitWidth / 32;
1556 VectorType shuffleBundleTy = VectorType::get({numShuffles}, shuffleElemTy);
1558 Value temp = arith::ConstantOp::create(
1561 IntegerAttr::get(shuffleElemTy, 0)));
1562 VectorType flatSrcTy =
1563 VectorType::get({srcTy.getNumElements()}, srcTy.getElementType());
1564 Value flatSrc = vector::ShapeCastOp::create(rewriter, loc, flatSrcTy, src);
1565 Value shuffleBundle =
1566 vector::BitCastOp::create(rewriter, loc, shuffleBundleTy, flatSrc);
1567 for (
int64_t i = 0; i < numShuffles; i++) {
1569 vector::ExtractOp::create(rewriter, loc, shuffleBundle, i);
1570 shuffleElem = gpu::ShuffleOp::create(rewriter, loc, shuffleElem, 0,
1571 targetLaneNum, gpu::ShuffleMode::UP)
1573 temp = vector::InsertOp::create(rewriter, loc, shuffleElem, temp, i);
1575 temp = vector::BitCastOp::create(rewriter, loc, flatSrcTy, temp);
1576 temp = vector::ShapeCastOp::create(rewriter, loc, srcTy, temp);
1581 Value res = vector::ShuffleOp::create(rewriter, loc, src, temp,
indices);
1586struct SgToLaneConvertLayout
1587 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
1588 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1591 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1592 ConversionPatternRewriter &rewriter)
const override {
1593 auto inputLayout = op.getInputLayoutAttr();
1594 auto targetLayout = op.getTargetLayoutAttr();
1595 Type valType = op.getResult().
getType();
1598 rewriter.replaceOp(op, op.getSource());
1602 auto resShape = cast<VectorType>(valType).getShape();
1603 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1607 if (inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1608 xegpu::LayoutKind::Lane)) {
1609 rewriter.replaceOp(op, adaptor.getSource());
1620 if (inputLayout.getEffectiveOrderAsInt() ==
1621 targetLayout.getEffectiveOrderAsInt() &&
1622 inputLayout.getRank() == 2 && targetLayout.getRank() == 2) {
1623 auto laneLayout = inputLayout.getEffectiveLaneLayoutAsInt();
1624 auto targetLaneLayout = targetLayout.getEffectiveLaneLayoutAsInt();
1625 auto laneData = inputLayout.getEffectiveLaneDataAsInt();
1626 auto targetLaneData = targetLayout.getEffectiveLaneDataAsInt();
1627 if (laneLayout.size() == 2 && targetLaneLayout.size() == 2 &&
1628 laneData == targetLaneData && laneLayout[1] == 1 &&
1629 targetLaneLayout[1] == 1 && laneLayout[0] > 1 &&
1630 laneLayout[0] != targetLaneLayout[0]) {
1631 FailureOr<Value> res = shuffleDataAsLaneLayoutChange(
1632 rewriter, op.getLoc(), adaptor.getSource(), laneLayout[0],
1633 targetLaneLayout[0]);
1634 if (succeeded(res)) {
1635 rewriter.replaceOp(op, *res);
1641 return rewriter.notifyMatchFailure(
1642 op,
"lowering incompatible convert_layout not yet supported");
1647struct SgToLaneVectorInterleave
1648 :
public OpConversionPattern<vector::InterleaveOp> {
1649 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1652 matchAndRewrite(vector::InterleaveOp op, OpAdaptor adaptor,
1653 ConversionPatternRewriter &rewriter)
const override {
1655 auto newOp = vector::InterleaveOp::create(
1656 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
1657 rewriter.replaceOp(op, newOp.
getResult());
1663struct SgToLaneVectorDeinterleave
1664 :
public OpConversionPattern<vector::DeinterleaveOp> {
1665 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1668 matchAndRewrite(vector::DeinterleaveOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter)
const override {
1671 auto newOp = vector::DeinterleaveOp::create(rewriter, op.getLoc(),
1672 adaptor.getSource());
1678struct SgToLaneDpasMx :
public OpConversionPattern<xegpu::DpasMxOp> {
1679 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
1682 matchAndRewrite(xegpu::DpasMxOp op, OpAdaptor adaptor,
1683 ConversionPatternRewriter &rewriter)
const override {
1688 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc))
1689 return rewriter.notifyMatchFailure(
1690 op,
"target uArch does not support scaled subgroup mma");
1692 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
1693 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
1694 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
1695 if (!layoutA || !layoutB || !layoutCd)
1696 return rewriter.notifyMatchFailure(
1697 op,
"missing required layout attributes for DpasMxOp distribution");
1700 auto expected1DTypeResult =
1702 auto expected1DTypeA =
1704 auto expected1DTypeB =
1707 VectorType expected1DTypeScaleA, expected1DTypeScaleB;
1708 if (op.getScaleA()) {
1709 auto layoutScaleA = cast<xegpu::LayoutAttr>(op.getLayoutAScaleAttr());
1711 cast<VectorType>(op.getScaleA().getType()), layoutScaleA);
1712 if (
failed(expected1DTypeScaleAOrFailure))
1713 return rewriter.notifyMatchFailure(
1714 op,
"failed to calculate expected 1D vector type for scale A");
1715 expected1DTypeScaleA = expected1DTypeScaleAOrFailure.value();
1717 if (op.getScaleB()) {
1718 auto layoutScaleB = cast<xegpu::LayoutAttr>(op.getLayoutBScaleAttr());
1720 cast<VectorType>(op.getScaleB().getType()), layoutScaleB);
1721 if (
failed(expected1DTypeScaleBOrFailure))
1722 return rewriter.notifyMatchFailure(
1723 op,
"failed to calculate expected 1D vector type for scale B");
1724 expected1DTypeScaleB = expected1DTypeScaleBOrFailure.value();
1727 auto expectedNDTypeResult =
1729 if (
failed(expected1DTypeResult) ||
failed(expected1DTypeA) ||
1731 return rewriter.notifyMatchFailure(
1733 "failed to calculate supported workitem 1D vector types for DpasOp "
1735 if (
failed(expectedNDTypeResult))
1736 return rewriter.notifyMatchFailure(
1737 op,
"unable to compute expected workitem vector type for DpasOp from "
1741 const auto *uArchInstruction = dyn_cast<
1743 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
1744 assert(uArchInstruction);
1745 auto wiAType = expected1DTypeA.value();
1746 auto wiBType = expected1DTypeB.value();
1748 unsigned aPackedBitWidth =
1749 wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
1750 unsigned bPackedBitWidth =
1751 wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
1752 if (aPackedBitWidth % uArchInstruction->getPackedFormatBitSizeA())
1753 return rewriter.notifyMatchFailure(
1754 op,
"A operand packed bit width must be a multiple of uArch packed "
1755 "format requirement");
1756 if (bPackedBitWidth % uArchInstruction->getPackedFormatBitSizeB())
1757 return rewriter.notifyMatchFailure(
1758 op,
"B operand packed bit width must be a multiple of uArch packed "
1759 "format requirement");
1761 auto newOp = xegpu::DpasMxOp::create(
1762 rewriter, op->getLoc(), expected1DTypeResult.value(),
1764 expected1DTypeA.value()),
1766 expected1DTypeB.value()),
1768 ? castValueTo(rewriter,
1770 expected1DTypeResult.value())
1774 ? castValueTo(rewriter,
1776 expected1DTypeScaleA)
1779 ? castValueTo(rewriter,
1781 expected1DTypeScaleB)
1787 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
1788 expectedNDTypeResult.value()));
1793struct XeGPUSgToLaneDistributePass
1795 XeGPUSgToLaneDistributePass> {
1796 void runOnOperation()
override;
1801void XeGPUSgToLaneDistributePass::runOnOperation() {
1804 Operation *root = getOperation();
1806 signalPassFailure();
1811 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1813 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1819 TypeConverter typeConverter;
1823 auto materializeCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
1824 Location loc) -> Value {
1825 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
1828 typeConverter.addSourceMaterialization(materializeCast);
1829 typeConverter.addTargetMaterialization(materializeCast);
1834 typeConverter, patterns,
target, root);
1835 target.addLegalOp<UnrealizedConversionCastOp>();
1836 (void)applyPartialConversion(root,
target, std::move(patterns));
1847 typeConverter.addConversion([](
Type type) ->
Type {
return type; });
1849 typeConverter.addConversion([](TensorDescType type) ->
Type {
1850 if (type.getLayoutAttr()) {
1851 return type.dropLayouts();
1859 auto getSubShapeAndCount = [](VectorType vecTy,
1860 xegpu::DistributeLayoutAttr layout)
1863 if (failed(distTyOrFailure))
1870 std::move(loopArgTypes));
1878 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1879 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
1881 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
1882 if (isa<xegpu::ConvertLayoutOp>(op))
1884 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1887 return !anchorOp.getAnchorLayout();
1890 target.addDynamicallyLegalOp<arith::ConstantOp>(
1891 [=](arith::ConstantOp op) ->
bool {
1893 if (!isa<VectorType>(op.getResult().getType()))
1899 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1900 [=](
Operation *op) -> std::optional<bool> {
1905 if (op->getNumResults() != 1)
1908 VectorType resultType =
1909 dyn_cast<VectorType>(op->getResult(0).getType());
1914 for (
Value operand : op->getOperands()) {
1915 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1916 if (!operandType || operandType.getShape() != resultType.getShape()) {
1924 target.addDynamicallyLegalOp<vector::ReductionOp>(
1925 [=](vector::ReductionOp op) ->
bool {
1930 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1931 [=](vector::MultiDimReductionOp op) ->
bool {
1932 return !isValidSubgroupMultiReductionOp(op);
1934 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1935 vector::TransposeOp, vector::BitCastOp,
1936 vector::ShapeCastOp, vector::StepOp,
1937 vector::BroadcastOp>([=](
Operation *op) ->
bool {
1940 target.addDynamicallyLegalOp<vector::ExtractOp>(
1941 [=](vector::ExtractOp op) ->
bool {
1942 if (!isa<VectorType>(op.getType()))
1946 target.addDynamicallyLegalOp<vector::InsertOp>(
1947 [=](vector::InsertOp op) ->
bool {
1950 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1951 [=](vector::ExtractStridedSliceOp op) ->
bool {
1954 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1955 [=](vector::InsertStridedSliceOp op) ->
bool {
1958 target.addDynamicallyLegalOp<vector::InterleaveOp, vector::DeinterleaveOp>(
1962 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1964 SgToLaneCreateNdDesc, SgToLaneLoadNd, SgToLaneStoreNd, SgToLaneDpas,
1965 SgToLaneElementWise, SgToLaneArithConstant, SgToLanePrefetchNd,
1966 SgToLaneLoadGather, SgToLaneStoreScatter, SgToLaneVectorReduction,
1967 SgToLaneMultiDimReduction, SgToLaneVectorExtract, SgToLaneVectorInsert,
1968 SgToLaneVectorExtractStridedSlice, SgToLaneVectorInsertStridedSlice,
1969 SgToLaneLoadMatrix, SgToLaneStoreMatrix, SgToLaneConvertLayout,
1970 SgToLaneVectorTranspose, SgToLaneVectorBitcast, SgToLaneVectorStep,
1971 SgToLaneVectorShapeCast, SgToLaneBroadcast,
1972 SgToLaneCreateMask<vector::CreateMaskOp>,
1973 SgToLaneCreateMask<vector::ConstantMaskOp>, SgToLaneVectorDeinterleave,
1974 SgToLaneVectorInterleave, SgToLaneDpasMx>(typeConverter,
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
@ SubgroupMatrixMultiplyAcc
@ SubgroupScaledMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUSgToLaneDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, Operation *topLevelOp)
Defines type conversions and legality for XeGPU subgroup to lane distribution and appends the require...
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
DenseMap< Value, SmallVector< Type > > precomputeLoopBlockArgTypes(Operation *topLevelOp, SubShapeAndCountFn getSubShapeAndCount)
Pre-computes distributed VectorType mappings for every value carried through an SCF loop under topLev...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
void addVectorTypeConversion(TypeConverter &converter, SubShapeAndCountFn getSubShapeAndCount, DenseMap< Value, SmallVector< Type > > loopArgTypes)
Adds a context-aware VectorType conversion to converter (1:1 shape-changing or 1:N,...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUSgToLaneDistributeTypeConversions(TypeConverter &typeConverter, Operation *topLevelOp)
Define only the type conversions needed for XeGPU subgroup to lane distribution.
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
void cleanupUnrealizedConversionCasts(Operation *root, const llvm::SmallSetVector< UnrealizedConversionCastOp, 8 > &existingCasts)
Cleans up UnrealizedConversionCastOps inserted during SCF structural type conversion and/or XeGPU unr...
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
bool isSupportedInstruction(InstructionKind instr) const
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const