29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49static Value castValueTo(ConversionPatternRewriter &rewriter,
52 if (v.getType() == expectedTy)
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
62 return newOp.getResult(0);
68static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
71 if (!resLayout || !resLayout.isForSubgroup())
74 if (op.getType().isIntOrFloat())
75 return op.getReductionDims().size() == 1;
76 VectorType resTy = dyn_cast<VectorType>(op.getType());
80 FailureOr<VectorType> resDistTypeOrFailure =
81 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
82 if (failed(resDistTypeOrFailure))
84 return op.getReductionDims().size() == 1;
91static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
93 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
94 "MultiDimReductionOp");
96 VectorType resTy = dyn_cast<VectorType>(op.getType());
97 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
98 return resTy != resDistTypeOrFailure.value();
104 VectorType distributedType) {
105 assert(originalType.getRank() == distributedType.getRank() &&
106 "original and distributed vector types must have the same rank");
108 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
109 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
110 distributedDims.push_back(i);
112 return distributedDims;
117struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
118 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
121 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const override {
123 xegpu::TensorDescType resultType = op.getType();
125 if (!resultType.getLayout())
128 auto newOp = xegpu::CreateNdDescOp::create(
129 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
131 rewriter.replaceOp(op, newOp.getResult());
139struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
140 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
143 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter)
const override {
145 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
151 if (op.getTensorDescType().getLayout() != layout)
152 return rewriter.notifyMatchFailure(
153 op,
"conflicting layout attributes on tensor descriptor and anchor");
156 return rewriter.notifyMatchFailure(
157 op,
"xegpu::LoadNdOp require target attribute attached to "
158 "determine transpose "
160 auto supportedWiResultTyOrFailure =
162 auto expectedWiResultTyOrFailure =
164 if (failed(supportedWiResultTyOrFailure))
165 return rewriter.notifyMatchFailure(
166 op,
"unable to compute the workitem vector type for LoadNdOp");
167 if (failed(expectedWiResultTyOrFailure))
168 return rewriter.notifyMatchFailure(
170 "unable to compute expected workitem vector type from lane layout");
171 auto newOp = xegpu::LoadNdOp::create(
172 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
173 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
174 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
175 op.getL3HintAttr(),
nullptr);
181 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
182 expectedWiResultTyOrFailure.value()));
190struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
191 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
194 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter)
const override {
196 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
202 if (op.getTensorDescType().getLayout() != layout)
203 return rewriter.notifyMatchFailure(
204 op,
"conflicting layout attributes on tensor descriptor and anchor");
206 if (valueLayout != layout)
207 return rewriter.notifyMatchFailure(
208 op,
"conflicting layout attributes on value and anchor");
209 auto supportedWiValueTyOrFailure =
211 if (failed(supportedWiValueTyOrFailure))
212 return rewriter.notifyMatchFailure(
214 "unable to compute wi vector type for StoreNdOp value from tensor "
217 xegpu::StoreNdOp::create(
218 rewriter, op.getLoc(),
220 supportedWiValueTyOrFailure.value()),
221 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
222 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
223 rewriter.eraseOp(op);
231struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
232 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
235 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
239 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
240 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
241 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
242 if (!layoutA || !layoutB || !layoutCd)
245 auto wiResultTyOrFailure =
247 auto wiATypeOrFailure =
249 auto wiBTypeOrFailure =
251 auto expectedWiResultTyOrFailure =
253 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
254 failed(wiBTypeOrFailure))
255 return rewriter.notifyMatchFailure(
256 op,
"failed to calculate supported workitem vector types for DpasOp "
258 if (failed(expectedWiResultTyOrFailure))
259 return rewriter.notifyMatchFailure(
260 op,
"unable to compute expected workitem vector type for DpasOp from "
262 auto newOp = xegpu::DpasOp::create(
263 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
265 wiATypeOrFailure.value()),
267 wiBTypeOrFailure.value()),
269 wiResultTyOrFailure.value()),
273 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
274 expectedWiResultTyOrFailure.value()));
287 ConversionPatternRewriter &rewriter)
const override {
294 return rewriter.notifyMatchFailure(
295 op,
"operation result is not a vector type");
297 xegpu::DistributeLayoutAttr layout =
299 if (!layout || !layout.isForSubgroup())
300 return rewriter.notifyMatchFailure(
301 op,
"operation result does not have subgroup distribute layout");
303 auto wiShapeOrFailure =
306 if (failed(wiShapeOrFailure))
307 return rewriter.notifyMatchFailure(
308 op,
"unable to compute workitem vector type from the layout");
310 VectorType newResultType = wiShapeOrFailure.value();
316 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
319 Operation *newOp = rewriter.create(state);
321 rewriter.replaceOp(op, newOp->
getResult(0));
328struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
329 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
332 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter)
const override {
334 auto resultType = dyn_cast<VectorType>(op.getType());
339 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
341 return rewriter.notifyMatchFailure(
342 op,
"only dense splat vector constants are supported");
344 xegpu::DistributeLayoutAttr layout =
346 if (!layout || !layout.isForSubgroup())
347 return rewriter.notifyMatchFailure(
348 op,
"operation result does not have subgroup distribute layout");
350 auto wiShapeOrFailure =
353 if (failed(wiShapeOrFailure))
354 return rewriter.notifyMatchFailure(
355 op,
"unable to compute workitem vector type from the layout");
357 VectorType newResultType = wiShapeOrFailure.value();
358 auto sclarValue = dense.getSplatValue<
Attribute>();
361 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
363 rewriter.replaceOp(op, newOp.getResult());
369struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
370 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
373 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
374 ConversionPatternRewriter &rewriter)
const override {
375 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
380 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
381 op.getMixedOffsets(), op.getL1HintAttr(),
382 op.getL2HintAttr(), op.getL3HintAttr(),
384 rewriter.eraseOp(op);
422struct SgToWiLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
423 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
426 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter)
const override {
428 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
432 VectorType origResultTy = op.getValueType();
437 int chunkSize = op.getChunkSize().value_or(1);
438 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
441 shape.take_front(origResultTy.getRank() - effectiveVecRank),
442 [](
int64_t d) { return d != 1; }))
443 return rewriter.notifyMatchFailure(
444 op,
"Only unit dimensions allowed for the leading "
445 "dimensions of the load vector!");
447 auto distResultTyOrFailure =
449 if (failed(distResultTyOrFailure))
450 return rewriter.notifyMatchFailure(
452 "unable to compute expected workitem vector type from lane layout");
454 VectorType distResultTy = distResultTyOrFailure.value();
455 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
456 distResultTy.getElementType());
459 Value distOffsets = adaptor.getOffsets();
460 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
461 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
462 distOffsetsTy.getElementType());
463 distOffsets = castValueTo(
466 Value distMask = adaptor.getMask();
467 auto distMaskTy = cast<VectorType>(distMask.
getType());
468 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
469 distMaskTy.getElementType());
473 Value distSource = adaptor.getSource();
474 auto newOp = xegpu::LoadGatherOp::create(
475 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
476 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
477 op.getL3HintAttr(),
nullptr);
480 if (distResultTy1D != distResultTy)
483 rewriter.replaceOp(op,
result);
492struct SgToWiVectorReduction :
public OpConversionPattern<vector::ReductionOp> {
493 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
496 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
497 ConversionPatternRewriter &rewriter)
const override {
501 if (!layout || !layout.isForSubgroup())
504 VectorType srcVecType = op.getSourceVectorType();
506 if (srcVecType.getRank() != 1)
507 return rewriter.notifyMatchFailure(
508 op,
"Only rank 1 reductions can be distributed.");
510 if (layout.getRank() != srcVecType.getRank())
511 return rewriter.notifyMatchFailure(
512 op,
"Layout rank does not match vector rank.");
515 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
518 return rewriter.notifyMatchFailure(
519 op,
"xegpu::ReductionOp require target attribute attached to "
520 "determine subgroup size");
524 srcVecType.getShape()[0] % sgSize != 0)
525 return rewriter.notifyMatchFailure(op,
526 "Invalid layout or reduction vector "
527 "dimension must match subgroup size.");
529 if (!op.getType().isIntOrFloat())
530 return rewriter.notifyMatchFailure(
531 op,
"Reduction distribution currently only supports floats and "
535 Value laneValVec = adaptor.getVector();
539 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
542 if (adaptor.getAcc())
544 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
546 rewriter.replaceOp(op, fullReduce);
555struct SgToWiMultiDimReduction
556 :
public OpConversionPattern<vector::MultiDimReductionOp> {
557 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
560 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter)
const override {
564 assert(reductionDims.size() == 1 &&
565 "Expecting single reduction dimension for subgroup multi "
568 VectorType sourceType = op.getSourceVectorType();
569 int64_t rank = sourceType.getRank();
572 if (llvm::any_of(
shape.take_front(rank - 2),
573 [](
int64_t d) { return d != 1; }))
574 return rewriter.notifyMatchFailure(
575 op,
"only unit leading dimensions are supported for "
576 "multi_reduction with rank > 2");
580 if (op.getType().isIntOrFloat()) {
581 auto reductionDim = reductionDims[0];
582 VectorType origSourceType = op.getSourceVectorType();
583 int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
587 op.getKind(), reductionDimSize);
589 if (adaptor.getAcc())
591 result, adaptor.getAcc());
592 }
else if (isReductionLaneLocal(op)) {
596 auto reductionDim = reductionDims[0];
600 reductionDim, op.getLoc(), rewriter);
602 auto reductionDim = reductionDims[0];
603 VectorType sourceType = op.getSourceVectorType();
604 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
608 reductionDim, reductionDimSize, op.getLoc(), rewriter);
610 rewriter.replaceOp(op,
result);
619 ConversionPatternRewriter &rewriter,
Location loc,
622 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
623 mlir::IntegerAttr());
625 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
626 if (failed(maybeCoords))
628 assert(maybeCoords.value().size() == 1 &&
629 "Expected one set of distributed offsets");
633 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
637struct SgToWiLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
638 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
641 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter)
const override {
643 auto layout = op.getLayoutAttr();
648 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
650 return rewriter.notifyMatchFailure(
651 op,
"the matrix op payload must be a vector type");
653 auto loc = op.getLoc();
654 auto offsets = op.getMixedOffsets();
656 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
658 FailureOr<VectorType> distPayloadTyOrFailure =
659 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
660 if (failed(distPayloadTyOrFailure))
661 return rewriter.notifyMatchFailure(
662 op,
"Failed to distribute matrix op payload based on layout.");
668 if (!op.getSubgroupBlockIoAttr()) {
669 newCoords = computeDistributedCoordsForMatrixOp(
670 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
671 if (newCoords.empty())
672 return rewriter.notifyMatchFailure(
673 op,
"Failed to compute distributed coordinates.");
677 ShapedType::kDynamic);
679 rewriter.getDenseI64ArrayAttr(newConstOffsets);
681 auto newOp = xegpu::LoadMatrixOp::create(
682 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
683 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
684 xegpu::DistributeLayoutAttr{});
685 rewriter.replaceOp(op, newOp.getResult());
691struct SgToWiVectorTranspose :
public OpConversionPattern<vector::TransposeOp> {
692 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
695 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
696 ConversionPatternRewriter &rewriter)
const override {
697 xegpu::DistributeLayoutAttr sourceLayout =
699 xegpu::DistributeLayoutAttr resultLayout =
701 if (!sourceLayout || !resultLayout)
702 return rewriter.notifyMatchFailure(
703 op,
"the source or result vector of the transpose op lacks layout "
707 if (!resultLayout.isTransposeOf(sourceLayout, perm,
709 return rewriter.notifyMatchFailure(
710 op,
"the source or result vector layouts must be transposes of "
712 FailureOr<VectorType> distributedResultTypeOrFailure =
713 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
714 if (failed(distributedResultTypeOrFailure))
715 return rewriter.notifyMatchFailure(
716 op,
"Failed to distribute the result vector type in "
717 "vector::Transpose op");
718 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
719 adaptor.getVector(), perm);
720 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
721 distributedResultTypeOrFailure.value()));
728struct SgToWiVectorBitcast :
public OpConversionPattern<vector::BitCastOp> {
729 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
732 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter)
const override {
734 xegpu::DistributeLayoutAttr resultLayout =
737 return rewriter.notifyMatchFailure(
738 op,
"result vector of the bitcast op lacks layout attribute");
739 FailureOr<VectorType> distributedResultTypeOrFailure =
740 getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
741 if (failed(distributedResultTypeOrFailure))
742 return rewriter.notifyMatchFailure(
743 op,
"Failed to distribute the result vector type in "
744 "vector::BitCast op");
745 auto newOp = vector::BitCastOp::create(
746 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
747 adaptor.getSource());
748 rewriter.replaceOp(op, newOp.getResult());
776template <
typename OpType,
777 typename = std::enable_if_t<llvm::is_one_of<
778 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
779struct SgToWiCreateMask :
public OpConversionPattern<OpType> {
780 using OpConversionPattern<OpType>::OpConversionPattern;
783 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
784 ConversionPatternRewriter &rewriter)
const override {
785 xegpu::DistributeLayoutAttr layout =
787 if (!layout || !layout.isForSubgroup())
788 return rewriter.notifyMatchFailure(
789 op,
"operation result does not have subgroup distribute layout");
791 VectorType origType = op.getType();
792 FailureOr<VectorType> distTypeOrFailure =
793 getDistVecTypeBasedOnLaneLayout(layout, origType);
794 if (failed(distTypeOrFailure))
795 return rewriter.notifyMatchFailure(
796 op,
"unable to compute workitem vector type from the layout");
798 VectorType distType = distTypeOrFailure.value();
803 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
804 origBounds.append(op.getOperands().begin(), op.getOperands().end());
806 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
807 for (
auto dimSize : dimSizes)
808 origBounds.push_back(
815 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
816 mlir::IntegerAttr());
817 auto maybeCoordsVec =
818 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
819 if (failed(maybeCoordsVec))
820 return rewriter.notifyMatchFailure(
821 op,
"failed to compute distributed coordinates from layout");
824 int64_t numElements = distType.getNumElements();
825 assert(
static_cast<int64_t>(coordsVec.size()) == numElements &&
826 "number of coordinate sets must match number of distributed "
833 for (
auto &coords : coordsVec) {
834 Value inBounds = trueVal;
835 for (
size_t i = 0; i < coords.size(); ++i) {
836 Value cmp = arith::CmpIOp::create(
837 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
838 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
840 maskBits.push_back(inBounds);
845 if (numElements == 1) {
847 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
850 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
852 rewriter.replaceOp(op,
result);
858struct SgToWiStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
859 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
862 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter)
const override {
864 auto layout = op.getLayoutAttr();
869 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
871 return rewriter.notifyMatchFailure(
872 op,
"the matrix op payload must be a vector type");
874 auto loc = op.getLoc();
875 auto offsets = op.getMixedOffsets();
877 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
879 FailureOr<VectorType> distPayloadTyOrFailure =
880 getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
881 if (failed(distPayloadTyOrFailure))
882 return rewriter.notifyMatchFailure(
883 op,
"Failed to distribute matrix op payload based on layout.");
889 if (!op.getSubgroupBlockIoAttr()) {
890 newCoords = computeDistributedCoordsForMatrixOp(
891 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
892 if (newCoords.empty())
893 return rewriter.notifyMatchFailure(
894 op,
"Failed to compute distributed coordinates.");
898 ShapedType::kDynamic);
900 rewriter.getDenseI64ArrayAttr(newConstOffsets);
902 xegpu::StoreMatrixOp::create(
905 distPayloadTyOrFailure.value()),
906 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
907 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
908 rewriter.eraseOp(op);
947struct SgToWiStoreScatter :
public OpConversionPattern<xegpu::StoreScatterOp> {
948 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
951 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter)
const override {
953 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
957 VectorType origValueTy = op.getValueType();
962 int chunkSize = op.getChunkSize().value_or(1);
963 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
965 if (llvm::any_of(
shape.take_front(origValueTy.getRank() - effectiveVecRank),
966 [](
int64_t d) { return d != 1; }))
967 return rewriter.notifyMatchFailure(
968 op,
"Only unit dimensions allowed for the leading "
969 "dimensions of the store vector!");
971 auto distValueTyOrFailure =
973 if (failed(distValueTyOrFailure))
974 return rewriter.notifyMatchFailure(
976 "unable to compute expected workitem vector type from lane layout");
978 VectorType distValueTy = distValueTyOrFailure.value();
979 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
980 distValueTy.getElementType());
982 Value distValue = adaptor.getValue();
983 if (distValue.
getType() != distValueTy1D)
988 Value distOffsets = adaptor.getOffsets();
989 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
990 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
991 distOffsetsTy.getElementType());
992 distOffsets = castValueTo(
995 Value distMask = adaptor.getMask();
996 auto distMaskTy = cast<VectorType>(distMask.
getType());
997 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
998 distMaskTy.getElementType());
1002 Value distDest = adaptor.getDest();
1003 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1004 distOffsets, distMask, op.getChunkSizeAttr(),
1005 op.getL1HintAttr(), op.getL2HintAttr(),
1006 op.getL3HintAttr(),
nullptr);
1007 rewriter.eraseOp(op);
1016struct SgToWiVectorStep :
public OpConversionPattern<vector::StepOp> {
1017 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1020 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter)
const override {
1022 xegpu::DistributeLayoutAttr resultLayout =
1024 if (!resultLayout || !resultLayout.isForSubgroup())
1025 return rewriter.notifyMatchFailure(
1026 op,
"the result vector of the step op lacks subgroup layout");
1028 auto loc = op.getLoc();
1029 auto stepResultVecTy = op.getResult().getType();
1030 auto wiShapeOrFailure =
1032 if (failed(wiShapeOrFailure))
1033 return rewriter.notifyMatchFailure(
1034 op,
"unable to compute workitem vector type from the layout");
1035 VectorType newVecTy = wiShapeOrFailure.value();
1037 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1038 mlir::IntegerAttr());
1039 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1040 rewriter, loc, laneId, stepResultVecTy.getShape());
1041 if (failed(laneDataBlockCoords))
1042 return rewriter.notifyMatchFailure(
1043 op,
"failed to compute lane data block coordinates");
1045 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1046 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1047 assert(
static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
1048 newVecTy.getNumElements() / laneDataBlockLength);
1057 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1058 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1059 stepVals.push_back(laneDataBlockStartCoord);
1060 for (
int i = 1; i < laneDataBlockLength; ++i) {
1062 stepVals.push_back(arith::AddIOp::create(
1063 rewriter, loc, laneDataBlockStartCoord, offset));
1066 assert(
static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
1067 "Expecting the number of step values to match the number of "
1068 "elements in the vector");
1070 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1071 rewriter.replaceOp(op, stepOpVal);
1078struct SgToWiVectorExtract :
public OpConversionPattern<vector::ExtractOp> {
1079 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1082 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1083 ConversionPatternRewriter &rewriter)
const override {
1085 auto resultType = dyn_cast<VectorType>(op.getType());
1087 return rewriter.notifyMatchFailure(op,
"scalar extract not supported");
1089 xegpu::DistributeLayoutAttr layout =
1091 if (!layout || !layout.isForSubgroup())
1096 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1098 [](
int64_t v) {
return v != 1; }))
1099 return rewriter.notifyMatchFailure(
1100 op,
"only innermost dimension distribution is supported for "
1103 auto newOp = vector::ExtractOp::create(
1104 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1105 rewriter.replaceOp(op, newOp.getResult());
1111struct SgToWiVectorShapeCast :
public OpConversionPattern<vector::ShapeCastOp> {
1112 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1115 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1116 ConversionPatternRewriter &rewriter)
const override {
1117 xegpu::DistributeLayoutAttr resultLayout =
1119 if (!resultLayout || !resultLayout.isForSubgroup())
1120 return rewriter.notifyMatchFailure(
1121 op,
"the result vector of the shape_cast op lacks subgroup layout");
1124 resultLayout, op.getResultVectorType());
1125 if (failed(resultDistTypeOrFailure))
1126 return rewriter.notifyMatchFailure(
1127 op,
"failed to get distributed vector type for result");
1129 Value source = adaptor.getSource();
1130 auto newShapeCast = vector::ShapeCastOp::create(
1131 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1132 rewriter.replaceOp(op, newShapeCast);
1140struct SgToWiVectorExtractStridedSlice
1141 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
1142 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1145 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter)
const override {
1147 xegpu::DistributeLayoutAttr resultLayout =
1149 if (!resultLayout || !resultLayout.isForSubgroup())
1152 VectorType resultType = op.getType();
1153 auto distResultTyOrFailure =
1155 if (failed(distResultTyOrFailure))
1156 return rewriter.notifyMatchFailure(
1157 op,
"unable to compute distributed vector type from lane layout");
1158 VectorType distResultTy = *distResultTyOrFailure;
1161 getDistributedDims(resultType, distResultTy);
1164 int64_t sourceRank = op.getSourceVectorType().getRank();
1166 llvm::map_to_vector(op.getSizes(), [](
Attribute attr) { return attr; });
1168 op.getOffsets(), [](
Attribute attr) { return attr; });
1170 op.getStrides(), [](
Attribute attr) { return attr; });
1171 for (
int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1172 updatedSizes.push_back(
1173 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1174 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1175 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1180 if (!distributedDims.empty()) {
1181 if (distributedDims.size() != 1)
1182 return rewriter.notifyMatchFailure(
1183 op,
"only single dimension distribution is supported");
1184 int64_t distDim = distributedDims[0];
1187 return rewriter.notifyMatchFailure(
1188 op,
"target attribute required to determine subgroup size");
1191 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1192 return rewriter.notifyMatchFailure(
1193 op,
"source of extract_strided_slice lacks distribution layout");
1194 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1195 if (sourceDistrDimSize % subgroupSize != 0)
1196 return rewriter.notifyMatchFailure(
1197 op,
"source size along distributed dim is not a multiple of "
1199 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1202 if (distDim <
static_cast<int64_t>(sourceLaneData.size()) &&
1203 sourceLaneData[distDim] != 1)
1204 return rewriter.notifyMatchFailure(
1205 op,
"expecting unit lane data along the distributed dimension");
1207 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1208 if (distrDimOffset % subgroupSize != 0)
1209 return rewriter.notifyMatchFailure(
1210 op,
"offset along distributed dim is not a multiple of "
1213 updatedSizes[distDim] =
1214 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1215 updatedOffsets[distDim] =
1216 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1219 auto newOp = vector::ExtractStridedSliceOp::create(
1220 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1221 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1222 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1223 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1224 rewriter.replaceOp(op, newOp.getResult());
1286struct SgToWiBroadcast :
public OpConversionPattern<vector::BroadcastOp> {
1287 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1290 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter)
const override {
1292 xegpu::DistributeLayoutAttr resultLayout =
1294 if (!resultLayout || !resultLayout.isForSubgroup())
1295 return rewriter.notifyMatchFailure(
1296 op,
"result does not have subgroup distribute layout");
1298 VectorType destType = op.getResultVectorType();
1299 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1301 xegpu::DistributeLayoutAttr sourceLayout =
1305 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1308 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1310 "broadcast source layout must be a slice of result layout");
1311 }
else if (rankDiff == 0) {
1313 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1315 broadcastUnitDimsSet.end());
1316 assert(sourceLayout.isEqualTo(
1317 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1318 "The sg_data for unit dimensions should be set as 1");
1319 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1324 return rewriter.notifyMatchFailure(
1325 op,
"broadcast from scalar must not have a layout attribute");
1330 if (failed(destDistType))
1331 return rewriter.notifyMatchFailure(
1332 op,
"failed to distribute the result vector type");
1334 Value source = adaptor.getSource();
1336 if (source.
getType() == destDistType.value()) {
1337 rewriter.replaceOp(op, source);
1341 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1342 destDistType.value(), source);
1343 rewriter.replaceOp(op, newOp);
1351struct SgToWiVectorInsertStridedSlice
1352 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
1353 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1356 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1357 ConversionPatternRewriter &rewriter)
const override {
1358 xegpu::DistributeLayoutAttr resultLayout =
1360 if (!resultLayout || !resultLayout.isForSubgroup())
1363 VectorType destType = op.getDestVectorType();
1364 auto distDestTyOrFailure =
1366 if (failed(distDestTyOrFailure))
1367 return rewriter.notifyMatchFailure(
1368 op,
"unable to compute distributed vector type from lane layout");
1369 VectorType distDestTy = *distDestTyOrFailure;
1372 getDistributedDims(destType, distDestTy);
1375 op.getOffsets(), [](
Attribute attr) { return attr; });
1377 if (!destDistributedDims.empty()) {
1378 if (destDistributedDims.size() != 1)
1379 return rewriter.notifyMatchFailure(
1380 op,
"only single dimension distribution is supported");
1381 int64_t destDistDim = destDistributedDims[0];
1385 return rewriter.notifyMatchFailure(
1386 op,
"target attribute required to determine subgroup size");
1389 VectorType srcType = op.getSourceVectorType();
1392 destDistDim - (destType.getRank() - srcType.getRank());
1393 if (sourceDistDim < 0)
1394 return rewriter.notifyMatchFailure(
1395 op,
"distributed dimension must be in the last k dims of dest");
1399 if (!destLayout || !sourceLayout ||
1400 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1401 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1402 return rewriter.notifyMatchFailure(
1403 op,
"source or dest of insert_strided_slice lacks distribution "
1406 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1407 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1410 if ((destDistDim <
static_cast<int64_t>(destLaneData.size()) &&
1411 destLaneData[destDistDim] != 1) ||
1412 (sourceDistDim <
static_cast<int64_t>(sourceLaneData.size()) &&
1413 sourceLaneData[sourceDistDim] != 1))
1414 return rewriter.notifyMatchFailure(
1415 op,
"expecting unit lane data along the distributed dimension");
1417 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1418 if (srcDistrDimSize % subgroupSize != 0)
1419 return rewriter.notifyMatchFailure(
1420 op,
"source distributed dim size is not a multiple of "
1424 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1425 if (destDistrDimOffset % subgroupSize != 0)
1426 return rewriter.notifyMatchFailure(
1427 op,
"offset along distributed dim is not a multiple of "
1430 updatedOffsets[destDistDim] =
1431 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1434 auto newOp = vector::InsertStridedSliceOp::create(
1435 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1437 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1438 rewriter.replaceOp(op, newOp.getResult());
1445struct SgToWiVectorInsert :
public OpConversionPattern<vector::InsertOp> {
1446 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1449 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1450 ConversionPatternRewriter &rewriter)
const override {
1452 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1454 return rewriter.notifyMatchFailure(op,
"scalar insert not supported");
1456 xegpu::DistributeLayoutAttr layout =
1458 if (!layout || !layout.isForSubgroup())
1463 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1465 [](
int64_t v) {
return v != 1; }))
1466 return rewriter.notifyMatchFailure(
1467 op,
"only innermost dimension distribution is supported for "
1470 auto newOp = vector::InsertOp::create(
1471 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1472 op.getMixedPosition());
1473 rewriter.replaceOp(op, newOp.getResult());
1479struct SgToWiConvertLayout
1480 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
1481 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1484 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1485 ConversionPatternRewriter &rewriter)
const override {
1486 auto inputLayout = op.getInputLayoutAttr();
1487 auto targetLayout = op.getTargetLayoutAttr();
1488 Type valType = op.getResult().getType();
1491 rewriter.replaceOp(op, op.getSource());
1495 auto resShape = cast<VectorType>(valType).getShape();
1497 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1499 return rewriter.notifyMatchFailure(
1500 op,
"lowering incompatible convert_layout not yet supported");
1503 rewriter.replaceOp(op, adaptor.getSource());
1508struct XeGPUSgToWiDistributeExperimentalPass
1509 :
public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
1510 XeGPUSgToWiDistributeExperimentalPass> {
1511 void runOnOperation()
override;
1516void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1519 Operation *root = getOperation();
1521 signalPassFailure();
1526 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1528 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1532 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1533 mlir::ValueRange inputs,
1534 mlir::Location loc) -> mlir::Value {
1535 UnrealizedConversionCastOp castOp =
1536 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1537 return castOp.getResult(0);
1541 TypeConverter typeConverter;
1543 typeConverter.addSourceMaterialization(materializeCast);
1544 typeConverter.addTargetMaterialization(materializeCast);
1549 typeConverter, patterns,
target);
1550 target.addLegalOp<UnrealizedConversionCastOp>();
1551 (void)applyPartialConversion(root,
target, std::move(patterns));
1562 OpBuilder builder(root);
1563 root->
walk([&](UnrealizedConversionCastOp op) {
1565 if (existingCasts.contains(op))
1568 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1571 auto singleInput = op.getInputs()[0];
1572 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1573 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1574 if (!inputTy || !outputTy)
1580 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1581 if (!definingOp || !definingOp->hasOneUse())
1583 auto inputOfDefiningOp = definingOp.getInputs()[0];
1586 auto inputOfDefiningOpTy =
1587 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1588 if (inputOfDefiningOpTy &&
1589 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1591 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1592 outputTy, inputOfDefiningOp);
1593 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
1599 bool changed =
true;
1602 root->
walk([&](UnrealizedConversionCastOp op) {
1604 if (existingCasts.contains(op))
1606 if (op.use_empty()) {
1619 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
1620 if (!isa<TensorDescType, VectorType>(type))
1622 return std::nullopt;
1625 typeConverter.addConversion([](TensorDescType type) ->
Type {
1626 if (type.getLayoutAttr()) {
1627 return type.dropLayouts();
1633 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
1636 if (!isa<VectorType>(type))
1637 return std::nullopt;
1639 if (!layout || !layout.isForSubgroup())
1642 auto newTyOrFailure =
1644 if (failed(newTyOrFailure))
1646 return *newTyOrFailure;
1655 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1656 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
1658 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
1659 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1662 return !anchorOp.getAnchorLayout();
1665 target.addDynamicallyLegalOp<arith::ConstantOp>(
1666 [=](arith::ConstantOp op) ->
bool {
1668 if (!isa<VectorType>(op.getResult().getType()))
1674 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1675 [=](
Operation *op) -> std::optional<bool> {
1680 if (op->getNumResults() != 1)
1683 VectorType resultType =
1684 dyn_cast<VectorType>(op->getResult(0).getType());
1689 for (
Value operand : op->getOperands()) {
1690 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1691 if (!operandType || operandType.getShape() != resultType.getShape()) {
1699 target.addDynamicallyLegalOp<vector::ReductionOp>(
1700 [=](vector::ReductionOp op) ->
bool {
1705 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1706 [=](vector::MultiDimReductionOp op) ->
bool {
1707 return !isValidSubgroupMultiReductionOp(op);
1709 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1710 vector::TransposeOp, vector::BitCastOp,
1711 vector::ShapeCastOp, vector::StepOp,
1712 vector::BroadcastOp>([=](
Operation *op) ->
bool {
1715 target.addDynamicallyLegalOp<vector::ExtractOp>(
1716 [=](vector::ExtractOp op) ->
bool {
1717 if (!isa<VectorType>(op.getType()))
1721 target.addDynamicallyLegalOp<vector::InsertOp>(
1722 [=](vector::InsertOp op) ->
bool {
1725 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1726 [=](vector::ExtractStridedSliceOp op) ->
bool {
1729 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1730 [=](vector::InsertStridedSliceOp op) ->
bool {
1733 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1734 patterns.
add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1735 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1736 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1737 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1738 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1739 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1740 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1741 SgToWiVectorShapeCast, SgToWiBroadcast,
1742 SgToWiCreateMask<vector::CreateMaskOp>,
1743 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
virtual int getSubgroupSize() const =0