28#include "llvm/ADT/SetVector.h"
29#include "llvm/Support/LogicalResult.h"
30#include "llvm/Support/raw_ostream.h"
35#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
42#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
43#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static Value castValueTo(ConversionPatternRewriter &rewriter,
51 if (v.getType() == expectedTy)
54 if (isa<VectorType>(v.getType()) &&
55 v.getType().getNumElements() == expectedTy.getNumElements())
56 return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
59 auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
61 return newOp.getResult(0);
65static LogicalResult verifyLayouts(
Operation *root) {
67 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
68 auto layout = anchorOp.getAnchorLayout();
70 nestedOp->
emitError(
"expected anchor layout attribute on operation");
78 if (isa<VectorType>(
result.getType())) {
82 "expected result layout attribute on vector result");
89 return walkResult.wasInterrupted() ? failure() :
success();
95static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
98 if (!resLayout || !resLayout.isForSubgroup())
100 VectorType resTy = dyn_cast<VectorType>(op.getType());
104 FailureOr<VectorType> resDistTypeOrFailure =
105 getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
106 if (failed(resDistTypeOrFailure))
108 return op.getReductionDims().size() == 1;
115static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
117 assert(isValidSubgroupMultiReductionOp(op) &&
"Expecting a valid subgroup "
118 "MultiDimReductionOp");
120 VectorType resTy = dyn_cast<VectorType>(op.getType());
121 auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
122 return resTy != resDistTypeOrFailure.value();
127struct SgToWiCreateNdDesc :
public OpConversionPattern<xegpu::CreateNdDescOp> {
128 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
131 matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
133 xegpu::TensorDescType resultType = op.getType();
135 if (!resultType.getLayout())
138 auto newOp = xegpu::CreateNdDescOp::create(
139 rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
141 rewriter.replaceOp(op, newOp.getResult());
149struct SgToWiLoadNd :
public OpConversionPattern<xegpu::LoadNdOp> {
150 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
153 matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const override {
155 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
161 if (op.getTensorDescType().getLayout() != layout)
162 return rewriter.notifyMatchFailure(
163 op,
"conflicting layout attributes on tensor descriptor and anchor");
166 return rewriter.notifyMatchFailure(
167 op,
"xegpu::LoadNdOp require target attribute attached to "
168 "determine transpose "
170 auto supportedWiResultTyOrFailure =
172 auto expectedWiResultTyOrFailure =
174 if (failed(supportedWiResultTyOrFailure))
175 return rewriter.notifyMatchFailure(
176 op,
"unable to compute the workitem vector type for LoadNdOp");
177 if (failed(expectedWiResultTyOrFailure))
178 return rewriter.notifyMatchFailure(
180 "unable to compute expected workitem vector type from lane layout");
181 auto newOp = xegpu::LoadNdOp::create(
182 rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
183 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
184 op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
185 op.getL3HintAttr(),
nullptr);
191 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
192 expectedWiResultTyOrFailure.value()));
200struct SgToWiStoreNd :
public OpConversionPattern<xegpu::StoreNdOp> {
201 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
204 matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override {
206 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
212 if (op.getTensorDescType().getLayout() != layout)
213 return rewriter.notifyMatchFailure(
214 op,
"conflicting layout attributes on tensor descriptor and anchor");
216 if (valueLayout != layout)
217 return rewriter.notifyMatchFailure(
218 op,
"conflicting layout attributes on value and anchor");
219 auto supportedWiValueTyOrFailure =
221 if (failed(supportedWiValueTyOrFailure))
222 return rewriter.notifyMatchFailure(
224 "unable to compute wi vector type for StoreNdOp value from tensor "
227 xegpu::StoreNdOp::create(
228 rewriter, op.getLoc(),
230 supportedWiValueTyOrFailure.value()),
231 adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
232 op.getL2HintAttr(), op.getL3HintAttr(),
nullptr);
233 rewriter.eraseOp(op);
241struct SgToWiDpas :
public OpConversionPattern<xegpu::DpasOp> {
242 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
245 matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
249 auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
250 auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
251 auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
252 if (!layoutA || !layoutB || !layoutCd)
255 auto wiResultTyOrFailure =
257 auto wiATypeOrFailure =
259 auto wiBTypeOrFailure =
261 auto expectedWiResultTyOrFailure =
263 if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
264 failed(wiBTypeOrFailure))
265 return rewriter.notifyMatchFailure(
266 op,
"failed to calculate supported workitem vector types for DpasOp "
268 if (failed(expectedWiResultTyOrFailure))
269 return rewriter.notifyMatchFailure(
270 op,
"unable to compute expected workitem vector type for DpasOp from "
272 auto newOp = xegpu::DpasOp::create(
273 rewriter, op->getLoc(), wiResultTyOrFailure.value(),
275 wiATypeOrFailure.value()),
277 wiBTypeOrFailure.value()),
279 wiResultTyOrFailure.value()),
283 rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
284 expectedWiResultTyOrFailure.value()));
291struct SgToWiElementWise :
public ConversionPattern {
293 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
297 ConversionPatternRewriter &rewriter)
const override {
304 return rewriter.notifyMatchFailure(
305 op,
"operation result is not a vector type");
307 xegpu::DistributeLayoutAttr layout =
309 if (!layout || !layout.isForSubgroup())
310 return rewriter.notifyMatchFailure(
311 op,
"operation result does not have subgroup distribute layout");
313 auto wiShapeOrFailure =
316 if (failed(wiShapeOrFailure))
317 return rewriter.notifyMatchFailure(
318 op,
"unable to compute workitem vector type from the layout");
320 VectorType newResultType = wiShapeOrFailure.value();
322 state.addOperands(operands);
323 state.addTypes(newResultType);
326 if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
327 state.addAttribute(attr.getName(), attr.getValue());
329 Operation *newOp = rewriter.create(state);
331 rewriter.replaceOp(op, newOp->
getResult(0));
338struct SgToWiArithConstant :
public OpConversionPattern<arith::ConstantOp> {
339 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
342 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter)
const override {
344 auto resultType = dyn_cast<VectorType>(op.getType());
349 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
351 return rewriter.notifyMatchFailure(
352 op,
"only dense splat vector constants are supported");
354 xegpu::DistributeLayoutAttr layout =
356 if (!layout || !layout.isForSubgroup())
357 return rewriter.notifyMatchFailure(
358 op,
"operation result does not have subgroup distribute layout");
360 auto wiShapeOrFailure =
363 if (failed(wiShapeOrFailure))
364 return rewriter.notifyMatchFailure(
365 op,
"unable to compute workitem vector type from the layout");
367 VectorType newResultType = wiShapeOrFailure.value();
368 auto sclarValue = dense.getSplatValue<
Attribute>();
371 auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
373 rewriter.replaceOp(op, newOp.
getResult());
379struct SgToWiPrefetchNd :
public OpConversionPattern<xegpu::PrefetchNdOp> {
380 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
383 matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
384 ConversionPatternRewriter &rewriter)
const override {
385 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
390 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
391 op.getMixedOffsets(), op.getL1HintAttr(),
392 op.getL2HintAttr(), op.getL3HintAttr(),
394 rewriter.eraseOp(op);
432struct SgToWiLoadGather :
public OpConversionPattern<xegpu::LoadGatherOp> {
433 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
436 matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter)
const override {
438 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
442 VectorType origResultTy = op.getValueType();
447 int chunkSize = op.getChunkSize().value_or(1);
448 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
449 ArrayRef<int64_t> shape = origResultTy.getShape();
451 shape.take_front(origResultTy.getRank() - effectiveVecRank),
452 [](int64_t d) { return d != 1; }))
453 return rewriter.notifyMatchFailure(
454 op,
"Only unit dimensions allowed for the leading "
455 "dimensions of the load vector!");
457 auto distResultTyOrFailure =
459 if (
failed(distResultTyOrFailure))
460 return rewriter.notifyMatchFailure(
462 "unable to compute expected workitem vector type from lane layout");
464 VectorType distResultTy = distResultTyOrFailure.value();
465 VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
466 distResultTy.getElementType());
469 Value distOffsets = adaptor.getOffsets();
470 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
471 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
472 distOffsetsTy.getElementType());
473 distOffsets = castValueTo(
476 Value distMask = adaptor.getMask();
477 auto distMaskTy = cast<VectorType>(distMask.
getType());
478 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
479 distMaskTy.getElementType());
483 Value distSource = adaptor.getSource();
484 auto newOp = xegpu::LoadGatherOp::create(
485 rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
486 distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
487 op.getL3HintAttr(),
nullptr);
489 Value
result = newOp->getResult(0);
490 if (distResultTy1D != distResultTy)
493 rewriter.replaceOp(op,
result);
502struct SgToWiVectorReduction :
public OpConversionPattern<vector::ReductionOp> {
503 using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
506 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
507 ConversionPatternRewriter &rewriter)
const override {
511 if (!layout || !layout.isForSubgroup())
514 VectorType srcVecType = op.getSourceVectorType();
516 if (srcVecType.getRank() != 1)
517 return rewriter.notifyMatchFailure(
518 op,
"Only rank 1 reductions can be distributed.");
520 if (layout.getRank() != srcVecType.getRank())
521 return rewriter.notifyMatchFailure(
522 op,
"Layout rank does not match vector rank.");
525 int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
528 return rewriter.notifyMatchFailure(
529 op,
"xegpu::ReductionOp require target attribute attached to "
530 "determine subgroup size");
534 srcVecType.getShape()[0] % sgSize != 0)
535 return rewriter.notifyMatchFailure(op,
536 "Invalid layout or reduction vector "
537 "dimension must match subgroup size.");
539 if (!op.getType().isIntOrFloat())
540 return rewriter.notifyMatchFailure(
541 op,
"Reduction distribution currently only supports floats and "
545 Value laneValVec = adaptor.getVector();
549 op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
552 if (adaptor.getAcc())
554 rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
556 rewriter.replaceOp(op, fullReduce);
565struct SgToWiMultiDimReduction
566 :
public OpConversionPattern<vector::MultiDimReductionOp> {
567 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
570 matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
571 ConversionPatternRewriter &rewriter)
const override {
573 ArrayRef<int64_t> reductionDims = op.getReductionDims();
574 assert(reductionDims.size() == 1 &&
575 "Expecting single reduction dimension for subgroup multi "
577 if (isReductionLaneLocal(op)) {
579 VectorType resVecTy = dyn_cast<VectorType>(op.getType());
580 auto resDistVecTyOrFailure =
584 result = vector::MultiDimReductionOp::create(
585 rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
586 adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
588 auto reductionDim = reductionDims[0];
589 VectorType sourceType = op.getSourceVectorType();
590 int64_t reductionDimSize = sourceType.getShape()[reductionDim];
594 reductionDim, reductionDimSize, op.getLoc(), rewriter);
596 rewriter.replaceOp(op,
result);
605 ConversionPatternRewriter &rewriter,
Location loc,
608 Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
609 mlir::IntegerAttr());
611 layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
614 assert(maybeCoords.value().size() == 1 &&
615 "Expected one set of distributed offsets");
619 return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
623struct SgToWiLoadMatrix :
public OpConversionPattern<xegpu::LoadMatrixOp> {
624 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
627 matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter)
const override {
629 auto layout = op.getLayoutAttr();
634 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
636 return rewriter.notifyMatchFailure(
637 op,
"the matrix op payload must be a vector type");
639 auto loc = op.getLoc();
640 auto offsets = op.getMixedOffsets();
642 return rewriter.notifyMatchFailure(op,
"the load op must have offsets");
644 FailureOr<VectorType> distPayloadTyOrFailure =
646 if (
failed(distPayloadTyOrFailure))
647 return rewriter.notifyMatchFailure(
648 op,
"Failed to distribute matrix op payload based on layout.");
650 SmallVector<Value> offsetsAsValues =
653 SmallVector<Value> newCoords = offsetsAsValues;
654 if (!op.getSubgroupBlockIoAttr()) {
655 newCoords = computeDistributedCoordsForMatrixOp(
656 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
657 if (newCoords.empty())
658 return rewriter.notifyMatchFailure(
659 op,
"Failed to compute distributed coordinates.");
662 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
663 ShapedType::kDynamic);
665 rewriter.getDenseI64ArrayAttr(newConstOffsets);
667 auto newOp = xegpu::LoadMatrixOp::create(
668 rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
669 ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
670 xegpu::DistributeLayoutAttr{});
671 rewriter.replaceOp(op, newOp.getResult());
677struct SgToWiStoreMatrix :
public OpConversionPattern<xegpu::StoreMatrixOp> {
678 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
681 matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter)
const override {
683 auto layout = op.getLayoutAttr();
688 VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
690 return rewriter.notifyMatchFailure(
691 op,
"the matrix op payload must be a vector type");
693 auto loc = op.getLoc();
694 auto offsets = op.getMixedOffsets();
696 return rewriter.notifyMatchFailure(op,
"the store op must have offsets");
698 FailureOr<VectorType> distPayloadTyOrFailure =
700 if (
failed(distPayloadTyOrFailure))
701 return rewriter.notifyMatchFailure(
702 op,
"Failed to distribute matrix op payload based on layout.");
704 SmallVector<Value> offsetsAsValues =
707 SmallVector<Value> newCoords = offsetsAsValues;
708 if (!op.getSubgroupBlockIoAttr()) {
709 newCoords = computeDistributedCoordsForMatrixOp(
710 rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
711 if (newCoords.empty())
712 return rewriter.notifyMatchFailure(
713 op,
"Failed to compute distributed coordinates.");
716 SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
717 ShapedType::kDynamic);
719 rewriter.getDenseI64ArrayAttr(newConstOffsets);
721 xegpu::StoreMatrixOp::create(
724 distPayloadTyOrFailure.value()),
725 adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr,
726 op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
727 rewriter.eraseOp(op);
766struct SgToWiStoreScatter :
public OpConversionPattern<xegpu::StoreScatterOp> {
767 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
770 matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
771 ConversionPatternRewriter &rewriter)
const override {
772 xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
776 VectorType origValueTy = op.getValueType();
781 int chunkSize = op.getChunkSize().value_or(1);
782 int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
783 ArrayRef<int64_t> shape = origValueTy.getShape();
784 if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
785 [](int64_t d) { return d != 1; }))
786 return rewriter.notifyMatchFailure(
787 op,
"Only unit dimensions allowed for the leading "
788 "dimensions of the store vector!");
790 auto distValueTyOrFailure =
792 if (
failed(distValueTyOrFailure))
793 return rewriter.notifyMatchFailure(
795 "unable to compute expected workitem vector type from lane layout");
797 VectorType distValueTy = distValueTyOrFailure.value();
798 VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
799 distValueTy.getElementType());
801 Value distValue = adaptor.getValue();
802 if (distValue.
getType() != distValueTy1D)
807 Value distOffsets = adaptor.getOffsets();
808 auto distOffsetsTy = cast<VectorType>(distOffsets.
getType());
809 VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
810 distOffsetsTy.getElementType());
811 distOffsets = castValueTo(
814 Value distMask = adaptor.getMask();
815 auto distMaskTy = cast<VectorType>(distMask.
getType());
816 VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
817 distMaskTy.getElementType());
821 Value distDest = adaptor.getDest();
822 xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
823 distOffsets, distMask, op.getChunkSizeAttr(),
824 op.getL1HintAttr(), op.getL2HintAttr(),
825 op.getL3HintAttr(),
nullptr);
826 rewriter.eraseOp(op);
831struct XeGPUSgToWiDistributeExperimentalPass
833 XeGPUSgToWiDistributeExperimentalPass> {
834 void runOnOperation()
override;
839void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
843 Operation *root = getOperation();
844 if (
failed(verifyLayouts(root))) {
845 LLVM_DEBUG(
DBGS() <<
"XeGPUSgToWiDistributeExperimentalPass: layout "
846 "verification failed\n");
851 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
853 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
857 auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
858 mlir::ValueRange inputs,
859 mlir::Location loc) -> mlir::Value {
860 UnrealizedConversionCastOp castOp =
861 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
862 return castOp.getResult(0);
866 TypeConverter typeConverter;
868 typeConverter.addSourceMaterialization(materializeCast);
869 typeConverter.addTargetMaterialization(materializeCast);
874 typeConverter, patterns,
target);
875 target.addLegalOp<UnrealizedConversionCastOp>();
876 (void)applyPartialConversion(root,
target, std::move(patterns));
887 OpBuilder builder(root);
888 root->
walk([&](UnrealizedConversionCastOp op) {
890 if (existingCasts.contains(op))
893 if (op.getNumOperands() != 1 || op.getNumResults() != 1)
896 auto singleInput = op.getInputs()[0];
897 auto inputTy = dyn_cast<VectorType>(singleInput.getType());
898 auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
899 if (!inputTy || !outputTy)
905 auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
906 if (!definingOp || !definingOp->hasOneUse())
908 auto inputOfDefiningOp = definingOp.getInputs()[0];
911 auto inputOfDefiningOpTy =
912 dyn_cast<VectorType>(inputOfDefiningOp.getType());
913 if (inputOfDefiningOpTy &&
914 inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
916 auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
917 outputTy, inputOfDefiningOp);
918 op.replaceAllUsesWith(
ValueRange{shapeCast.getResult()});
927 root->
walk([&](UnrealizedConversionCastOp op) {
929 if (existingCasts.contains(op))
931 if (op.use_empty()) {
942 typeConverter.addConversion([](
Type type) -> std::optional<Type> {
943 if (!isa<TensorDescType, VectorType>(type))
948 typeConverter.addConversion([](TensorDescType type) ->
Type {
949 if (type.getLayoutAttr()) {
950 return type.dropLayouts();
956 typeConverter.addConversion([](
Value v) -> std::optional<Type> {
959 if (!isa<VectorType>(type))
962 if (!layout || !layout.isForSubgroup())
965 auto newTyOrFailure =
967 if (failed(newTyOrFailure))
969 return *newTyOrFailure;
978 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
979 [&](xegpu::CreateNdDescOp op) {
return !op.getType().getLayoutAttr(); });
981 target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](
Operation *op) {
982 auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
985 return !anchorOp.getAnchorLayout();
988 target.addDynamicallyLegalOp<arith::ConstantOp>(
989 [=](arith::ConstantOp op) ->
bool {
991 if (!isa<VectorType>(op.getResult().getType()))
997 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
998 [=](
Operation *op) -> std::optional<bool> {
1003 if (op->getNumResults() != 1)
1006 VectorType resultType =
1007 dyn_cast<VectorType>(op->getResult(0).getType());
1012 for (
Value operand : op->getOperands()) {
1013 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1014 if (!operandType || operandType.getShape() != resultType.getShape()) {
1022 target.addDynamicallyLegalOp<vector::ReductionOp>(
1023 [=](vector::ReductionOp op) ->
bool {
1028 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1029 [=](vector::MultiDimReductionOp op) ->
bool {
1030 return !isValidSubgroupMultiReductionOp(op);
1032 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
1033 patterns.
add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
1034 SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
1035 SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
1036 SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
Attributes are known-constant values of operations.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
const uArch * getUArch(llvm::StringRef archName)
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter)
Define only the type conversions needed for XeGPU subgroup to workitem distribution.
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
bool requirePacked(const LayoutAttr layout)
Helper function to check if the layout is packed.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Defines type conversions and legality for XeGPU subgroup to workitem distribution and appends the req...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
virtual int getSubgroupSize() const =0