29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
36#define GEN_PASS_DEF_XEGPUSGTOLANEDISTRIBUTE
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43#define DEBUG_TYPE "xegpu-sg-to-lane-distribute"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49static Value castValueTo(ConversionPatternRewriter &rewriter,
52 if (v.getType() == expectedTy)
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
62 return newOp.getResult(0);
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
71 if (!resLayout || !resLayout.isForSubgroup())
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
84 return op.getReductionDims().size() == 1;
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
93 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
94 "MultiDimReductionOp");
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
108 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
112 return distributedDims;
117struct SgToLaneCreateNdDesc
118 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
119 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
122 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter)
const override {
124 xegpu::TensorDescType resultType = op.getType();
126 if (!resultType.getLayout())
129 auto newOp = xegpu::CreateNdDescOp::create(
130 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
132 rewriter.replaceOp(op, newOp.getResult());
140struct SgToLaneLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
141 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
144 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter)
const override {
146 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
152 if (op.getTensorDescType().getLayout() != layout)
153 return rewriter.notifyMatchFailure(
154 op,
"conflicting layout attributes on tensor descriptor and anchor");
157 return rewriter.notifyMatchFailure(
158 op,
"xegpu::LoadNdOp require target attribute attached to "
159 "determine transpose "
161 auto supportedLaneResultTyOrFailure =
163 auto expectedLaneResultTyOrFailure =
165 if (failed(supportedLaneResultTyOrFailure))
166 return rewriter.notifyMatchFailure(
167 op,
"unable to compute the lane vector type for LoadNdOp");
168 if (failed(expectedLaneResultTyOrFailure))
169 return rewriter.notifyMatchFailure(
170 op,
"unable to compute expected lane vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedLaneResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(),
nullptr);
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedLaneResultTyOrFailure.value()));
190struct SgToLaneStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter)
const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op,
"conflicting layout attributes on tensor descriptor and anchor");
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op,
"conflicting layout attributes on value and anchor");
209 auto supportedLaneValueTyOrFailure =
211 if (failed(supportedLaneValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
214 "unable to compute lane vector type for StoreNdOp value from tensor "
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
220 supportedLaneValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
223 rewriter.eraseOp(op);
231struct SgToLaneDpas :
public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
238 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
239 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
240 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
241 if (!layoutA || !layoutB || !layoutCd)
243 auto laneResultTyOrFailure =
245 auto laneATypeOrFailure =
247 auto laneBTypeOrFailure =
249 auto expectedLaneResultTyOrFailure =
251 if (failed(laneResultTyOrFailure) || failed(laneATypeOrFailure) ||
252 failed(laneBTypeOrFailure))
253 return rewriter.notifyMatchFailure(
254 op,
"failed to calculate supported lane vector types for DpasOp "
256 if (failed(expectedLaneResultTyOrFailure))
257 return rewriter.notifyMatchFailure(
258 op,
"unable to compute expected lane vector type for DpasOp from "
264 const auto *uArchInstruction =
265 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
268 if (uArchInstruction) {
269 auto laneAType = laneATypeOrFailure.value();
270 auto laneBType = laneBTypeOrFailure.value();
272 unsigned aPackedBitWidth =
273 laneAType.getElementTypeBitWidth() * laneAType.getNumElements();
274 unsigned bPackedBitWidth =
275 laneBType.getElementTypeBitWidth() * laneBType.getNumElements();
276 unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
277 unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
279 if (aPackedBitWidth % expectedABitSize != 0)
280 return rewriter.notifyMatchFailure(
282 "A operand packed bit width must be a multiple of uArch packed "
283 "format requirement");
284 if (bPackedBitWidth % expectedBBitSize != 0)
285 return rewriter.notifyMatchFailure(
287 "B operand packed bit width must be a multiple of uArch packed "
288 "format requirement");
292 auto newOp = xegpu::DpasOp::create(
293 rewriter, op->getLoc(), laneResultTyOrFailure.value(),
295 laneATypeOrFailure.value()),
297 laneBTypeOrFailure.value()),
299 laneResultTyOrFailure.value()),
303 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
304 expectedLaneResultTyOrFailure.value()));
311struct SgToLaneElementWise :
public ConversionPattern {
313 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
317 ConversionPatternRewriter &rewriter)
const override {
324 return rewriter.notifyMatchFailure(
325 op,
"operation result is not a vector type");
327 xegpu::DistributeLayoutAttr layout =
329 if (!layout || !layout.isForSubgroup())
330 return rewriter.notifyMatchFailure(
331 op,
"operation result does not have subgroup distribute layout");
333 auto laneShapeOrFailure =
336 if (failed(laneShapeOrFailure))
337 return rewriter.notifyMatchFailure(
338 op,
"unable to compute lane vector type from the layout");
340 VectorType newResultType = laneShapeOrFailure.value();
342 state.addOperands(operands);
343 state.addTypes(newResultType);
346 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
347 state.addAttribute(attr.getName(), attr.getValue());
349 Operation *newOp = rewriter.create(state);
351 rewriter.replaceOp(op, newOp->
getResult(0));
358struct SgToLaneArithConstant :
public OpConversionPattern<arith::ConstantOp> {
359 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
362 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 auto resultType = dyn_cast<VectorType>(op.getType());
369 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
371 return rewriter.notifyMatchFailure(
372 op,
"only dense splat vector constants are supported");
374 xegpu::DistributeLayoutAttr layout =
376 if (!layout || !layout.isForSubgroup())
377 return rewriter.notifyMatchFailure(
378 op,
"operation result does not have subgroup distribute layout");
380 auto laneShapeOrFailure =
383 if (
failed(laneShapeOrFailure))
384 return rewriter.notifyMatchFailure(
385 op,
"unable to compute lane vector type from the layout");
387 VectorType newResultType = laneShapeOrFailure.value();
388 auto sclarValue = dense.getSplatValue<Attribute>();
391 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
393 rewriter.replaceOp(op, newOp.
getResult());
399struct SgToLanePrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
400 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
403 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
405 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
410 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
411 op.getMixedOffsets(), op.getL1HintAttr(),
412 op.getL2HintAttr(), op.getL3HintAttr(),
414 rewriter.eraseOp(op);
452struct SgToLaneLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
453 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
456 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter)
const override {
458 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
462 VectorType origResultTy = op.getValueType();
467 int chunkSize = op.getChunkSize().value_or(1);
468 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
469 ArrayRef<int64_t> shape = origResultTy.getShape();
471 shape.take_front(origResultTy.getRank() - effectiveVecRank),
472 [](int64_t d) { return d != 1; }))
473 return rewriter.notifyMatchFailure(
474 op,
"Only unit dimensions allowed for the leading "
475 "dimensions of the load vector!");
477 auto distResultTyOrFailure =
479 if (
failed(distResultTyOrFailure))
480 return rewriter.notifyMatchFailure(
481 op,
"unable to compute expected lane vector type from lane layout");
483 VectorType distResultTy = distResultTyOrFailure.value();
484 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
485 distResultTy.getElementType());
488 Value distOffsets = adaptor.getOffsets();
489 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
490 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
491 distOffsetsTy.getElementType());
492 distOffsets = castValueTo(
495 Value distMask = adaptor.getMask();
496 auto distMaskTy = cast<VectorType>(distMask.
getType());
497 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
498 distMaskTy.getElementType());
502 Value distSource = adaptor.getSource();
503 auto newOp = xegpu::LoadGatherOp::create(
504 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
505 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
506 op.getL3HintAttr(),
nullptr);
509 if (distResultTy1D != distResultTy)
512 rewriter.replaceOp(op,
result);
521struct SgToLaneVectorReduction
522 :
public OpConversionPattern<vector::ReductionOp> {
523 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
526 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter)
const override {
531 if (!layout || !layout.isForSubgroup())
534 VectorType srcVecType = op.getSourceVectorType();
536 if (srcVecType.getRank() != 1)
537 return rewriter.notifyMatchFailure(
538 op,
"Only rank 1 reductions can be distributed.");
540 if (layout.getRank() != srcVecType.getRank())
541 return rewriter.notifyMatchFailure(
542 op,
"Layout rank does not match vector rank.");
545 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
548 return rewriter.notifyMatchFailure(
549 op,
"xegpu::ReductionOp require target attribute attached to "
550 "determine subgroup size");
554 srcVecType.getShape()[0] % sgSize != 0)
555 return rewriter.notifyMatchFailure(op,
556 "Invalid layout or reduction vector "
557 "dimension must match subgroup size.");
559 if (!op.getType().isIntOrFloat())
560 return rewriter.notifyMatchFailure(
561 op,
"Reduction distribution currently only supports floats and "
565 Value laneValVec = adaptor.getVector();
569 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
572 if (adaptor.getAcc())
574 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
576 rewriter.replaceOp(op, fullReduce);
585struct SgToLaneMultiDimReduction
586 :
public OpConversionPattern<vector::MultiDimReductionOp> {
587 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
590 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter)
const override {
593 ArrayRef<int64_t> reductionDims = op.getReductionDims();
594 assert(reductionDims.size() == 1 &&
595 "Expecting single reduction dimension for subgroup multi "
598 VectorType sourceType = op.getSourceVectorType();
599 int64_t rank = sourceType.getRank();
601 ArrayRef<int64_t> shape = sourceType.getShape();
602 if (llvm::any_of(shape.take_front(rank - 2),
603 [](int64_t d) { return d != 1; }))
604 return rewriter.notifyMatchFailure(
605 op,
"only unit leading dimensions are supported for "
606 "multi_reduction with rank > 2");
610 if (op.getType().isIntOrFloat()) {
611 auto reductionDim = reductionDims[0];
612 VectorType origSourceType = op.getSourceVectorType();
613 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
617 op.getKind(), reductionDimSize);
619 if (adaptor.getAcc())
621 result, adaptor.getAcc());
622 }
else if (isReductionLaneLocal(op)) {
626 auto reductionDim = reductionDims[0];
630 reductionDim, op.getLoc(), rewriter);
632 auto reductionDim = reductionDims[0];
633 VectorType sourceType = op.getSourceVectorType();
634 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
638 reductionDim, reductionDimSize, op.getLoc(), rewriter);
640 rewriter.replaceOp(op,
result);
649 ConversionPatternRewriter &rewriter,
Location loc,
652 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
653 mlir::IntegerAttr());
655 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
658 assert(maybeCoords.value().size() == 1 &&
659 "Expected one set of distributed offsets");
663 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
667struct SgToLaneLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
668 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
671 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
672 ConversionPatternRewriter &rewriter)
const override {
673 auto layout = op.getLayoutAttr();
678 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
680 return rewriter.notifyMatchFailure(
681 op,
"the matrix op payload must be a vector type");
683 auto loc = op.getLoc();
684 auto offsets = op.getMixedOffsets();
686 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
688 FailureOr<VectorType> distPayloadTyOrFailure =
690 if (
failed(distPayloadTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 op,
"Failed to distribute matrix op payload based on layout.");
694 SmallVector<Value> offsetsAsValues =
697 SmallVector<Value> newCoords = offsetsAsValues;
698 if (!op.getSubgroupBlockIoAttr()) {
699 newCoords = computeDistributedCoordsForMatrixOp(
700 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
701 if (newCoords.empty())
702 return rewriter.notifyMatchFailure(
703 op,
"Failed to compute distributed coordinates.");
706 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
707 ShapedType::kDynamic);
709 rewriter.getDenseI64ArrayAttr(newConstOffsets);
711 auto newOp = xegpu::LoadMatrixOp::create(
712 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
713 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
714 xegpu::DistributeLayoutAttr{});
715 rewriter.replaceOp(op, newOp.
getResult());
721struct SgToLaneVectorTranspose
722 :
public OpConversionPattern<vector::TransposeOp> {
723 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
726 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter)
const override {
728 xegpu::DistributeLayoutAttr sourceLayout =
730 xegpu::DistributeLayoutAttr resultLayout =
732 if (!sourceLayout || !resultLayout)
733 return rewriter.notifyMatchFailure(
734 op,
"the source or result vector of the transpose op lacks layout "
736 ArrayRef<int64_t> perm = op.getPermutation();
738 if (!resultLayout.isTransposeOf(sourceLayout, perm,
739 xegpu::LayoutKind::Lane))
740 return rewriter.notifyMatchFailure(
741 op,
"the source or result vector layouts must be transposes of "
743 FailureOr<VectorType> distributedResultTypeOrFailure =
745 if (
failed(distributedResultTypeOrFailure))
746 return rewriter.notifyMatchFailure(
747 op,
"Failed to distribute the result vector type in "
748 "vector::Transpose op");
749 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
750 adaptor.getVector(), perm);
751 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
752 distributedResultTypeOrFailure.value()));
759struct SgToLaneVectorBitcast :
public OpConversionPattern<vector::BitCastOp> {
760 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
763 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter)
const override {
765 xegpu::DistributeLayoutAttr resultLayout =
768 return rewriter.notifyMatchFailure(
769 op,
"result vector of the bitcast op lacks layout attribute");
770 FailureOr<VectorType> distributedResultTypeOrFailure =
772 if (
failed(distributedResultTypeOrFailure))
773 return rewriter.notifyMatchFailure(
774 op,
"Failed to distribute the result vector type in "
775 "vector::BitCast op");
776 auto newOp = vector::BitCastOp::create(
777 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
778 adaptor.getSource());
779 rewriter.replaceOp(op, newOp.
getResult());
807template <
typename OpType,
808 typename = std::enable_if_t<llvm::is_one_of<
809 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
810struct SgToLaneCreateMask :
public OpConversionPattern<OpType> {
811 using OpConversionPattern<OpType>::OpConversionPattern;
814 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
815 ConversionPatternRewriter &rewriter)
const override {
816 xegpu::DistributeLayoutAttr layout =
818 if (!layout || !layout.isForSubgroup())
819 return rewriter.notifyMatchFailure(
820 op,
"operation result does not have subgroup distribute layout");
822 VectorType origType = op.getType();
823 FailureOr<VectorType> distTypeOrFailure =
825 if (
failed(distTypeOrFailure))
826 return rewriter.notifyMatchFailure(
827 op,
"unable to compute lane vector type from the layout");
829 VectorType distType = distTypeOrFailure.value();
830 Location loc = op.getLoc();
833 SmallVector<Value> origBounds;
834 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
835 origBounds.append(op.getOperands().begin(), op.getOperands().end());
837 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
838 for (
auto dimSize : dimSizes)
839 origBounds.push_back(
843 ArrayRef<int64_t> origShape = origType.getShape();
846 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
847 mlir::IntegerAttr());
848 auto maybeCoordsVec =
849 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
850 if (
failed(maybeCoordsVec))
851 return rewriter.notifyMatchFailure(
852 op,
"failed to compute distributed coordinates from layout");
854 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
855 int64_t numElements = distType.getNumElements();
856 assert(
static_cast<int64_t
>(coordsVec.size()) == numElements &&
857 "number of coordinate sets must match number of distributed "
863 SmallVector<Value> maskBits;
864 for (
auto &coords : coordsVec) {
865 Value inBounds = trueVal;
866 for (
size_t i = 0; i < coords.size(); ++i) {
867 Value cmp = arith::CmpIOp::create(
868 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
869 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
871 maskBits.push_back(inBounds);
876 if (numElements == 1) {
878 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
881 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
883 rewriter.replaceOp(op,
result);
889struct SgToLaneStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
890 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
893 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
894 ConversionPatternRewriter &rewriter)
const override {
895 auto layout = op.getLayoutAttr();
900 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
902 return rewriter.notifyMatchFailure(
903 op,
"the matrix op payload must be a vector type");
905 auto loc = op.getLoc();
906 auto offsets = op.getMixedOffsets();
908 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
910 FailureOr<VectorType> distPayloadTyOrFailure =
912 if (
failed(distPayloadTyOrFailure))
913 return rewriter.notifyMatchFailure(
914 op,
"Failed to distribute matrix op payload based on layout.");
916 SmallVector<Value> offsetsAsValues =
919 SmallVector<Value> newCoords = offsetsAsValues;
920 if (!op.getSubgroupBlockIoAttr()) {
921 newCoords = computeDistributedCoordsForMatrixOp(
922 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
923 if (newCoords.empty())
924 return rewriter.notifyMatchFailure(
925 op,
"Failed to compute distributed coordinates.");
928 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
929 ShapedType::kDynamic);
931 rewriter.getDenseI64ArrayAttr(newConstOffsets);
933 xegpu::StoreMatrixOp::create(
936 distPayloadTyOrFailure.value()),
937 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
938 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
939 rewriter.eraseOp(op);
978struct SgToLaneStoreScatter
979 :
public OpConversionPattern<xegpu::StoreScatterOp> {
980 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
983 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter)
const override {
985 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
989 VectorType origValueTy = op.getValueType();
994 int chunkSize = op.getChunkSize().value_or(1);
995 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
996 ArrayRef<int64_t> shape = origValueTy.getShape();
997 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
998 [](int64_t d) { return d != 1; }))
999 return rewriter.notifyMatchFailure(
1000 op,
"Only unit dimensions allowed for the leading "
1001 "dimensions of the store vector!");
1003 auto distValueTyOrFailure =
1005 if (
failed(distValueTyOrFailure))
1006 return rewriter.notifyMatchFailure(
1007 op,
"unable to compute expected lane vector type from lane layout");
1009 VectorType distValueTy = distValueTyOrFailure.value();
1010 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
1011 distValueTy.getElementType());
1013 Value distValue = adaptor.getValue();
1014 if (distValue.
getType() != distValueTy1D)
1019 Value distOffsets = adaptor.getOffsets();
1020 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
1021 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1022 distOffsetsTy.getElementType());
1023 distOffsets = castValueTo(
1026 Value distMask = adaptor.getMask();
1027 auto distMaskTy = cast<VectorType>(distMask.
getType());
1028 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1029 distMaskTy.getElementType());
1033 Value distDest = adaptor.getDest();
1034 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1035 distOffsets, distMask, op.getChunkSizeAttr(),
1036 op.getL1HintAttr(), op.getL2HintAttr(),
1037 op.getL3HintAttr(),
nullptr);
1038 rewriter.eraseOp(op);
1047struct SgToLaneVectorStep :
public OpConversionPattern<vector::StepOp> {
1048 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1051 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1052 ConversionPatternRewriter &rewriter)
const override {
1053 xegpu::DistributeLayoutAttr resultLayout =
1055 if (!resultLayout || !resultLayout.isForSubgroup())
1056 return rewriter.notifyMatchFailure(
1057 op,
"the result vector of the step op lacks subgroup layout");
1059 auto loc = op.getLoc();
1060 auto stepResultVecTy = op.getResult().getType();
1061 auto laneShapeOrFailure =
1063 if (
failed(laneShapeOrFailure))
1064 return rewriter.notifyMatchFailure(
1065 op,
"unable to compute lane vector type from the layout");
1066 VectorType newVecTy = laneShapeOrFailure.value();
1068 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1069 mlir::IntegerAttr());
1070 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1071 rewriter, loc, laneId, stepResultVecTy.getShape());
1072 if (
failed(laneDataBlockCoords))
1073 return rewriter.notifyMatchFailure(
1074 op,
"failed to compute lane data block coordinates");
1076 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1077 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1078 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
1079 newVecTy.getNumElements() / laneDataBlockLength);
1080 SmallVector<Value> stepVals;
1088 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1089 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1090 stepVals.push_back(laneDataBlockStartCoord);
1091 for (
int i = 1; i < laneDataBlockLength; ++i) {
1093 stepVals.push_back(arith::AddIOp::create(
1094 rewriter, loc, laneDataBlockStartCoord, offset));
1097 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
1098 "Expecting the number of step values to match the number of "
1099 "elements in the vector");
1101 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1102 rewriter.replaceOp(op, stepOpVal);
1109struct SgToLaneVectorExtract :
public OpConversionPattern<vector::ExtractOp> {
1110 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1113 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter)
const override {
1116 auto resultType = dyn_cast<VectorType>(op.getType());
1118 return rewriter.notifyMatchFailure(op,
"scalar extract not supported");
1120 xegpu::DistributeLayoutAttr layout =
1122 if (!layout || !layout.isForSubgroup())
1127 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1128 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1129 [](int64_t v) {
return v != 1; }))
1130 return rewriter.notifyMatchFailure(
1131 op,
"only innermost dimension distribution is supported for "
1134 auto newOp = vector::ExtractOp::create(
1135 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1136 rewriter.replaceOp(op, newOp.
getResult());
1142struct SgToLaneVectorShapeCast
1143 :
public OpConversionPattern<vector::ShapeCastOp> {
1144 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1147 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1149 xegpu::DistributeLayoutAttr resultLayout =
1151 if (!resultLayout || !resultLayout.isForSubgroup())
1152 return rewriter.notifyMatchFailure(
1153 op,
"the result vector of the shape_cast op lacks subgroup layout");
1156 resultLayout, op.getResultVectorType());
1157 if (
failed(resultDistTypeOrFailure))
1158 return rewriter.notifyMatchFailure(
1159 op,
"failed to get distributed vector type for result");
1161 Value source = adaptor.getSource();
1162 auto newShapeCast = vector::ShapeCastOp::create(
1163 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1164 rewriter.replaceOp(op, newShapeCast);
1172struct SgToLaneVectorExtractStridedSlice
1173 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
1174 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1177 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter)
const override {
1179 xegpu::DistributeLayoutAttr resultLayout =
1181 if (!resultLayout || !resultLayout.isForSubgroup())
1184 VectorType resultType = op.getType();
1185 auto distResultTyOrFailure =
1187 if (
failed(distResultTyOrFailure))
1188 return rewriter.notifyMatchFailure(
1189 op,
"unable to compute distributed vector type from lane layout");
1190 VectorType distResultTy = *distResultTyOrFailure;
1192 SmallVector<int64_t> distributedDims =
1193 getDistributedDims(resultType, distResultTy);
1196 int64_t sourceRank = op.getSourceVectorType().getRank();
1197 SmallVector<Attribute> updatedSizes =
1198 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1199 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1200 op.getOffsets(), [](Attribute attr) { return attr; });
1201 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1202 op.getStrides(), [](Attribute attr) { return attr; });
1203 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1204 updatedSizes.push_back(
1205 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1206 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1207 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1212 if (!distributedDims.empty()) {
1213 if (distributedDims.size() != 1)
1214 return rewriter.notifyMatchFailure(
1215 op,
"only single dimension distribution is supported");
1216 int64_t distDim = distributedDims[0];
1219 return rewriter.notifyMatchFailure(
1220 op,
"target attribute required to determine subgroup size");
1223 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1224 return rewriter.notifyMatchFailure(
1225 op,
"source of extract_strided_slice lacks distribution layout");
1226 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1227 if (sourceDistrDimSize % subgroupSize != 0)
1228 return rewriter.notifyMatchFailure(
1229 op,
"source size along distributed dim is not a multiple of "
1231 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1234 if (distDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1235 sourceLaneData[distDim] != 1)
1236 return rewriter.notifyMatchFailure(
1237 op,
"expecting unit lane data along the distributed dimension");
1238 int64_t distrDimOffset =
1239 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1240 if (distrDimOffset % subgroupSize != 0)
1241 return rewriter.notifyMatchFailure(
1242 op,
"offset along distributed dim is not a multiple of "
1245 updatedSizes[distDim] =
1246 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1247 updatedOffsets[distDim] =
1248 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1251 auto newOp = vector::ExtractStridedSliceOp::create(
1252 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1253 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1254 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1255 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1256 rewriter.replaceOp(op, newOp.
getResult());
1318struct SgToLaneBroadcast :
public OpConversionPattern<vector::BroadcastOp> {
1319 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1322 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1323 ConversionPatternRewriter &rewriter)
const override {
1324 xegpu::DistributeLayoutAttr resultLayout =
1326 if (!resultLayout || !resultLayout.isForSubgroup())
1327 return rewriter.notifyMatchFailure(
1328 op,
"result does not have subgroup distribute layout");
1330 VectorType destType = op.getResultVectorType();
1331 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1333 xegpu::DistributeLayoutAttr sourceLayout =
1337 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1340 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1342 "broadcast source layout must be a slice of result layout");
1343 }
else if (rankDiff == 0) {
1345 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1346 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1347 broadcastUnitDimsSet.end());
1348 assert(sourceLayout.isEqualTo(
1349 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1350 "The sg_data for unit dimensions should be set as 1");
1351 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1356 return rewriter.notifyMatchFailure(
1357 op,
"broadcast from scalar must not have a layout attribute");
1362 if (
failed(destDistType))
1363 return rewriter.notifyMatchFailure(
1364 op,
"failed to distribute the result vector type");
1366 Value source = adaptor.getSource();
1368 if (source.
getType() == destDistType.value()) {
1369 rewriter.replaceOp(op, source);
1373 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1374 destDistType.value(), source);
1375 rewriter.replaceOp(op, newOp);
1383struct SgToLaneVectorInsertStridedSlice
1384 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
1385 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1388 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1389 ConversionPatternRewriter &rewriter)
const override {
1390 xegpu::DistributeLayoutAttr resultLayout =
1392 if (!resultLayout || !resultLayout.isForSubgroup())
1395 VectorType destType = op.getDestVectorType();
1396 auto distDestTyOrFailure =
1398 if (
failed(distDestTyOrFailure))
1399 return rewriter.notifyMatchFailure(
1400 op,
"unable to compute distributed vector type from lane layout");
1401 VectorType distDestTy = *distDestTyOrFailure;
1403 SmallVector<int64_t> destDistributedDims =
1404 getDistributedDims(destType, distDestTy);
1406 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1407 op.getOffsets(), [](Attribute attr) { return attr; });
1409 if (!destDistributedDims.empty()) {
1410 if (destDistributedDims.size() != 1)
1411 return rewriter.notifyMatchFailure(
1412 op,
"only single dimension distribution is supported");
1413 int64_t destDistDim = destDistributedDims[0];
1417 return rewriter.notifyMatchFailure(
1418 op,
"target attribute required to determine subgroup size");
1421 VectorType srcType = op.getSourceVectorType();
1423 int64_t sourceDistDim =
1424 destDistDim - (destType.getRank() - srcType.getRank());
1425 if (sourceDistDim < 0)
1426 return rewriter.notifyMatchFailure(
1427 op,
"distributed dimension must be in the last k dims of dest");
1431 if (!destLayout || !sourceLayout ||
1432 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1433 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1434 return rewriter.notifyMatchFailure(
1435 op,
"source or dest of insert_strided_slice lacks distribution "
1438 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1439 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1442 if ((destDistDim <
static_cast<int64_t
>(destLaneData.size()) &&
1443 destLaneData[destDistDim] != 1) ||
1444 (sourceDistDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1445 sourceLaneData[sourceDistDim] != 1))
1446 return rewriter.notifyMatchFailure(
1447 op,
"expecting unit lane data along the distributed dimension");
1449 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1450 if (srcDistrDimSize % subgroupSize != 0)
1451 return rewriter.notifyMatchFailure(
1452 op,
"source distributed dim size is not a multiple of "
1455 int64_t destDistrDimOffset =
1456 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1457 if (destDistrDimOffset % subgroupSize != 0)
1458 return rewriter.notifyMatchFailure(
1459 op,
"offset along distributed dim is not a multiple of "
1462 updatedOffsets[destDistDim] =
1463 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1466 auto newOp = vector::InsertStridedSliceOp::create(
1467 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1469 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1470 rewriter.replaceOp(op, newOp.
getResult());
1477struct SgToLaneVectorInsert :
public OpConversionPattern<vector::InsertOp> {
1478 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1481 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1482 ConversionPatternRewriter &rewriter)
const override {
1484 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1486 return rewriter.notifyMatchFailure(op,
"scalar insert not supported");
1488 xegpu::DistributeLayoutAttr layout =
1490 if (!layout || !layout.isForSubgroup())
1495 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1496 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1497 [](int64_t v) {
return v != 1; }))
1498 return rewriter.notifyMatchFailure(
1499 op,
"only innermost dimension distribution is supported for "
1502 auto newOp = vector::InsertOp::create(
1503 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1504 op.getMixedPosition());
1505 rewriter.replaceOp(op, newOp.
getResult());
1511struct SgToLaneConvertLayout
1512 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
1513 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1516 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1517 ConversionPatternRewriter &rewriter)
const override {
1518 auto inputLayout = op.getInputLayoutAttr();
1519 auto targetLayout = op.getTargetLayoutAttr();
1520 Type valType = op.getResult().getType();
1523 rewriter.replaceOp(op, op.getSource());
1527 auto resShape = cast<VectorType>(valType).getShape();
1528 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1529 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1530 xegpu::LayoutKind::Lane)) {
1531 return rewriter.notifyMatchFailure(
1532 op,
"lowering incompatible convert_layout not yet supported");
1535 rewriter.replaceOp(op, adaptor.getSource());
1541struct SgToLaneVectorInterleave
1542 :
public OpConversionPattern<vector::InterleaveOp> {
1543 using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
1546 matchAndRewrite(vector::InterleaveOp op, OpAdaptor adaptor,
1547 ConversionPatternRewriter &rewriter)
const override {
1549 auto newOp = vector::InterleaveOp::create(
1550 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
1551 rewriter.replaceOp(op, newOp.
getResult());
1557struct SgToLaneVectorDeinterleave
1558 :
public OpConversionPattern<vector::DeinterleaveOp> {
1559 using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
1562 matchAndRewrite(vector::DeinterleaveOp op, OpAdaptor adaptor,
1563 ConversionPatternRewriter &rewriter)
const override {
1565 auto newOp = vector::DeinterleaveOp::create(rewriter, op.getLoc(),
1566 adaptor.getSource());
1572struct SgToLaneDpasMx :
public OpConversionPattern<xegpu::DpasMxOp> {
1573 using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
1576 matchAndRewrite(xegpu::DpasMxOp op, OpAdaptor adaptor,
1577 ConversionPatternRewriter &rewriter)
const override {
1582 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc))
1583 return rewriter.notifyMatchFailure(
1584 op,
"target uArch does not support scaled subgroup mma");
1586 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
1587 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
1588 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
1589 if (!layoutA || !layoutB || !layoutCd)
1590 return rewriter.notifyMatchFailure(
1591 op,
"missing required layout attributes for DpasMxOp distribution");
1594 auto expected1DTypeResult =
1596 auto expected1DTypeA =
1598 auto expected1DTypeB =
1601 VectorType expected1DTypeScaleA, expected1DTypeScaleB;
1602 if (op.getScaleA()) {
1603 auto layoutScaleA = cast<xegpu::LayoutAttr>(op.getLayoutAScaleAttr());
1605 cast<VectorType>(op.getScaleA().getType()), layoutScaleA);
1606 if (
failed(expected1DTypeScaleAOrFailure))
1607 return rewriter.notifyMatchFailure(
1608 op,
"failed to calculate expected 1D vector type for scale A");
1609 expected1DTypeScaleA = expected1DTypeScaleAOrFailure.value();
1611 if (op.getScaleB()) {
1612 auto layoutScaleB = cast<xegpu::LayoutAttr>(op.getLayoutBScaleAttr());
1614 cast<VectorType>(op.getScaleB().getType()), layoutScaleB);
1615 if (
failed(expected1DTypeScaleBOrFailure))
1616 return rewriter.notifyMatchFailure(
1617 op,
"failed to calculate expected 1D vector type for scale B");
1618 expected1DTypeScaleB = expected1DTypeScaleBOrFailure.value();
1621 auto expectedNDTypeResult =
1623 if (
failed(expected1DTypeResult) ||
failed(expected1DTypeA) ||
1625 return rewriter.notifyMatchFailure(
1627 "failed to calculate supported workitem 1D vector types for DpasOp "
1629 if (
failed(expectedNDTypeResult))
1630 return rewriter.notifyMatchFailure(
1631 op,
"unable to compute expected workitem vector type for DpasOp from "
1635 const auto *uArchInstruction = dyn_cast<
1637 xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
1638 assert(uArchInstruction);
1639 auto wiAType = expected1DTypeA.value();
1640 auto wiBType = expected1DTypeB.value();
1642 unsigned aPackedBitWidth =
1643 wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
1644 unsigned bPackedBitWidth =
1645 wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
1646 if (aPackedBitWidth % uArchInstruction->getPackedFormatBitSizeA())
1647 return rewriter.notifyMatchFailure(
1648 op,
"A operand packed bit width must be a multiple of uArch packed "
1649 "format requirement");
1650 if (bPackedBitWidth % uArchInstruction->getPackedFormatBitSizeB())
1651 return rewriter.notifyMatchFailure(
1652 op,
"B operand packed bit width must be a multiple of uArch packed "
1653 "format requirement");
1655 auto newOp = xegpu::DpasMxOp::create(
1656 rewriter, op->getLoc(), expected1DTypeResult.value(),
1658 expected1DTypeA.value()),
1660 expected1DTypeB.value()),
1662 ? castValueTo(rewriter,
1664 expected1DTypeResult.value())
1668 ? castValueTo(rewriter,
1670 expected1DTypeScaleA)
1673 ? castValueTo(rewriter,
1675 expected1DTypeScaleB)
1681 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
1682 expectedNDTypeResult.value()));
1687struct XeGPUSgToLaneDistributePass
1689 XeGPUSgToLaneDistributePass> {
1690 void runOnOperation()
override;
1695void XeGPUSgToLaneDistributePass::runOnOperation() {
1698 Operation *root = getOperation();
1700 signalPassFailure();
1705 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1707 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1711 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1712 mlir::ValueRange inputs,
1713 mlir::Location loc) -> mlir::Value {
1714 UnrealizedConversionCastOp castOp =
1715 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1716 return castOp.getResult(0);
1720 TypeConverter typeConverter;
1722 typeConverter.addSourceMaterialization(materializeCast);
1723 typeConverter.addTargetMaterialization(materializeCast);
1728 typeConverter, patterns,
target);
1729 target.addLegalOp<UnrealizedConversionCastOp>();
1730 (void)applyPartialConversion(root,
target, std::move(patterns));
1741 OpBuilder builder(root);
1742 root->
walk([&](UnrealizedConversionCastOp op) {
1744 if (existingCasts.contains(op))
1747 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1750 auto singleInput = op.getInputs()[0];
1751 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1752 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1753 if (!inputTy || !outputTy)
1759 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1760 if (!definingOp || !definingOp->hasOneUse())
1762 auto inputOfDefiningOp = definingOp.getInputs()[0];
1765 auto inputOfDefiningOpTy =
1766 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1767 if (inputOfDefiningOpTy &&
1768 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1770 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1771 outputTy, inputOfDefiningOp);
1772 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
1778 bool changed =
true;
1781 root->
walk([&](UnrealizedConversionCastOp op) {
1783 if (existingCasts.contains(op))
1785 if (op.use_empty()) {
1798 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
1799 if (!isa<TensorDescType, VectorType>(type))
1801 return std::nullopt;
1804 typeConverter.addConversion([](TensorDescType type) ->
Type {
1805 if (type.getLayoutAttr()) {
1806 return type.dropLayouts();
1812 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
1815 if (!isa<VectorType>(type))
1816 return std::nullopt;
1818 if (!layout || !layout.isForSubgroup())
1821 auto newTyOrFailure =
1823 if (failed(newTyOrFailure))
1825 return *newTyOrFailure;
1834 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1835 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
1837 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
1838 if (isa<xegpu::ConvertLayoutOp>(op))
1840 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1843 return !anchorOp.getAnchorLayout();
1846 target.addDynamicallyLegalOp<arith::ConstantOp>(
1847 [=](arith::ConstantOp op) ->
bool {
1849 if (!isa<VectorType>(op.getResult().getType()))
1855 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1856 [=](
Operation *op) -> std::optional<bool> {
1861 if (op->getNumResults() != 1)
1864 VectorType resultType =
1865 dyn_cast<VectorType>(op->getResult(0).getType());
1870 for (
Value operand : op->getOperands()) {
1871 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1872 if (!operandType || operandType.getShape() != resultType.getShape()) {
1880 target.addDynamicallyLegalOp<vector::ReductionOp>(
1881 [=](vector::ReductionOp op) ->
bool {
1886 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1887 [=](vector::MultiDimReductionOp op) ->
bool {
1888 return !isValidSubgroupMultiReductionOp(op);
1890 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1891 vector::TransposeOp, vector::BitCastOp,
1892 vector::ShapeCastOp, vector::StepOp,
1893 vector::BroadcastOp>([=](
Operation *op) ->
bool {
1896 target.addDynamicallyLegalOp<vector::ExtractOp>(
1897 [=](vector::ExtractOp op) ->
bool {
1898 if (!isa<VectorType>(op.getType()))
1902 target.addDynamicallyLegalOp<vector::InsertOp>(
1903 [=](vector::InsertOp op) ->
bool {
1906 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1907 [=](vector::ExtractStridedSliceOp op) ->
bool {
1910 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1911 [=](vector::InsertStridedSliceOp op) ->
bool {
1914 target.addDynamicallyLegalOp<vector::InterleaveOp, vector::DeinterleaveOp>(
1918 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1920 SgToLaneCreateNdDesc, SgToLaneLoadNd, SgToLaneStoreNd, SgToLaneDpas,
1921 SgToLaneElementWise, SgToLaneArithConstant, SgToLanePrefetchNd,
1922 SgToLaneLoadGather, SgToLaneStoreScatter, SgToLaneVectorReduction,
1923 SgToLaneMultiDimReduction, SgToLaneVectorExtract, SgToLaneVectorInsert,
1924 SgToLaneVectorExtractStridedSlice, SgToLaneVectorInsertStridedSlice,
1925 SgToLaneLoadMatrix, SgToLaneStoreMatrix, SgToLaneConvertLayout,
1926 SgToLaneVectorTranspose, SgToLaneVectorBitcast, SgToLaneVectorStep,
1927 SgToLaneVectorShapeCast, SgToLaneBroadcast,
1928 SgToLaneCreateMask<vector::CreateMaskOp>,
1929 SgToLaneCreateMask<vector::ConstantMaskOp>, SgToLaneVectorDeinterleave,
1930 SgToLaneVectorInterleave, SgToLaneDpasMx>(typeConverter,
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
@ SubgroupMatrixMultiplyAcc
@ SubgroupScaledMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToLaneDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to lane distribution.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUSgToLaneDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to lane distribution and appends the require...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
bool isSupportedInstruction(InstructionKind instr) const
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const