29#include "llvm/ADT/SetVector.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
36#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
37#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
43#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
44#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49static Value castValueTo(ConversionPatternRewriter &rewriter,
52 if (v.getType() == expectedTy)
55 if (isa<VectorType>(v.getType()) &&
56 v.getType().getNumElements() == expectedTy.getNumElements())
57 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
60 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
62 return newOp.getResult(0);
66static LogicalResult verifyLayouts(
Operation *root) {
68 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
69 auto layout = anchorOp.getAnchorLayout();
71 nestedOp->
emitError(
"expected anchor layout attribute on operation");
79 if (isa<VectorType>(
result.getType())) {
83 "expected result layout attribute on vector result");
90 return walkResult.wasInterrupted() ? failure() :
success();
96static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
99 if (!resLayout || !resLayout.isForSubgroup())
101 VectorType resTy = dyn_cast<VectorType>(op.getType());
105 FailureOr<VectorType> resDistTypeOrFailure =
106 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
107 if (failed(resDistTypeOrFailure))
109 return op.getReductionDims().size() == 1;
116static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
118 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
119 "MultiDimReductionOp");
121 VectorType resTy = dyn_cast<VectorType>(op.getType());
122 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
123 return resTy != resDistTypeOrFailure.value();
129 VectorType distributedType) {
130 assert(originalType.getRank() == distributedType.getRank() &&
131 "original and distributed vector types must have the same rank");
133 for (
int64_t i = 0; i < originalType.getRank(); ++i) {
134 if (distributedType.getDimSize(i) != originalType.getDimSize(i))
135 distributedDims.push_back(i);
137 return distributedDims;
142struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
143 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
146 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 xegpu::TensorDescType resultType = op.getType();
150 if (!resultType.getLayout())
153 auto newOp = xegpu::CreateNdDescOp::create(
154 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
156 rewriter.replaceOp(op, newOp.getResult());
164struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
165 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
168 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
170 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
176 if (op.getTensorDescType().getLayout() != layout)
177 return rewriter.notifyMatchFailure(
178 op,
"conflicting layout attributes on tensor descriptor and anchor");
181 return rewriter.notifyMatchFailure(
182 op,
"xegpu::LoadNdOp require target attribute attached to "
183 "determine transpose "
185 auto supportedWiResultTyOrFailure =
187 auto expectedWiResultTyOrFailure =
189 if (failed(supportedWiResultTyOrFailure))
190 return rewriter.notifyMatchFailure(
191 op,
"unable to compute the workitem vector type for LoadNdOp");
192 if (failed(expectedWiResultTyOrFailure))
193 return rewriter.notifyMatchFailure(
195 "unable to compute expected workitem vector type from lane layout");
196 auto newOp = xegpu::LoadNdOp::create(
197 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
198 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
199 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
200 op.getL3HintAttr(),
nullptr);
206 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
207 expectedWiResultTyOrFailure.value()));
215struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
216 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
219 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter)
const override {
221 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
227 if (op.getTensorDescType().getLayout() != layout)
228 return rewriter.notifyMatchFailure(
229 op,
"conflicting layout attributes on tensor descriptor and anchor");
231 if (valueLayout != layout)
232 return rewriter.notifyMatchFailure(
233 op,
"conflicting layout attributes on value and anchor");
234 auto supportedWiValueTyOrFailure =
236 if (failed(supportedWiValueTyOrFailure))
237 return rewriter.notifyMatchFailure(
239 "unable to compute wi vector type for StoreNdOp value from tensor "
242 xegpu::StoreNdOp::create(
243 rewriter, op.getLoc(),
245 supportedWiValueTyOrFailure.value()),
246 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
247 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
248 rewriter.eraseOp(op);
256struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
257 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
260 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override {
264 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
265 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
266 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
267 if (!layoutA || !layoutB || !layoutCd)
270 auto wiResultTyOrFailure =
272 auto wiATypeOrFailure =
274 auto wiBTypeOrFailure =
276 auto expectedWiResultTyOrFailure =
278 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
279 failed(wiBTypeOrFailure))
280 return rewriter.notifyMatchFailure(
281 op,
"failed to calculate supported workitem vector types for DpasOp "
283 if (failed(expectedWiResultTyOrFailure))
284 return rewriter.notifyMatchFailure(
285 op,
"unable to compute expected workitem vector type for DpasOp from "
287 auto newOp = xegpu::DpasOp::create(
288 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
290 wiATypeOrFailure.value()),
292 wiBTypeOrFailure.value()),
294 wiResultTyOrFailure.value()),
298 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
299 expectedWiResultTyOrFailure.value()));
306struct SgToWiElementWise :
public ConversionPattern {
308 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
312 ConversionPatternRewriter &rewriter)
const override {
319 return rewriter.notifyMatchFailure(
320 op,
"operation result is not a vector type");
322 xegpu::DistributeLayoutAttr layout =
324 if (!layout || !layout.isForSubgroup())
325 return rewriter.notifyMatchFailure(
326 op,
"operation result does not have subgroup distribute layout");
328 auto wiShapeOrFailure =
331 if (failed(wiShapeOrFailure))
332 return rewriter.notifyMatchFailure(
333 op,
"unable to compute workitem vector type from the layout");
335 VectorType newResultType = wiShapeOrFailure.value();
337 state.addOperands(operands);
338 state.addTypes(newResultType);
341 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
342 state.addAttribute(attr.getName(), attr.getValue());
344 Operation *newOp = rewriter.create(state);
353struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
354 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
357 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
359 auto resultType = dyn_cast<VectorType>(op.getType());
364 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
366 return rewriter.notifyMatchFailure(
367 op,
"only dense splat vector constants are supported");
369 xegpu::DistributeLayoutAttr layout =
371 if (!layout || !layout.isForSubgroup())
372 return rewriter.notifyMatchFailure(
373 op,
"operation result does not have subgroup distribute layout");
375 auto wiShapeOrFailure =
378 if (failed(wiShapeOrFailure))
379 return rewriter.notifyMatchFailure(
380 op,
"unable to compute workitem vector type from the layout");
382 VectorType newResultType = wiShapeOrFailure.value();
383 auto sclarValue = dense.getSplatValue<
Attribute>();
386 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
388 rewriter.replaceOp(op, newOp.
getResult());
394struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
395 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
398 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter)
const override {
400 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
405 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
406 op.getMixedOffsets(), op.getL1HintAttr(),
407 op.getL2HintAttr(), op.getL3HintAttr(),
409 rewriter.eraseOp(op);
447struct SgToWiLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
448 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
451 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
452 ConversionPatternRewriter &rewriter)
const override {
453 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
457 VectorType origResultTy = op.getValueType();
462 int chunkSize = op.getChunkSize().value_or(1);
463 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
464 ArrayRef<int64_t> shape = origResultTy.getShape();
466 shape.take_front(origResultTy.getRank() - effectiveVecRank),
467 [](int64_t d) { return d != 1; }))
468 return rewriter.notifyMatchFailure(
469 op,
"Only unit dimensions allowed for the leading "
470 "dimensions of the load vector!");
472 auto distResultTyOrFailure =
474 if (
failed(distResultTyOrFailure))
475 return rewriter.notifyMatchFailure(
477 "unable to compute expected workitem vector type from lane layout");
479 VectorType distResultTy = distResultTyOrFailure.value();
480 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
481 distResultTy.getElementType());
484 Value distOffsets = adaptor.getOffsets();
485 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
486 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
487 distOffsetsTy.getElementType());
488 distOffsets = castValueTo(
491 Value distMask = adaptor.getMask();
492 auto distMaskTy = cast<VectorType>(distMask.
getType());
493 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
494 distMaskTy.getElementType());
498 Value distSource = adaptor.getSource();
499 auto newOp = xegpu::LoadGatherOp::create(
500 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
501 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
502 op.getL3HintAttr(),
nullptr);
505 if (distResultTy1D != distResultTy)
508 rewriter.replaceOp(op,
result);
517struct SgToWiVectorReduction :
public OpConversionPattern<vector::ReductionOp> {
518 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
521 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
522 ConversionPatternRewriter &rewriter)
const override {
526 if (!layout || !layout.isForSubgroup())
529 VectorType srcVecType = op.getSourceVectorType();
531 if (srcVecType.getRank() != 1)
532 return rewriter.notifyMatchFailure(
533 op,
"Only rank 1 reductions can be distributed.");
535 if (layout.getRank() != srcVecType.getRank())
536 return rewriter.notifyMatchFailure(
537 op,
"Layout rank does not match vector rank.");
540 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
543 return rewriter.notifyMatchFailure(
544 op,
"xegpu::ReductionOp require target attribute attached to "
545 "determine subgroup size");
549 srcVecType.getShape()[0] % sgSize != 0)
550 return rewriter.notifyMatchFailure(op,
551 "Invalid layout or reduction vector "
552 "dimension must match subgroup size.");
554 if (!op.getType().isIntOrFloat())
555 return rewriter.notifyMatchFailure(
556 op,
"Reduction distribution currently only supports floats and "
560 Value laneValVec = adaptor.getVector();
564 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
567 if (adaptor.getAcc())
569 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
571 rewriter.replaceOp(op, fullReduce);
580struct SgToWiMultiDimReduction
581 :
public OpConversionPattern<vector::MultiDimReductionOp> {
582 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
585 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter)
const override {
588 ArrayRef<int64_t> reductionDims = op.getReductionDims();
589 assert(reductionDims.size() == 1 &&
590 "Expecting single reduction dimension for subgroup multi "
593 VectorType sourceType = op.getSourceVectorType();
594 int64_t rank = sourceType.getRank();
596 ArrayRef<int64_t> shape = sourceType.getShape();
597 if (llvm::any_of(shape.take_front(rank - 2),
598 [](int64_t d) { return d != 1; }))
599 return rewriter.notifyMatchFailure(
600 op,
"only unit leading dimensions are supported for "
601 "multi_reduction with rank > 2");
603 if (isReductionLaneLocal(op)) {
605 VectorType resVecTy = dyn_cast<VectorType>(op.getType());
606 auto resDistVecTyOrFailure =
610 result = vector::MultiDimReductionOp::create(
611 rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
612 adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
614 auto reductionDim = reductionDims[0];
615 VectorType sourceType = op.getSourceVectorType();
616 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
620 reductionDim, reductionDimSize, op.getLoc(), rewriter);
622 rewriter.replaceOp(op,
result);
631 ConversionPatternRewriter &rewriter,
Location loc,
634 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
635 mlir::IntegerAttr());
637 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
640 assert(maybeCoords.value().size() == 1 &&
641 "Expected one set of distributed offsets");
645 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
649struct SgToWiLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
650 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
653 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter)
const override {
655 auto layout = op.getLayoutAttr();
660 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
662 return rewriter.notifyMatchFailure(
663 op,
"the matrix op payload must be a vector type");
665 auto loc = op.getLoc();
666 auto offsets = op.getMixedOffsets();
668 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
670 FailureOr<VectorType> distPayloadTyOrFailure =
672 if (
failed(distPayloadTyOrFailure))
673 return rewriter.notifyMatchFailure(
674 op,
"Failed to distribute matrix op payload based on layout.");
676 SmallVector<Value> offsetsAsValues =
679 SmallVector<Value> newCoords = offsetsAsValues;
680 if (!op.getSubgroupBlockIoAttr()) {
681 newCoords = computeDistributedCoordsForMatrixOp(
682 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
683 if (newCoords.empty())
684 return rewriter.notifyMatchFailure(
685 op,
"Failed to compute distributed coordinates.");
688 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
689 ShapedType::kDynamic);
691 rewriter.getDenseI64ArrayAttr(newConstOffsets);
693 auto newOp = xegpu::LoadMatrixOp::create(
694 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
695 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
696 xegpu::DistributeLayoutAttr{});
697 rewriter.replaceOp(op, newOp.
getResult());
703struct SgToWiVectorTranspose :
public OpConversionPattern<vector::TransposeOp> {
704 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
707 matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter)
const override {
709 xegpu::DistributeLayoutAttr sourceLayout =
711 xegpu::DistributeLayoutAttr resultLayout =
713 if (!sourceLayout || !resultLayout)
714 return rewriter.notifyMatchFailure(
715 op,
"the source or result vector of the transpose op lacks layout "
717 ArrayRef<int64_t> perm = op.getPermutation();
719 if (!resultLayout.isTransposeOf(sourceLayout, perm,
720 xegpu::LayoutKind::Lane))
721 return rewriter.notifyMatchFailure(
722 op,
"the source or result vector layouts must be transposes of "
724 FailureOr<VectorType> distributedResultTypeOrFailure =
726 if (
failed(distributedResultTypeOrFailure))
727 return rewriter.notifyMatchFailure(
728 op,
"Failed to distribute the result vector type in "
729 "vector::Transpose op");
730 auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
731 adaptor.getVector(), perm);
732 rewriter.replaceOp(op, castValueTo(rewriter, newOp.
getResult(),
733 distributedResultTypeOrFailure.value()));
740struct SgToWiVectorBitcast :
public OpConversionPattern<vector::BitCastOp> {
741 using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
744 matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter)
const override {
746 xegpu::DistributeLayoutAttr resultLayout =
749 return rewriter.notifyMatchFailure(
750 op,
"result vector of the bitcast op lacks layout attribute");
751 FailureOr<VectorType> distributedResultTypeOrFailure =
753 if (
failed(distributedResultTypeOrFailure))
754 return rewriter.notifyMatchFailure(
755 op,
"Failed to distribute the result vector type in "
756 "vector::BitCast op");
757 auto newOp = vector::BitCastOp::create(
758 rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
759 adaptor.getSource());
760 rewriter.replaceOp(op, newOp.
getResult());
788template <
typename OpType,
789 typename = std::enable_if_t<llvm::is_one_of<
790 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
791struct SgToWiCreateMask :
public OpConversionPattern<OpType> {
792 using OpConversionPattern<OpType>::OpConversionPattern;
795 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter)
const override {
797 xegpu::DistributeLayoutAttr layout =
799 if (!layout || !layout.isForSubgroup())
800 return rewriter.notifyMatchFailure(
801 op,
"operation result does not have subgroup distribute layout");
803 VectorType origType = op.getType();
804 FailureOr<VectorType> distTypeOrFailure =
806 if (
failed(distTypeOrFailure))
807 return rewriter.notifyMatchFailure(
808 op,
"unable to compute workitem vector type from the layout");
810 VectorType distType = distTypeOrFailure.value();
811 Location loc = op.getLoc();
814 SmallVector<Value> origBounds;
815 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
816 origBounds.append(op.getOperands().begin(), op.getOperands().end());
818 auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
819 for (
auto dimSize : dimSizes)
820 origBounds.push_back(
824 ArrayRef<int64_t> origShape = origType.getShape();
827 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
828 mlir::IntegerAttr());
829 auto maybeCoordsVec =
830 layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
831 if (
failed(maybeCoordsVec))
832 return rewriter.notifyMatchFailure(
833 op,
"failed to compute distributed coordinates from layout");
835 SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
836 int64_t numElements = distType.getNumElements();
837 assert(
static_cast<int64_t
>(coordsVec.size()) == numElements &&
838 "number of coordinate sets must match number of distributed "
844 SmallVector<Value> maskBits;
845 for (
auto &coords : coordsVec) {
846 Value inBounds = trueVal;
847 for (
size_t i = 0; i < coords.size(); ++i) {
848 Value cmp = arith::CmpIOp::create(
849 rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
850 inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
852 maskBits.push_back(inBounds);
857 if (numElements == 1) {
859 vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
862 vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
864 rewriter.replaceOp(op,
result);
870struct SgToWiStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
871 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
874 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
875 ConversionPatternRewriter &rewriter)
const override {
876 auto layout = op.getLayoutAttr();
881 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
883 return rewriter.notifyMatchFailure(
884 op,
"the matrix op payload must be a vector type");
886 auto loc = op.getLoc();
887 auto offsets = op.getMixedOffsets();
889 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
891 FailureOr<VectorType> distPayloadTyOrFailure =
893 if (
failed(distPayloadTyOrFailure))
894 return rewriter.notifyMatchFailure(
895 op,
"Failed to distribute matrix op payload based on layout.");
897 SmallVector<Value> offsetsAsValues =
900 SmallVector<Value> newCoords = offsetsAsValues;
901 if (!op.getSubgroupBlockIoAttr()) {
902 newCoords = computeDistributedCoordsForMatrixOp(
903 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
904 if (newCoords.empty())
905 return rewriter.notifyMatchFailure(
906 op,
"Failed to compute distributed coordinates.");
909 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
910 ShapedType::kDynamic);
912 rewriter.getDenseI64ArrayAttr(newConstOffsets);
914 xegpu::StoreMatrixOp::create(
917 distPayloadTyOrFailure.value()),
918 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
919 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
920 rewriter.eraseOp(op);
959struct SgToWiStoreScatter :
public OpConversionPattern<xegpu::StoreScatterOp> {
960 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
963 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter)
const override {
965 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
969 VectorType origValueTy = op.getValueType();
974 int chunkSize = op.getChunkSize().value_or(1);
975 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
976 ArrayRef<int64_t> shape = origValueTy.getShape();
977 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
978 [](int64_t d) { return d != 1; }))
979 return rewriter.notifyMatchFailure(
980 op,
"Only unit dimensions allowed for the leading "
981 "dimensions of the store vector!");
983 auto distValueTyOrFailure =
985 if (
failed(distValueTyOrFailure))
986 return rewriter.notifyMatchFailure(
988 "unable to compute expected workitem vector type from lane layout");
990 VectorType distValueTy = distValueTyOrFailure.value();
991 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
992 distValueTy.getElementType());
994 Value distValue = adaptor.getValue();
995 if (distValue.
getType() != distValueTy1D)
1000 Value distOffsets = adaptor.getOffsets();
1001 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
1002 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
1003 distOffsetsTy.getElementType());
1004 distOffsets = castValueTo(
1007 Value distMask = adaptor.getMask();
1008 auto distMaskTy = cast<VectorType>(distMask.
getType());
1009 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
1010 distMaskTy.getElementType());
1014 Value distDest = adaptor.getDest();
1015 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
1016 distOffsets, distMask, op.getChunkSizeAttr(),
1017 op.getL1HintAttr(), op.getL2HintAttr(),
1018 op.getL3HintAttr(),
nullptr);
1019 rewriter.eraseOp(op);
1028struct SgToWiVectorStep :
public OpConversionPattern<vector::StepOp> {
1029 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1032 matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
1033 ConversionPatternRewriter &rewriter)
const override {
1034 xegpu::DistributeLayoutAttr resultLayout =
1036 if (!resultLayout || !resultLayout.isForSubgroup())
1037 return rewriter.notifyMatchFailure(
1038 op,
"the result vector of the step op lacks subgroup layout");
1040 auto loc = op.getLoc();
1041 auto stepResultVecTy = op.getResult().getType();
1042 auto wiShapeOrFailure =
1044 if (
failed(wiShapeOrFailure))
1045 return rewriter.notifyMatchFailure(
1046 op,
"unable to compute workitem vector type from the layout");
1047 VectorType newVecTy = wiShapeOrFailure.value();
1049 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
1050 mlir::IntegerAttr());
1051 auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
1052 rewriter, loc, laneId, stepResultVecTy.getShape());
1053 if (
failed(laneDataBlockCoords))
1054 return rewriter.notifyMatchFailure(
1055 op,
"failed to compute lane data block coordinates");
1057 auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
1058 auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
1059 assert(
static_cast<int64_t
>(laneDataBlockCoordsVec.size()) ==
1060 newVecTy.getNumElements() / laneDataBlockLength);
1061 SmallVector<Value> stepVals;
1069 for (
auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
1070 auto laneDataBlockStartCoord = laneDataBlockCoords[0];
1071 stepVals.push_back(laneDataBlockStartCoord);
1072 for (
int i = 1; i < laneDataBlockLength; ++i) {
1074 stepVals.push_back(arith::AddIOp::create(
1075 rewriter, loc, laneDataBlockStartCoord, offset));
1078 assert(
static_cast<int64_t
>(stepVals.size()) == newVecTy.getNumElements() &&
1079 "Expecting the number of step values to match the number of "
1080 "elements in the vector");
1082 vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
1083 rewriter.replaceOp(op, stepOpVal);
1090struct SgToWiVectorExtract :
public OpConversionPattern<vector::ExtractOp> {
1091 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
1094 matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
1095 ConversionPatternRewriter &rewriter)
const override {
1097 auto resultType = dyn_cast<VectorType>(op.getType());
1099 return rewriter.notifyMatchFailure(op,
"scalar extract not supported");
1101 xegpu::DistributeLayoutAttr layout =
1103 if (!layout || !layout.isForSubgroup())
1108 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1109 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1110 [](int64_t v) {
return v != 1; }))
1111 return rewriter.notifyMatchFailure(
1112 op,
"only innermost dimension distribution is supported for "
1115 auto newOp = vector::ExtractOp::create(
1116 rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
1117 rewriter.replaceOp(op, newOp.
getResult());
1123struct SgToWiVectorShapeCast :
public OpConversionPattern<vector::ShapeCastOp> {
1124 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1127 matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter)
const override {
1129 xegpu::DistributeLayoutAttr resultLayout =
1131 if (!resultLayout || !resultLayout.isForSubgroup())
1132 return rewriter.notifyMatchFailure(
1133 op,
"the result vector of the shape_cast op lacks subgroup layout");
1136 resultLayout, op.getResultVectorType());
1137 if (
failed(resultDistTypeOrFailure))
1138 return rewriter.notifyMatchFailure(
1139 op,
"failed to get distributed vector type for result");
1141 Value source = adaptor.getSource();
1142 auto newShapeCast = vector::ShapeCastOp::create(
1143 rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
1144 rewriter.replaceOp(op, newShapeCast);
1152struct SgToWiVectorExtractStridedSlice
1153 :
public OpConversionPattern<vector::ExtractStridedSliceOp> {
1154 using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
1157 matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter)
const override {
1159 xegpu::DistributeLayoutAttr resultLayout =
1161 if (!resultLayout || !resultLayout.isForSubgroup())
1164 VectorType resultType = op.getType();
1165 auto distResultTyOrFailure =
1167 if (
failed(distResultTyOrFailure))
1168 return rewriter.notifyMatchFailure(
1169 op,
"unable to compute distributed vector type from lane layout");
1170 VectorType distResultTy = *distResultTyOrFailure;
1172 SmallVector<int64_t> distributedDims =
1173 getDistributedDims(resultType, distResultTy);
1176 int64_t sourceRank = op.getSourceVectorType().getRank();
1177 SmallVector<Attribute> updatedSizes =
1178 llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
1179 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1180 op.getOffsets(), [](Attribute attr) { return attr; });
1181 SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
1182 op.getStrides(), [](Attribute attr) { return attr; });
1183 for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
1184 updatedSizes.push_back(
1185 rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
1186 updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
1187 updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
1192 if (!distributedDims.empty()) {
1193 if (distributedDims.size() != 1)
1194 return rewriter.notifyMatchFailure(
1195 op,
"only single dimension distribution is supported");
1196 int64_t distDim = distributedDims[0];
1199 return rewriter.notifyMatchFailure(
1200 op,
"target attribute required to determine subgroup size");
1203 if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1204 return rewriter.notifyMatchFailure(
1205 op,
"source of extract_strided_slice lacks distribution layout");
1206 int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
1207 if (sourceDistrDimSize % subgroupSize != 0)
1208 return rewriter.notifyMatchFailure(
1209 op,
"source size along distributed dim is not a multiple of "
1211 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1214 if (distDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1215 sourceLaneData[distDim] != 1)
1216 return rewriter.notifyMatchFailure(
1217 op,
"expecting unit lane data along the distributed dimension");
1218 int64_t distrDimOffset =
1219 cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
1220 if (distrDimOffset % subgroupSize != 0)
1221 return rewriter.notifyMatchFailure(
1222 op,
"offset along distributed dim is not a multiple of "
1225 updatedSizes[distDim] =
1226 rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
1227 updatedOffsets[distDim] =
1228 rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
1231 auto newOp = vector::ExtractStridedSliceOp::create(
1232 rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
1233 ArrayAttr::get(rewriter.getContext(), updatedOffsets),
1234 ArrayAttr::get(rewriter.getContext(), updatedSizes),
1235 ArrayAttr::get(rewriter.getContext(), updatedStrides));
1236 rewriter.replaceOp(op, newOp.
getResult());
1298struct SgToWiBroadcast :
public OpConversionPattern<vector::BroadcastOp> {
1299 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
1302 matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
1303 ConversionPatternRewriter &rewriter)
const override {
1304 xegpu::DistributeLayoutAttr resultLayout =
1306 if (!resultLayout || !resultLayout.isForSubgroup())
1307 return rewriter.notifyMatchFailure(
1308 op,
"result does not have subgroup distribute layout");
1310 VectorType destType = op.getResultVectorType();
1311 VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
1313 xegpu::DistributeLayoutAttr sourceLayout =
1317 int64_t rankDiff = destType.getRank() - sourceType.getRank();
1320 if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
1322 "broadcast source layout must be a slice of result layout");
1323 }
else if (rankDiff == 0) {
1325 auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
1326 SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
1327 broadcastUnitDimsSet.end());
1328 assert(sourceLayout.isEqualTo(
1329 sourceLayout.setUnitDimData(broadcastUnitDims)) &&
1330 "The sg_data for unit dimensions should be set as 1");
1331 sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
1336 return rewriter.notifyMatchFailure(
1337 op,
"broadcast from scalar must not have a layout attribute");
1342 if (
failed(destDistType))
1343 return rewriter.notifyMatchFailure(
1344 op,
"failed to distribute the result vector type");
1346 Value source = adaptor.getSource();
1348 if (source.
getType() == destDistType.value()) {
1349 rewriter.replaceOp(op, source);
1353 auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
1354 destDistType.value(), source);
1355 rewriter.replaceOp(op, newOp);
1363struct SgToWiVectorInsertStridedSlice
1364 :
public OpConversionPattern<vector::InsertStridedSliceOp> {
1365 using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
1368 matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
1369 ConversionPatternRewriter &rewriter)
const override {
1370 xegpu::DistributeLayoutAttr resultLayout =
1372 if (!resultLayout || !resultLayout.isForSubgroup())
1375 VectorType destType = op.getDestVectorType();
1376 auto distDestTyOrFailure =
1378 if (
failed(distDestTyOrFailure))
1379 return rewriter.notifyMatchFailure(
1380 op,
"unable to compute distributed vector type from lane layout");
1381 VectorType distDestTy = *distDestTyOrFailure;
1383 SmallVector<int64_t> destDistributedDims =
1384 getDistributedDims(destType, distDestTy);
1386 SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
1387 op.getOffsets(), [](Attribute attr) { return attr; });
1389 if (!destDistributedDims.empty()) {
1390 if (destDistributedDims.size() != 1)
1391 return rewriter.notifyMatchFailure(
1392 op,
"only single dimension distribution is supported");
1393 int64_t destDistDim = destDistributedDims[0];
1397 return rewriter.notifyMatchFailure(
1398 op,
"target attribute required to determine subgroup size");
1401 VectorType srcType = op.getSourceVectorType();
1403 int64_t sourceDistDim =
1404 destDistDim - (destType.getRank() - srcType.getRank());
1405 if (sourceDistDim < 0)
1406 return rewriter.notifyMatchFailure(
1407 op,
"distributed dimension must be in the last k dims of dest");
1411 if (!destLayout || !sourceLayout ||
1412 destLayout.getEffectiveLaneLayoutAsInt().empty() ||
1413 sourceLayout.getEffectiveLaneLayoutAsInt().empty())
1414 return rewriter.notifyMatchFailure(
1415 op,
"source or dest of insert_strided_slice lacks distribution "
1418 auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
1419 auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
1422 if ((destDistDim <
static_cast<int64_t
>(destLaneData.size()) &&
1423 destLaneData[destDistDim] != 1) ||
1424 (sourceDistDim <
static_cast<int64_t
>(sourceLaneData.size()) &&
1425 sourceLaneData[sourceDistDim] != 1))
1426 return rewriter.notifyMatchFailure(
1427 op,
"expecting unit lane data along the distributed dimension");
1429 int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
1430 if (srcDistrDimSize % subgroupSize != 0)
1431 return rewriter.notifyMatchFailure(
1432 op,
"source distributed dim size is not a multiple of "
1435 int64_t destDistrDimOffset =
1436 cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
1437 if (destDistrDimOffset % subgroupSize != 0)
1438 return rewriter.notifyMatchFailure(
1439 op,
"offset along distributed dim is not a multiple of "
1442 updatedOffsets[destDistDim] =
1443 rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
1446 auto newOp = vector::InsertStridedSliceOp::create(
1447 rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
1449 ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
1450 rewriter.replaceOp(op, newOp.
getResult());
1457struct SgToWiVectorInsert :
public OpConversionPattern<vector::InsertOp> {
1458 using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
1461 matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter)
const override {
1464 auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
1466 return rewriter.notifyMatchFailure(op,
"scalar insert not supported");
1468 xegpu::DistributeLayoutAttr layout =
1470 if (!layout || !layout.isForSubgroup())
1475 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
1476 if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
1477 [](int64_t v) {
return v != 1; }))
1478 return rewriter.notifyMatchFailure(
1479 op,
"only innermost dimension distribution is supported for "
1482 auto newOp = vector::InsertOp::create(
1483 rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
1484 op.getMixedPosition());
1485 rewriter.replaceOp(op, newOp.
getResult());
1491struct SgToWiConvertLayout
1492 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
1493 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
1496 matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
1497 ConversionPatternRewriter &rewriter)
const override {
1498 auto inputLayout = op.getInputLayoutAttr();
1499 auto targetLayout = op.getTargetLayoutAttr();
1500 auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
1501 SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
1503 if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
1504 xegpu::LayoutKind::Lane)) {
1505 return rewriter.notifyMatchFailure(
1506 op,
"lowering incompatible convert_layout not yet supported");
1508 rewriter.replaceOp(op, adaptor.getSource());
1513struct XeGPUSgToWiDistributeExperimentalPass
1515 XeGPUSgToWiDistributeExperimentalPass> {
1516 void runOnOperation()
override;
1521void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
1524 Operation *root = getOperation();
1526 signalPassFailure();
1532 if (
failed(verifyLayouts(root))) {
1533 LLVM_DEBUG(
DBGS() <<
"XeGPUSgToWiDistributeExperimentalPass: layout "
1534 "verification failed\n");
1535 signalPassFailure();
1539 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
1541 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
1545 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
1546 mlir::ValueRange inputs,
1547 mlir::Location loc) -> mlir::Value {
1548 UnrealizedConversionCastOp castOp =
1549 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1550 return castOp.getResult(0);
1554 TypeConverter typeConverter;
1556 typeConverter.addSourceMaterialization(materializeCast);
1557 typeConverter.addTargetMaterialization(materializeCast);
1562 typeConverter, patterns,
target);
1563 target.addLegalOp<UnrealizedConversionCastOp>();
1564 (void)applyPartialConversion(root,
target, std::move(patterns));
1575 OpBuilder builder(root);
1576 root->
walk([&](UnrealizedConversionCastOp op) {
1578 if (existingCasts.contains(op))
1581 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
1584 auto singleInput = op.getInputs()[0];
1585 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
1586 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
1587 if (!inputTy || !outputTy)
1593 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
1594 if (!definingOp || !definingOp->hasOneUse())
1596 auto inputOfDefiningOp = definingOp.getInputs()[0];
1599 auto inputOfDefiningOpTy =
1600 dyn_cast<VectorType>(inputOfDefiningOp.getType());
1601 if (inputOfDefiningOpTy &&
1602 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
1604 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
1605 outputTy, inputOfDefiningOp);
1606 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
1612 bool changed =
true;
1615 root->
walk([&](UnrealizedConversionCastOp op) {
1617 if (existingCasts.contains(op))
1619 if (op.use_empty()) {
1630 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
1631 if (!isa<TensorDescType, VectorType>(type))
1633 return std::nullopt;
1636 typeConverter.addConversion([](TensorDescType type) ->
Type {
1637 if (type.getLayoutAttr()) {
1638 return type.dropLayouts();
1644 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
1647 if (!isa<VectorType>(type))
1648 return std::nullopt;
1650 if (!layout || !layout.isForSubgroup())
1653 auto newTyOrFailure =
1655 if (failed(newTyOrFailure))
1657 return *newTyOrFailure;
1666 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
1667 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
1669 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
1670 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
1673 return !anchorOp.getAnchorLayout();
1676 target.addDynamicallyLegalOp<arith::ConstantOp>(
1677 [=](arith::ConstantOp op) ->
bool {
1679 if (!isa<VectorType>(op.getResult().getType()))
1685 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1686 [=](
Operation *op) -> std::optional<bool> {
1691 if (op->getNumResults() != 1)
1694 VectorType resultType =
1695 dyn_cast<VectorType>(op->getResult(0).getType());
1700 for (
Value operand : op->getOperands()) {
1701 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1702 if (!operandType || operandType.getShape() != resultType.getShape()) {
1710 target.addDynamicallyLegalOp<vector::ReductionOp>(
1711 [=](vector::ReductionOp op) ->
bool {
1716 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1717 [=](vector::MultiDimReductionOp op) ->
bool {
1718 return !isValidSubgroupMultiReductionOp(op);
1720 target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
1721 vector::TransposeOp, vector::BitCastOp,
1722 vector::ShapeCastOp, vector::StepOp,
1723 vector::BroadcastOp>([=](
Operation *op) ->
bool {
1726 target.addDynamicallyLegalOp<vector::ExtractOp>(
1727 [=](vector::ExtractOp op) ->
bool {
1728 if (!isa<VectorType>(op.getType()))
1732 target.addDynamicallyLegalOp<vector::InsertOp>(
1733 [=](vector::InsertOp op) ->
bool {
1736 target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
1737 [=](vector::ExtractStridedSliceOp op) ->
bool {
1740 target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
1741 [=](vector::InsertStridedSliceOp op) ->
bool {
1744 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1745 patterns.
add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1746 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1747 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1748 SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
1749 SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
1750 SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
1751 SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
1752 SgToWiVectorShapeCast, SgToWiBroadcast,
1753 SgToWiCreateMask<vector::CreateMaskOp>,
1754 SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
virtual int getSubgroupSize() const =0