29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49static Value castValueTo(ConversionPatternRewriter &rewriter,
52 if (v.getType() == expectedTy)
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
62 return newOp.getResult(0);
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
71 if (!resLayout || !resLayout.isForSubgroup())
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
84 return op.getReductionDims().size() == 1;
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
93 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
94 "MultiDimReductionOp");
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
108 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
112 return distributedDims;
117struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
118 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
121 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const override {
123 xegpu::TensorDescType resultType = op.getType();
125 if (!resultType.getLayout())
128 auto newOp = xegpu::CreateNdDescOp::create(
129 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
131 rewriter.replaceOp(op, newOp.getResult());
139struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
140 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
143 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter)
const override {
145 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
151 if (op.getTensorDescType().getLayout() != layout)
152 return rewriter.notifyMatchFailure(
153 op,
"conflicting layout attributes on tensor descriptor and anchor");
156 return rewriter.notifyMatchFailure(
157 op,
"xegpu::LoadNdOp require target attribute attached to "
158 "determine transpose "
160 auto supportedWiResultTyOrFailure =
162 auto expectedWiResultTyOrFailure =
164 if (failed(supportedWiResultTyOrFailure))
165 return rewriter.notifyMatchFailure(
166 op,
"unable to compute the workitem vector type for LoadNdOp");
167 if (failed(expectedWiResultTyOrFailure))
168 return rewriter.notifyMatchFailure(
170 "unable to compute expected workitem vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(),
nullptr);
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedWiResultTyOrFailure.value()));
190struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter)
const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op,
"conflicting layout attributes on tensor descriptor and anchor");
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op,
"conflicting layout attributes on value and anchor");
209 auto supportedWiValueTyOrFailure =
211 if (failed(supportedWiValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
214 "unable to compute wi vector type for StoreNdOp value from tensor "
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
220 supportedWiValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
223 rewriter.eraseOp(op);
231struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
238 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
239 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
240 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
241 if (!layoutA || !layoutB || !layoutCd)
243 auto wiResultTyOrFailure =
245 auto wiATypeOrFailure =
247 auto wiBTypeOrFailure =
249 auto expectedWiResultTyOrFailure =
251 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
252 failed(wiBTypeOrFailure))
253 return rewriter.notifyMatchFailure(
254 op,
"failed to calculate supported workitem vector types for DpasOp "
256 if (failed(expectedWiResultTyOrFailure))
257 return rewriter.notifyMatchFailure(
258 op,
"unable to compute expected workitem vector type for DpasOp from "
264 const auto *uArchInstruction =
265 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
268 if (uArchInstruction) {
269 auto wiAType = wiATypeOrFailure.value();
270 auto wiBType = wiBTypeOrFailure.value();
272 unsigned aPackedBitWidth =
273 wiAType.getElementTypeBitWidth() * wiAType.getNumElements();
274 unsigned bPackedBitWidth =
275 wiBType.getElementTypeBitWidth() * wiBType.getNumElements();
276 unsigned expectedABitSize = uArchInstruction->getPackedFormatBitSizeA();
277 unsigned expectedBBitSize = uArchInstruction->getPackedFormatBitSizeB();
279 if (aPackedBitWidth % expectedABitSize != 0)
280 return rewriter.notifyMatchFailure(
282 "A operand packed bit width must be a multiple of uArch packed "
283 "format requirement");
284 if (bPackedBitWidth % expectedBBitSize != 0)
285 return rewriter.notifyMatchFailure(
287 "B operand packed bit width must be a multiple of uArch packed "
288 "format requirement");
292 auto newOp = xegpu::DpasOp::create(
293 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
295 wiATypeOrFailure.value()),
297 wiBTypeOrFailure.value()),
299 wiResultTyOrFailure.value()),
303 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
304 expectedWiResultTyOrFailure.value()));
311struct SgToWiElementWise :
public ConversionPattern {
313 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
317 ConversionPatternRewriter &rewriter)
const override {
324 return rewriter.notifyMatchFailure(
325 op,
"operation result is not a vector type");
327 xegpu::DistributeLayoutAttr layout =
329 if (!layout || !layout.isForSubgroup())
330 return rewriter.notifyMatchFailure(
331 op,
"operation result does not have subgroup distribute layout");
333 auto wiShapeOrFailure =
336 if (failed(wiShapeOrFailure))
337 return rewriter.notifyMatchFailure(
338 op,
"unable to compute workitem vector type from the layout");
340 VectorType newResultType = wiShapeOrFailure.value();
342 state.addOperands(operands);
343 state.addTypes(newResultType);
346 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
347 state.addAttribute(attr.getName(), attr.getValue());
349 Operation *newOp = rewriter.create(state);
351 rewriter.replaceOp(op, newOp->
getResult(0));
358struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
359 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
362 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 auto resultType = dyn_cast<VectorType>(op.getType());
369 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
371 return rewriter.notifyMatchFailure(
372 op,
"only dense splat vector constants are supported");
374 xegpu::DistributeLayoutAttr layout =
376 if (!layout || !layout.isForSubgroup())
377 return rewriter.notifyMatchFailure(
378 op,
"operation result does not have subgroup distribute layout");
380 auto wiShapeOrFailure =
383 if (
failed(wiShapeOrFailure))
384 return rewriter.notifyMatchFailure(
385 op,
"unable to compute workitem vector type from the layout");
387 VectorType newResultType = wiShapeOrFailure.value();
388 auto sclarValue = dense.getSplatValue<Attribute>();
391 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
393 rewriter.replaceOp(op, newOp.
getResult());
399struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
400 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
403 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
405 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
410 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
411 op.getMixedOffsets(), op.getL1HintAttr(),
412 op.getL2HintAttr(), op.getL3HintAttr(),
414 rewriter.eraseOp(op);
452struct SgToWiLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
453 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
456 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter)
const override {
458 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
462 VectorType origResultTy = op.getValueType();
467 int chunkSize = op.getChunkSize().value_or(1);
468 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
469 ArrayRef<int64_t> shape = origResultTy.getShape();
471 shape.take_front(origResultTy.getRank() - effectiveVecRank),
472 [](int64_t d) { return d != 1; }))
473 return rewriter.notifyMatchFailure(
474 op,
"Only unit dimensions allowed for the leading "
475 "dimensions of the load vector!");
477 auto distResultTyOrFailure =
479 if (
failed(distResultTyOrFailure))
480 return rewriter.notifyMatchFailure(
482 "unable to compute expected workitem vector type from lane layout");
484 VectorType distResultTy = distResultTyOrFailure.value();
485 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
486 distResultTy.getElementType());
489 Value distOffsets = adaptor.getOffsets();
490 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
491 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
492 distOffsetsTy.getElementType());
493 distOffsets = castValueTo(
496 Value distMask = adaptor.getMask();
497 auto distMaskTy = cast<VectorType>(distMask.
getType());
498 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
499 distMaskTy.getElementType());
503 Value distSource = adaptor.getSource();
504 auto newOp = xegpu::LoadGatherOp::create(
505 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
506 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
507 op.getL3HintAttr(),
nullptr);
510 if (distResultTy1D != distResultTy)
513 rewriter.replaceOp(op,
result);
522struct SgToWiVectorReduction :
public OpConversionPattern<vector::ReductionOp> {
523 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
526 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter)
const override {
531 if (!layout || !layout.isForSubgroup())
534 VectorType srcVecType = op.getSourceVectorType();
536 if (srcVecType.getRank() != 1)
537 return rewriter.notifyMatchFailure(
538 op,
"Only rank 1 reductions can be distributed.");
540 if (layout.getRank() != srcVecType.getRank())
541 return rewriter.notifyMatchFailure(
542 op,
"Layout rank does not match vector rank.");
545 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
548 return rewriter.notifyMatchFailure(
549 op,
"xegpu::ReductionOp require target attribute attached to "
550 "determine subgroup size");
554 srcVecType.getShape()[0] % sgSize != 0)
555 return rewriter.notifyMatchFailure(op,
556 "Invalid layout or reduction vector "
557 "dimension must match subgroup size.");
559 if (!op.getType().isIntOrFloat())
560 return rewriter.notifyMatchFailure(
561 op,
"Reduction distribution currently only supports floats and "
565 Value laneValVec = adaptor.getVector();
569 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
572 if (adaptor.getAcc())
574 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
576 rewriter.replaceOp(op, fullReduce);
585struct SgToWiMultiDimReduction
586 :
public OpConversionPattern<vector::MultiDimReductionOp> {
587 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
590 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter)
const override {
593 ArrayRef<int64_t> reductionDims = op.getReductionDims();
594 assert(reductionDims.size() == 1 &&
595 "Expecting single reduction dimension for subgroup multi "
598 VectorType sourceType = op.getSourceVectorType();
599 int64_t rank = sourceType.getRank();
601 ArrayRef<int64_t> shape = sourceType.getShape();
602 if (llvm::any_of(shape.take_front(rank - 2),
603 [](int64_t d) { return d != 1; }))
604 return rewriter.notifyMatchFailure(
605 op,
"only unit leading dimensions are supported for "
606 "multi_reduction with rank > 2");
610 if (op.getType().isIntOrFloat()) {
611 auto reductionDim = reductionDims[0];
612 VectorType origSourceType = op.getSourceVectorType();
613 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
617 op.getKind(), reductionDimSize);
619 if (adaptor.getAcc())
621 result, adaptor.getAcc());
622 }
else if (isReductionLaneLocal(op)) {
626 auto reductionDim = reductionDims[0];
630 reductionDim, op.getLoc(), rewriter);
632 auto reductionDim = reductionDims[0];
633 VectorType sourceType = op.getSourceVectorType();
634 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
638 reductionDim, reductionDimSize, op.getLoc(), rewriter);
640 rewriter.replaceOp(op,
result);
649 ConversionPatternRewriter &rewriter,
Location loc,
652 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
653 mlir::IntegerAttr());
655 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
658 assert(maybeCoords.value().size() == 1 &&
659 "Expected one set of distributed offsets");
663 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
667struct SgToWiLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
668 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
671 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
672 ConversionPatternRewriter &rewriter)
const override {
673 auto layout = op.getLayoutAttr();
678 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
680 return rewriter.notifyMatchFailure(
681 op,
"the matrix op payload must be a vector type");
683 auto loc = op.getLoc();
684 auto offsets = op.getMixedOffsets();
686 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
688 FailureOr<VectorType> distPayloadTyOrFailure =
690 if (
failed(distPayloadTyOrFailure))
691 return rewriter.notifyMatchFailure(
692 op,
"Failed to distribute matrix op payload based on layout.");
694 SmallVector<Value> offsetsAsValues =
697 SmallVector<Value> newCoords = offsetsAsValues;
698 if (!op.getSubgroupBlockIoAttr()) {
699 newCoords = computeDistributedCoordsForMatrixOp(
700 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
701 if (newCoords.empty())
702 return rewriter.notifyMatchFailure(
703 op,
"Failed to compute distributed coordinates.");
706 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
707 ShapedType::kDynamic);
709 rewriter.getDenseI64ArrayAttr(newConstOffsets);
711 auto newOp = xegpu::LoadMatrixOp::create(
712 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
713 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
714 xegpu::DistributeLayoutAttr{});
715 rewriter.replaceOp(op, newOp.
getResult());
721struct SgToWiVectorTranspose :
public OpConversionPattern<vector::TransposeOp> {
722 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
725 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter)
const override {
727 xegpu::DistributeLayoutAttr sourceLayout =
729 xegpu::DistributeLayoutAttr resultLayout =
731 if (!sourceLayout || !resultLayout)
732 return rewriter.notifyMatchFailure(
733 op,
"the source or result vector of the transpose op lacks layout "
735 ArrayRef<int64_t> perm = op.getPermutation();
737 if (!resultLayout.isTransposeOf(sourceLayout, perm,
738 xegpu::LayoutKind::Lane))
739 return rewriter.notifyMatchFailure(
740 op,
"the source or result vector layouts must be transposes of "
742 FailureOr<VectorType> distributedResultTypeOrFailure =
744 if (
failed(distributedResultTypeOrFailure))
745 return rewriter.notifyMatchFailure(
746 op,
"Failed to distribute the result vector type in "
747 "vector::Transpose op");
748 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
749 adaptor.getVector(), perm);
750 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
751 distributedResultTypeOrFailure.value()));
758struct SgToWiVectorBitcast :
public OpConversionPattern<vector::BitCastOp> {
759 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
762 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter)
const override {
764 xegpu::DistributeLayoutAttr resultLayout =
767 return rewriter.notifyMatchFailure(
768 op,
"result vector of the bitcast op lacks layout attribute");
769 FailureOr<VectorType> distributedResultTypeOrFailure =
771 if (
failed(distributedResultTypeOrFailure))
772 return rewriter.notifyMatchFailure(
773 op,
"Failed to distribute the result vector type in "
774 "vector::BitCast op");
775 auto newOp = vector::BitCastOp::create(
776 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
777 adaptor.getSource());
778 rewriter.replaceOp(op, newOp.
getResult());
806template <
typename OpType,
807 typename = std::enable_if_t<llvm::is_one_of<
808 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
809struct SgToWiCreateMask :
public OpConversionPattern<OpType> {
810 using OpConversionPattern<OpType>::OpConversionPattern;
813 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
814 ConversionPatternRewriter &rewriter)
const override {
815 xegpu::DistributeLayoutAttr layout =
817 if (!layout || !layout.isForSubgroup())
818 return rewriter.notifyMatchFailure(
819 op,
"operation result does not have subgroup distribute layout");
821 VectorType origType = op.getType();
822 FailureOr<VectorType> distTypeOrFailure =
824 if (
failed(distTypeOrFailure))
825 return rewriter.notifyMatchFailure(
826 op,
"unable to compute workitem vector type from the layout");
828 VectorType distType = distTypeOrFailure.value();
829 Location loc = op.getLoc();
832 SmallVector<Value> origBounds;
833 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
834 origBounds.append(op.getOperands().begin(), op.getOperands().end());
836 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
837 for (
auto dimSize : dimSizes)
838 origBounds.push_back(
842 ArrayRef<int64_t> origShape = origType.getShape();
845 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
846 mlir::IntegerAttr());
847 auto maybeCoordsVec =
848 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
849 if (
failed(maybeCoordsVec))
850 return rewriter.notifyMatchFailure(
851 op,
"failed to compute distributed coordinates from layout");
853 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
854 int64_t numElements = distType.getNumElements();
855 assert(
static_cast<int64_t
>(coordsVec.size()) == numElements &&
856 "number of coordinate sets must match number of distributed "
862 SmallVector<Value> maskBits;
863 for (
auto &coords : coordsVec) {
864 Value inBounds = trueVal;
865 for (
size_t i = 0; i < coords.size(); ++i) {
866 Value cmp = arith::CmpIOp::create(
867 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
868 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
870 maskBits.push_back(inBounds);
875 if (numElements == 1) {
877 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
880 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
882 rewriter.replaceOp(op,
result);
888struct SgToWiStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
889 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
892 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter)
const override {
894 auto layout = op.getLayoutAttr();
899 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
901 return rewriter.notifyMatchFailure(
902 op,
"the matrix op payload must be a vector type");
904 auto loc = op.getLoc();
905 auto offsets = op.getMixedOffsets();
907 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
909 FailureOr<VectorType> distPayloadTyOrFailure =
911 if (
failed(distPayloadTyOrFailure))
912 return rewriter.notifyMatchFailure(
913 op,
"Failed to distribute matrix op payload based on layout.");
915 SmallVector<Value> offsetsAsValues =
918 SmallVector<Value> newCoords = offsetsAsValues;
919 if (!op.getSubgroupBlockIoAttr()) {
920 newCoords = computeDistributedCoordsForMatrixOp(
921 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
922 if (newCoords.empty())
923 return rewriter.notifyMatchFailure(
924 op,
"Failed to compute distributed coordinates.");
927 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
928 ShapedType::kDynamic);
930 rewriter.getDenseI64ArrayAttr(newConstOffsets);
932 xegpu::StoreMatrixOp::create(
935 distPayloadTyOrFailure.value()),
936 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
937 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
938 rewriter.eraseOp(op);
977struct SgToWiStoreScatter :
public OpConversionPattern<xegpu::StoreScatterOp> {
978 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
981 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
982 ConversionPatternRewriter &rewriter)
const override {
983 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
987 VectorType origValueTy = op.getValueType();
992 int chunkSize = op.getChunkSize().value_or(1);
993 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
994 ArrayRef<int64_t> shape = origValueTy.getShape();
995 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
996 [](int64_t d) { return d != 1; }))
997 return rewriter.notifyMatchFailure(
998 op,
"Only unit dimensions allowed for the leading "
999 "dimensions of the store vector!");
1001 auto distValueTyOrFailure =
1003 if (
failed(distValueTyOrFailure))
1004 return rewriter.notifyMatchFailure(
1006 "unable to compute expected workitem vector type from lane layout");
1008 VectorType distValueTy = distValueTyOrFailure.value();
1009 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
1010 distValueTy.getElementType());
1012 Value distValue = adaptor.getValue();
1013 if (distValue.
getType() != distValueTy1D)
1018 Value distOffsets = adaptor.getOffsets();
1019 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
1020 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1021 distOffsetsTy.getElementType());
1022 distOffsets = castValueTo(
1025 Value distMask = adaptor.getMask();
1026 auto distMaskTy = cast<VectorType>(distMask.
getType());
1027 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1028 distMaskTy.getElementType());
1032 Value distDest = adaptor.getDest();
1033 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1034 distOffsets, distMask, op.getChunkSizeAttr(),
1035 op.getL1HintAttr(), op.getL2HintAttr(),
1036 op.getL3HintAttr(),
nullptr);
1037 rewriter.eraseOp(op);
1046struct SgToWiVectorStep :
public OpConversionPattern<vector::StepOp> {
1047 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1050 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1051 ConversionPatternRewriter &rewriter)
const override {
1052 xegpu::DistributeLayoutAttr resultLayout =
1054 if (!resultLayout || !resultLayout.isForSubgroup())
1055 return rewriter.notifyMatchFailure(
1056 op,
"the result vector of the step op lacks subgroup layout");
1058 auto loc = op.getLoc();
1059 auto stepResultVecTy = op.getResult().getType();
1060 auto wiShapeOrFailure =
1062 if (
failed(wiShapeOrFailure))
1063 return rewriter.notifyMatchFailure(
1064 op,
"unable to compute workitem vector type from the layout");
1065 VectorType newVecTy = wiShapeOrFailure.value();
1067 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1068 mlir::IntegerAttr());
1069 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1070 rewriter, loc, laneId, stepResultVecTy.getShape());
1071 if (
failed(laneDataBlockCoords))
1072 return rewriter.notifyMatchFailure(
1073 op,
"failed to compute lane data block coordinates");
1075 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1076 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1077 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
1078 newVecTy.getNumElements() / laneDataBlockLength);
1079 SmallVector<Value> stepVals;
1087 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1088 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1089 stepVals.push_back(laneDataBlockStartCoord);
1090 for (
int i = 1; i < laneDataBlockLength; ++i) {
1092 stepVals.push_back(arith::AddIOp::create(
1093 rewriter, loc, laneDataBlockStartCoord, offset));
1096 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
1097 "Expecting the number of step values to match the number of "
1098 "elements in the vector");
1100 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1101 rewriter.replaceOp(op, stepOpVal);
1108struct SgToWiVectorExtract :
public OpConversionPattern<vector::ExtractOp> {
1109 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1112 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1113 ConversionPatternRewriter &rewriter)
const override {
1115 auto resultType = dyn_cast<VectorType>(op.getType());
1117 return rewriter.notifyMatchFailure(op,
"scalar extract not supported");
1119 xegpu::DistributeLayoutAttr layout =
1121 if (!layout || !layout.isForSubgroup())
1126 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1127 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1128 [](int64_t v) {
return v != 1; }))
1129 return rewriter.notifyMatchFailure(
1130 op,
"only innermost dimension distribution is supported for "
1133 auto newOp = vector::ExtractOp::create(
1134 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1135 rewriter.replaceOp(op, newOp.
getResult());
1141struct SgToWiVectorShapeCast :
public OpConversionPattern<vector::ShapeCastOp> {
1142 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1145 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter)
const override {
1147 xegpu::DistributeLayoutAttr resultLayout =
1149 if (!resultLayout || !resultLayout.isForSubgroup())
1150 return rewriter.notifyMatchFailure(
1151 op,
"the result vector of the shape_cast op lacks subgroup layout");
1154 resultLayout, op.getResultVectorType());
1155 if (
failed(resultDistTypeOrFailure))
1156 return rewriter.notifyMatchFailure(
1157 op,
"failed to get distributed vector type for result");
1159 Value source = adaptor.getSource();
1160 auto newShapeCast = vector::ShapeCastOp::create(
1161 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1162 rewriter.replaceOp(op, newShapeCast);
1170struct SgToWiVectorExtractStridedSlice
1171 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
1172 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1175 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const override {
1177 xegpu::DistributeLayoutAttr resultLayout =
1179 if (!resultLayout || !resultLayout.isForSubgroup())
1182 VectorType resultType = op.getType();
1183 auto distResultTyOrFailure =
1185 if (
failed(distResultTyOrFailure))
1186 return rewriter.notifyMatchFailure(
1187 op,
"unable to compute distributed vector type from lane layout");
1188 VectorType distResultTy = *distResultTyOrFailure;
1190 SmallVector<int64_t> distributedDims =
1191 getDistributedDims(resultType, distResultTy);
1194 int64_t sourceRank = op.getSourceVectorType().getRank();
1195 SmallVector<Attribute> updatedSizes =
1196 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1197 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1198 op.getOffsets(), [](Attribute attr) { return attr; });
1199 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1200 op.getStrides(), [](Attribute attr) { return attr; });
1201 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1202 updatedSizes.push_back(
1203 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1204 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1205 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1210 if (!distributedDims.empty()) {
1211 if (distributedDims.size() != 1)
1212 return rewriter.notifyMatchFailure(
1213 op,
"only single dimension distribution is supported");
1214 int64_t distDim = distributedDims[0];
1217 return rewriter.notifyMatchFailure(
1218 op,
"target attribute required to determine subgroup size");
1221 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1222 return rewriter.notifyMatchFailure(
1223 op,
"source of extract_strided_slice lacks distribution layout");
1224 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1225 if (sourceDistrDimSize % subgroupSize != 0)
1226 return rewriter.notifyMatchFailure(
1227 op,
"source size along distributed dim is not a multiple of "
1229 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1232 if (distDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1233 sourceLaneData[distDim] != 1)
1234 return rewriter.notifyMatchFailure(
1235 op,
"expecting unit lane data along the distributed dimension");
1236 int64_t distrDimOffset =
1237 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1238 if (distrDimOffset % subgroupSize != 0)
1239 return rewriter.notifyMatchFailure(
1240 op,
"offset along distributed dim is not a multiple of "
1243 updatedSizes[distDim] =
1244 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1245 updatedOffsets[distDim] =
1246 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1249 auto newOp = vector::ExtractStridedSliceOp::create(
1250 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1251 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1252 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1253 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1254 rewriter.replaceOp(op, newOp.
getResult());
1316struct SgToWiBroadcast :
public OpConversionPattern<vector::BroadcastOp> {
1317 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1320 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1321 ConversionPatternRewriter &rewriter)
const override {
1322 xegpu::DistributeLayoutAttr resultLayout =
1324 if (!resultLayout || !resultLayout.isForSubgroup())
1325 return rewriter.notifyMatchFailure(
1326 op,
"result does not have subgroup distribute layout");
1328 VectorType destType = op.getResultVectorType();
1329 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1331 xegpu::DistributeLayoutAttr sourceLayout =
1335 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1338 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1340 "broadcast source layout must be a slice of result layout");
1341 }
else if (rankDiff == 0) {
1343 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1344 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1345 broadcastUnitDimsSet.end());
1346 assert(sourceLayout.isEqualTo(
1347 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1348 "The sg_data for unit dimensions should be set as 1");
1349 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1354 return rewriter.notifyMatchFailure(
1355 op,
"broadcast from scalar must not have a layout attribute");
1360 if (
failed(destDistType))
1361 return rewriter.notifyMatchFailure(
1362 op,
"failed to distribute the result vector type");
1364 Value source = adaptor.getSource();
1366 if (source.
getType() == destDistType.value()) {
1367 rewriter.replaceOp(op, source);
1371 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1372 destDistType.value(), source);
1373 rewriter.replaceOp(op, newOp);
1381struct SgToWiVectorInsertStridedSlice
1382 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
1383 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1386 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1387 ConversionPatternRewriter &rewriter)
const override {
1388 xegpu::DistributeLayoutAttr resultLayout =
1390 if (!resultLayout || !resultLayout.isForSubgroup())
1393 VectorType destType = op.getDestVectorType();
1394 auto distDestTyOrFailure =
1396 if (
failed(distDestTyOrFailure))
1397 return rewriter.notifyMatchFailure(
1398 op,
"unable to compute distributed vector type from lane layout");
1399 VectorType distDestTy = *distDestTyOrFailure;
1401 SmallVector<int64_t> destDistributedDims =
1402 getDistributedDims(destType, distDestTy);
1404 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1405 op.getOffsets(), [](Attribute attr) { return attr; });
1407 if (!destDistributedDims.empty()) {
1408 if (destDistributedDims.size() != 1)
1409 return rewriter.notifyMatchFailure(
1410 op,
"only single dimension distribution is supported");
1411 int64_t destDistDim = destDistributedDims[0];
1415 return rewriter.notifyMatchFailure(
1416 op,
"target attribute required to determine subgroup size");
1419 VectorType srcType = op.getSourceVectorType();
1421 int64_t sourceDistDim =
1422 destDistDim - (destType.getRank() - srcType.getRank());
1423 if (sourceDistDim < 0)
1424 return rewriter.notifyMatchFailure(
1425 op,
"distributed dimension must be in the last k dims of dest");
1429 if (!destLayout || !sourceLayout ||
1430 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1431 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1432 return rewriter.notifyMatchFailure(
1433 op,
"source or dest of insert_strided_slice lacks distribution "
1436 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1437 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1440 if ((destDistDim <
static_cast<int64_t
>(destLaneData.size()) &&
1441 destLaneData[destDistDim] != 1) ||
1442 (sourceDistDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1443 sourceLaneData[sourceDistDim] != 1))
1444 return rewriter.notifyMatchFailure(
1445 op,
"expecting unit lane data along the distributed dimension");
1447 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1448 if (srcDistrDimSize % subgroupSize != 0)
1449 return rewriter.notifyMatchFailure(
1450 op,
"source distributed dim size is not a multiple of "
1453 int64_t destDistrDimOffset =
1454 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1455 if (destDistrDimOffset % subgroupSize != 0)
1456 return rewriter.notifyMatchFailure(
1457 op,
"offset along distributed dim is not a multiple of "
1460 updatedOffsets[destDistDim] =
1461 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1464 auto newOp = vector::InsertStridedSliceOp::create(
1465 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1467 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1468 rewriter.replaceOp(op, newOp.
getResult());
1475struct SgToWiVectorInsert :
public OpConversionPattern<vector::InsertOp> {
1476 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1479 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1480 ConversionPatternRewriter &rewriter)
const override {
1482 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1484 return rewriter.notifyMatchFailure(op,
"scalar insert not supported");
1486 xegpu::DistributeLayoutAttr layout =
1488 if (!layout || !layout.isForSubgroup())
1493 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1494 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1495 [](int64_t v) {
return v != 1; }))
1496 return rewriter.notifyMatchFailure(
1497 op,
"only innermost dimension distribution is supported for "
1500 auto newOp = vector::InsertOp::create(
1501 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1502 op.getMixedPosition());
1503 rewriter.replaceOp(op, newOp.
getResult());
1509struct SgToWiConvertLayout
1510 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
1511 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1514 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1515 ConversionPatternRewriter &rewriter)
const override {
1516 auto inputLayout = op.getInputLayoutAttr();
1517 auto targetLayout = op.getTargetLayoutAttr();
1518 Type valType = op.getResult().getType();
1521 rewriter.replaceOp(op, op.getSource());
1525 auto resShape = cast<VectorType>(valType).getShape();
1526 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1527 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1528 xegpu::LayoutKind::Lane)) {
1529 return rewriter.notifyMatchFailure(
1530 op,
"lowering incompatible convert_layout not yet supported");
1533 rewriter.replaceOp(op, adaptor.getSource());
1538struct XeGPUSgToWiDistributeExperimentalPass
1540 XeGPUSgToWiDistributeExperimentalPass> {
1541 void runOnOperation()
override;
1546void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1549 Operation *root = getOperation();
1551 signalPassFailure();
1556 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1558 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1562 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1563 mlir::ValueRange inputs,
1564 mlir::Location loc) -> mlir::Value {
1565 UnrealizedConversionCastOp castOp =
1566 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1567 return castOp.getResult(0);
1571 TypeConverter typeConverter;
1573 typeConverter.addSourceMaterialization(materializeCast);
1574 typeConverter.addTargetMaterialization(materializeCast);
1579 typeConverter, patterns,
target);
1580 target.addLegalOp<UnrealizedConversionCastOp>();
1581 (void)applyPartialConversion(root,
target, std::move(patterns));
1592 OpBuilder builder(root);
1593 root->
walk([&](UnrealizedConversionCastOp op) {
1595 if (existingCasts.contains(op))
1598 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1601 auto singleInput = op.getInputs()[0];
1602 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1603 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1604 if (!inputTy || !outputTy)
1610 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1611 if (!definingOp || !definingOp->hasOneUse())
1613 auto inputOfDefiningOp = definingOp.getInputs()[0];
1616 auto inputOfDefiningOpTy =
1617 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1618 if (inputOfDefiningOpTy &&
1619 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1621 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1622 outputTy, inputOfDefiningOp);
1623 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
1629 bool changed =
true;
1632 root->
walk([&](UnrealizedConversionCastOp op) {
1634 if (existingCasts.contains(op))
1636 if (op.use_empty()) {
1649 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
1650 if (!isa<TensorDescType, VectorType>(type))
1652 return std::nullopt;
1655 typeConverter.addConversion([](TensorDescType type) ->
Type {
1656 if (type.getLayoutAttr()) {
1657 return type.dropLayouts();
1663 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
1666 if (!isa<VectorType>(type))
1667 return std::nullopt;
1669 if (!layout || !layout.isForSubgroup())
1672 auto newTyOrFailure =
1674 if (failed(newTyOrFailure))
1676 return *newTyOrFailure;
1685 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1686 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
1688 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
1689 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1692 return !anchorOp.getAnchorLayout();
1695 target.addDynamicallyLegalOp<arith::ConstantOp>(
1696 [=](arith::ConstantOp op) ->
bool {
1698 if (!isa<VectorType>(op.getResult().getType()))
1704 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1705 [=](
Operation *op) -> std::optional<bool> {
1710 if (op->getNumResults() != 1)
1713 VectorType resultType =
1714 dyn_cast<VectorType>(op->getResult(0).getType());
1719 for (
Value operand : op->getOperands()) {
1720 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1721 if (!operandType || operandType.getShape() != resultType.getShape()) {
1729 target.addDynamicallyLegalOp<vector::ReductionOp>(
1730 [=](vector::ReductionOp op) ->
bool {
1735 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1736 [=](vector::MultiDimReductionOp op) ->
bool {
1737 return !isValidSubgroupMultiReductionOp(op);
1739 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1740 vector::TransposeOp, vector::BitCastOp,
1741 vector::ShapeCastOp, vector::StepOp,
1742 vector::BroadcastOp>([=](
Operation *op) ->
bool {
1745 target.addDynamicallyLegalOp<vector::ExtractOp>(
1746 [=](vector::ExtractOp op) ->
bool {
1747 if (!isa<VectorType>(op.getType()))
1751 target.addDynamicallyLegalOp<vector::InsertOp>(
1752 [=](vector::InsertOp op) ->
bool {
1755 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1756 [=](vector::ExtractStridedSliceOp op) ->
bool {
1759 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1760 [=](vector::InsertStridedSliceOp op) ->
bool {
1763 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1764 patterns.
add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1765 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1766 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1767 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1768 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1769 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1770 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1771 SgToWiVectorShapeCast, SgToWiBroadcast,
1772 SgToWiCreateMask<vector::CreateMaskOp>,
1773 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const