32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
275 xegpu::LayoutAttr::get(ctx, {
uArch->getSubgroupSize()}, {1}));
278 xegpu::LayoutAttr::get(ctx, {1,
uArch->getSubgroupSize()}, {1, 1}));
282template <
typename Ty>
283static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
285 unsigned packingSize) {
287 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
288 "Expected 1D or 2D vector.");
290 assert(ty.getElementType().isIntOrFloat() &&
291 "Expected int or float element type.");
293 if (ty.getRank() == 1)
294 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
296 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
297 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
298 return LayoutInfo(xegpu::LayoutAttr::get(
299 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
311class LayoutInfoPropagation
318 unsigned indexBitWidth;
322 void visitDpasMxOp(xegpu::DpasMxOp dpasMx,
326 void visitStoreNdOp(xegpu::StoreNdOp store,
330 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
334 void visitLoadNdOp(xegpu::LoadNdOp
load,
338 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
342 void visitTransposeOp(vector::TransposeOp transpose,
346 void visitVectorBitcastOp(vector::BitCastOp bitcast,
350 void visitVectorInterleaveOp(vector::InterleaveOp interleave,
354 void visitVectorDeinterleaveOp(vector::DeinterleaveOp deinterleave,
358 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
362 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
366 void visitVectorReductionOp(vector::ReductionOp reduction,
370 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
373 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
377 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
381 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
385 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
389 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
393 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
397 void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
401 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
408 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
412 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
413 ArrayRef<const LayoutInfoLattice *> results)
override;
415 void visitBranchOperand(OpOperand &operand)
override {};
417 void visitCallOperand(OpOperand &operand)
override {};
420 visitNonControlFlowArguments(RegionSuccessor &successor,
421 ArrayRef<BlockArgument> arguments)
override {};
423 void visitExternalCall(CallOpInterface call,
424 ArrayRef<LayoutInfoLattice *> operands,
425 ArrayRef<const LayoutInfoLattice *> results)
override {
428 void setToExitState(LayoutInfoLattice *lattice)
override {
429 (void)lattice->meet(LayoutInfo());
434LogicalResult LayoutInfoPropagation::visitOperation(
435 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
436 ArrayRef<const LayoutInfoLattice *> results) {
439 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
440 .Case([&](xegpu::DpasMxOp dpasMxOp) {
441 visitDpasMxOp(dpasMxOp, operands, results);
443 .Case([&](xegpu::StoreNdOp storeNdOp) {
444 visitStoreNdOp(storeNdOp, operands, results);
446 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
447 visitStoreScatterOp(storeScatterOp, operands, results);
449 .Case([&](xegpu::LoadNdOp loadNdOp) {
450 visitLoadNdOp(loadNdOp, operands, results);
452 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
453 visitLoadGatherOp(loadGatherOp, operands, results);
455 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
456 visitPrefetchNdOp(prefetchNdOp, operands, results);
458 .Case([&](vector::TransposeOp transposeOp) {
459 visitTransposeOp(transposeOp, operands, results);
461 .Case([&](vector::BitCastOp bitcastOp) {
462 visitVectorBitcastOp(bitcastOp, operands, results);
464 .Case([&](vector::InterleaveOp interleaveOp) {
465 visitVectorInterleaveOp(interleaveOp, operands, results);
467 .Case([&](vector::DeinterleaveOp deinterleaveOp) {
468 visitVectorDeinterleaveOp(deinterleaveOp, operands, results);
470 .Case([&](vector::MultiDimReductionOp reductionOp) {
471 visitVectorMultiReductionOp(reductionOp, operands, results);
473 .Case([&](vector::ReductionOp reductionOp) {
474 visitVectorReductionOp(reductionOp, operands, results);
476 .Case([&](vector::BroadcastOp broadcastOp) {
477 visitVectorBroadCastOp(broadcastOp, operands, results);
479 .Case([&](vector::ShapeCastOp shapeCastOp) {
480 visitShapeCastOp(shapeCastOp, operands, results);
482 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
483 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
485 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
486 visitLoadMatrixOp(loadMatrixOp, operands, results);
488 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
489 visitStoreMatrixOp(storeMatrixOp, operands, results);
491 .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
492 visitConvertLayoutOp(convertLayoutOp, operands, results);
495 .Default([&](Operation *op) {
496 for (
const LayoutInfoLattice *resultInfo : results) {
497 if (!resultInfo->getValue().isAssigned())
499 for (
auto [operandInfo, operand] :
503 if (!isa<xegpu::TensorDescType, VectorType>(
504 operand.get().getType()))
507 meet(operandInfo, *resultInfo);
515bool LayoutInfoPropagation::hasParamsOfLayoutKind(
516 xegpu::DistributeLayoutAttr anchorLayout) {
517 if (anchorLayout ==
nullptr) {
520 if (layoutKind == xegpu::LayoutKind::InstData) {
521 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
523 if (layoutKind == xegpu::LayoutKind::Lane) {
524 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
525 anchorLayout.getEffectiveLaneDataAsInt().empty());
527 if (layoutKind == xegpu::LayoutKind::Subgroup) {
528 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
529 anchorLayout.getEffectiveSgDataAsInt().empty());
545 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
546 if (sgCount % sgLayout0)
548 int sgLayout1 = sgCount / sgLayout0;
549 int sgData0 = wgShape[0] / sgLayout0;
550 int sgData1 = wgShape[1] / sgLayout1;
551 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
552 (sgData0 % instData[0] || sgData1 % instData[1]))
554 candidates.emplace_back(sgLayout0, sgLayout1);
559 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
560 const std::pair<int, int> &
rhs) {
561 int diffLhs = std::abs(
lhs.first -
lhs.second);
562 int diffRhs = std::abs(
rhs.first -
rhs.second);
563 if (diffLhs != diffRhs)
564 return diffLhs < diffRhs;
565 return lhs.first <
rhs.first;
575 auto knownBlockSize = gpuFunc.getKnownBlockSize();
576 if (!knownBlockSize.has_value())
578 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
579 return flatBlockSize / sgSize;
582void LayoutInfoPropagation::visitPrefetchNdOp(
583 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
584 ArrayRef<const LayoutInfoLattice *> results) {
586 LayoutInfo prefetchLayout;
587 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
588 if (hasParamsOfLayoutKind(anchorLayout)) {
589 prefetchLayout = LayoutInfo(anchorLayout);
593 auto tdescTy = prefetch.getTensorDescType();
598 const auto *uArchInstruction =
599 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
601 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
604 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
606 prefetch.emitWarning(
"No known block params found for the element type.");
607 auto [bWidth, bHeight, bCount] = blockWHC.value();
608 SmallVector<int> instData;
610 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
612 prefetch.emitWarning(
613 "No suitable instruction multiple found for the given shape.");
614 if (tdescTy.getRank() == 1)
615 instData = {instWidth};
618 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
619 if (instHeight == -1)
620 prefetch.emitWarning(
621 "No suitable instruction multiple found for the given shape.");
622 instData = {instHeight, instWidth};
625 if (layoutKind == xegpu::LayoutKind::InstData)
627 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
629 prefetchLayout = getSIMTLayoutInfoBlockIO(
630 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
632 prefetch.setLayoutAttr(
633 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
636 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
639void LayoutInfoPropagation::visitVectorMultiReductionOp(
640 vector::MultiDimReductionOp reduction,
641 ArrayRef<LayoutInfoLattice *> operands,
642 ArrayRef<const LayoutInfoLattice *> results) {
643 Type resultTy = reduction.getDestType();
645 LayoutInfo resLayoutInfo = results[0]->getValue();
647 xegpu::DistributeLayoutAttr consumerLayoutAttr;
649 if (!resLayoutInfo.isAssigned())
652 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
655 VectorType sourceTy = reduction.getSourceVectorType();
656 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
662 if (layoutKind == xegpu::LayoutKind::Subgroup) {
664 if (succeeded(numSgOrErr))
665 numSg = numSgOrErr.value();
674 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch);
680 requiredResLayoutAttr, reductionDims);
682 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
684 propagateIfChanged(operands[1],
685 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
688void LayoutInfoPropagation::visitVectorReductionOp(
689 vector::ReductionOp reduction, ArrayRef<LayoutInfoLattice *> operands,
690 ArrayRef<const LayoutInfoLattice *> results) {
692 VectorType sourceTy = reduction.getSourceVectorType();
697 auto requiredResLayoutAttr =
702 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
703 if (reduction.getAcc())
704 propagateIfChanged(operands[1],
705 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
708void LayoutInfoPropagation::visitVectorBroadCastOp(
709 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
710 ArrayRef<const LayoutInfoLattice *> results) {
712 LayoutInfo resLayoutInfo = results[0]->getValue();
713 if (!resLayoutInfo.isAssigned())
717 VectorType resultTy =
broadcast.getResultVectorType();
718 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
723 auto srcShape = sourceTy.getShape();
724 auto resShape = resultTy.getShape();
726 auto resultLayoutAttr =
727 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
729 xegpu::DistributeLayoutAttr srcLayoutAttr =
732 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
735void LayoutInfoPropagation::visitShapeCastOp(
736 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
737 ArrayRef<const LayoutInfoLattice *> results) {
739 LayoutInfo resLayoutInfo = results[0]->getValue();
740 if (!resLayoutInfo.isAssigned())
742 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
743 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
744 auto resultLayoutAttr =
745 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
747 xegpu::DistributeLayoutAttr srcLayoutAttr =
750 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
754void LayoutInfoPropagation::visitDpasOp(
755 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
756 ArrayRef<const LayoutInfoLattice *> results) {
757 LayoutInfo dpasALayout;
758 LayoutInfo dpasBLayout;
759 LayoutInfo dpasCDLayout;
761 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
762 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
763 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
764 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
765 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
766 "Expected anchor layout for DPAS A operand.");
767 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
768 "Expected anchor layout for DPAS B operand.");
769 dpasALayout = LayoutInfo(anchorLayoutA);
770 dpasBLayout = LayoutInfo(anchorLayoutB);
771 dpasCDLayout = LayoutInfo(anchorLayoutCD);
776 VectorType aTy = dpas.getLhsType();
777 VectorType bTy = dpas.getRhsType();
778 VectorType cdTy = dpas.getResultType();
780 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
781 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
785 if (layoutKind == xegpu::LayoutKind::Subgroup) {
786 LayoutInfo consumerLayout = results[0]->getValue();
787 if (!consumerLayout.isAssigned())
790 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
794 "Unable to determine the number of subgroups for the operation.");
797 numSg = numSgOrErr.value();
800 consumerLayoutAttr, numSg, uArch);
801 if (!layouts.has_value()) {
803 "Failed to determine required layouts for DPAS operands.");
807 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
809 dpas.setLayoutAAttr(requiredALayout);
810 dpas.setLayoutBAttr(requiredBLayout);
811 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
812 dpasALayout = LayoutInfo(requiredALayout);
813 dpasBLayout = LayoutInfo(requiredBLayout);
814 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
816 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
817 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
818 if (operands.size() > 2)
819 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
825void LayoutInfoPropagation::visitDpasMxOp(
826 xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
827 ArrayRef<const LayoutInfoLattice *> results) {
830 LayoutInfo dpasMxALayout, dpasMxBLayout, dpasMxCDLayout;
831 LayoutInfo dpasMxAScaleLayout, dpasMxBScaleLayout;
834 xegpu::DistributeLayoutAttr anchorLayoutA = dpasMx.getLayoutAAttr();
835 xegpu::DistributeLayoutAttr anchorLayoutB = dpasMx.getLayoutBAttr();
836 xegpu::DistributeLayoutAttr anchorLayoutCD = dpasMx.getLayoutCdAttr();
839 if (anchorLayoutA && anchorLayoutB && anchorLayoutCD &&
840 hasParamsOfLayoutKind(anchorLayoutA) &&
841 hasParamsOfLayoutKind(anchorLayoutB) &&
842 hasParamsOfLayoutKind(anchorLayoutCD)) {
843 dpasMxALayout = LayoutInfo(anchorLayoutA);
844 dpasMxBLayout = LayoutInfo(anchorLayoutB);
845 dpasMxCDLayout = LayoutInfo(anchorLayoutCD);
848 xegpu::DistributeLayoutAttr anchorLayoutAScale =
849 dpasMx.getLayoutAScaleAttr();
850 xegpu::DistributeLayoutAttr anchorLayoutBScale =
851 dpasMx.getLayoutBScaleAttr();
852 if (anchorLayoutAScale)
853 dpasMxAScaleLayout = LayoutInfo(anchorLayoutAScale);
854 if (anchorLayoutBScale)
855 dpasMxBScaleLayout = LayoutInfo(anchorLayoutBScale);
862 VectorType aTy = dpasMx.getAType();
863 VectorType bTy = dpasMx.getBType();
864 VectorType cdTy = dpasMx.getResultType();
869 Value scaleA = dpasMx.getScaleA();
870 Value scaleB = dpasMx.getScaleB();
872 aScaleTy = dyn_cast<VectorType>(scaleA.
getType());
874 bScaleTy = dyn_cast<VectorType>(scaleB.
getType());
876 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
877 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
878 requiredBLayout, requiredAScaleLayout, requiredBScaleLayout;
881 if (layoutKind == xegpu::LayoutKind::Subgroup) {
882 LayoutInfo consumerLayout = results[0]->getValue();
883 if (!consumerLayout.isAssigned())
886 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
890 "Unable to determine the number of subgroups for the operation.");
893 numSg = numSgOrErr.value();
898 consumerLayoutAttr, numSg, uArch);
899 if (!layouts.has_value()) {
901 "Failed to determine required layouts for DPAS_MX operands.");
905 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr,
906 requiredAScaleLayout, requiredBScaleLayout) = *layouts;
908 dpasMx.setLayoutAAttr(requiredALayout);
909 dpasMx.setLayoutBAttr(requiredBLayout);
910 dpasMx.setLayoutCdAttr(requiredCDLayoutAttr);
911 if (requiredAScaleLayout)
912 dpasMx.setLayoutAScaleAttr(requiredAScaleLayout);
913 if (requiredBScaleLayout)
914 dpasMx.setLayoutBScaleAttr(requiredBScaleLayout);
916 dpasMxALayout = LayoutInfo(requiredALayout);
917 dpasMxBLayout = LayoutInfo(requiredBLayout);
918 dpasMxCDLayout = LayoutInfo(requiredCDLayoutAttr);
919 if (requiredAScaleLayout)
920 dpasMxAScaleLayout = LayoutInfo(requiredAScaleLayout);
921 if (requiredBScaleLayout)
922 dpasMxBScaleLayout = LayoutInfo(requiredBScaleLayout);
929 propagateIfChanged(operands[0], operands[0]->meet(dpasMxALayout));
930 propagateIfChanged(operands[1], operands[1]->meet(dpasMxBLayout));
932 if (dpasMx.getAcc()) {
933 propagateIfChanged(operands[idx], operands[idx]->meet(dpasMxCDLayout));
936 if (dpasMx.getScaleA()) {
937 if (dpasMxAScaleLayout.isAssigned())
938 propagateIfChanged(operands[idx],
939 operands[idx]->meet(dpasMxAScaleLayout));
942 if (dpasMx.getScaleB()) {
943 if (dpasMxBScaleLayout.isAssigned())
944 propagateIfChanged(operands[idx],
945 operands[idx]->meet(dpasMxBScaleLayout));
951void LayoutInfoPropagation::visitStoreNdOp(
952 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
953 ArrayRef<const LayoutInfoLattice *> results) {
954 LayoutInfo storeLayout;
955 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
956 if (hasParamsOfLayoutKind(anchorLayout)) {
957 storeLayout = LayoutInfo(anchorLayout);
962 const auto *uArchInstruction =
963 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
965 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
966 VectorType dataTy = store.getValueType();
967 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
968 store.getValueType().getElementType());
970 store.emitWarning(
"No known block params found for the element type.");
971 auto [bWidth, bHeight, bCount] = blockWHC.value();
972 SmallVector<int> instData;
974 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
977 "No suitable instruction multiple found for the given shape.");
978 if (dataTy.getRank() == 1)
979 instData = {instWidth};
982 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
983 if (instHeight == -1)
985 "No suitable instruction multiple found for the given shape.");
986 instData = {instHeight, instWidth};
989 if (layoutKind == xegpu::LayoutKind::InstData)
991 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
992 else if (layoutKind == xegpu::LayoutKind::Lane)
994 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
995 uArchInstruction->getPackedFormatBitSize());
998 auto numSgOrErr =
getNumSg(store, sgSize);
1001 "Unable to determine the number of subgroups for the operation.");
1005 instData, numSgOrErr.value());
1006 if (sgLayouts.empty()) {
1008 "Unable to determine suitable subgroup layout for store value.");
1011 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
1012 SmallVector<int> sgData = {
1013 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
1014 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
1015 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
1016 dataTy.getContext(),
1022 store.setLayoutAttr(
1023 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
1027 for (LayoutInfoLattice *operand : operands)
1028 propagateIfChanged(operand, operand->meet(storeLayout));
1033void LayoutInfoPropagation::visitLoadNdOp(
1034 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
1035 ArrayRef<const LayoutInfoLattice *> results) {
1036 LayoutInfo loadLayout;
1037 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
1038 if (hasParamsOfLayoutKind(anchorLayout)) {
1039 loadLayout = LayoutInfo(anchorLayout);
1042 LayoutInfo valueLayout = results[0]->getValue();
1044 if (!valueLayout.isAssigned())
1046 loadLayout = valueLayout;
1050 if (
auto transpose =
load.getTranspose()) {
1051 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
1052 "LayoutInfoPropagation stage.");
1053 loadLayout = valueLayout.transpose(transpose.value());
1055 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1058 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1063void LayoutInfoPropagation::visitConvertLayoutOp(
1064 xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
1065 ArrayRef<const LayoutInfoLattice *> results) {
1066 xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
1067 LayoutInfo convertLayout(anchorLayout);
1069 propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
1074void LayoutInfoPropagation::visitTransposeOp(
1075 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
1076 ArrayRef<const LayoutInfoLattice *> results) {
1078 LayoutInfo resultLayout = results[0]->getValue();
1079 if (!resultLayout.isAssigned())
1082 auto consumerLayoutAttr =
1083 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
1085 consumerLayoutAttr, transpose.getPermutation());
1088 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1093void LayoutInfoPropagation::visitVectorBitcastOp(
1094 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
1095 ArrayRef<const LayoutInfoLattice *> results) {
1097 LayoutInfo resLayoutInfo = results[0]->getValue();
1098 if (!resLayoutInfo.isAssigned())
1101 auto srcVecType = bitcast.getSourceVectorType();
1102 auto resVecType = bitcast.getResultVectorType();
1104 auto consumerLayoutAttr =
1105 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1110 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1114 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1115 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
1119 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
1121 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1127void LayoutInfoPropagation::visitVectorInterleaveOp(
1128 vector::InterleaveOp interleave, ArrayRef<LayoutInfoLattice *> operands,
1129 ArrayRef<const LayoutInfoLattice *> results) {
1131 LayoutInfo resLayoutInfo = results[0]->getValue();
1132 if (!resLayoutInfo.isAssigned())
1135 auto srcVecType = interleave.getSourceVectorType();
1136 auto resVecType = interleave.getResultVectorType();
1138 auto consumerLayoutAttr =
1139 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1146 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1151 auto srcLayoutAttr =
1155 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1156 propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(srcLayoutAttr)));
1162void LayoutInfoPropagation::visitVectorDeinterleaveOp(
1163 vector::DeinterleaveOp deinterleave, ArrayRef<LayoutInfoLattice *> operands,
1164 ArrayRef<const LayoutInfoLattice *> results) {
1167 LayoutInfo resLayoutInfo = results[0]->getValue();
1168 if (!resLayoutInfo.isAssigned())
1171 auto consumerLayoutAttr =
1172 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1178 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1181void LayoutInfoPropagation::visitInsertStridedSliceOp(
1182 vector::InsertStridedSliceOp insertStridedSlice,
1183 ArrayRef<LayoutInfoLattice *> operands,
1184 ArrayRef<const LayoutInfoLattice *> results) {
1186 LayoutInfo resLayoutInfo = results[0]->getValue();
1187 if (!resLayoutInfo.isAssigned())
1190 auto srcVecType = insertStridedSlice.getSourceVectorType();
1191 auto resVecType = insertStridedSlice.getDestVectorType();
1193 auto consumerLayoutAttr =
1194 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1195 const uArch *uArch =
1201 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1203 requiredResLayoutAttr);
1206 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
1207 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1208 propagateIfChanged(operands[1],
1209 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
1214void LayoutInfoPropagation::visitLoadGatherOp(
1215 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
1216 ArrayRef<const LayoutInfoLattice *> results) {
1217 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1218 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
1222 VectorType resVecTy =
load.getValueType();
1223 int chunkSize =
load.getChunkSize().value_or(1);
1225 LayoutInfo resLayoutInfo = results[0]->getValue();
1226 if (!resLayoutInfo.isAssigned())
1228 auto consumerLayoutAttr =
1229 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1231 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1232 requiredAnchorLayoutAttr = anchorLayoutAttr;
1235 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1239 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1240 load.setLayoutAttr(requiredAnchorLayoutAttr);
1243 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1245 requiredAnchorLayoutAttr, chunkSize);
1246 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1247 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1250 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1251 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1253 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1254 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1259void LayoutInfoPropagation::visitStoreScatterOp(
1260 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1261 ArrayRef<const LayoutInfoLattice *> results) {
1263 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1264 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1268 VectorType srcVecTy = storeScatter.getValueType();
1269 int chunkSize = storeScatter.getChunkSize().value_or(1);
1271 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1272 requiredAnchorLayoutAttr = anchorLayoutAttr;
1275 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1279 layoutKind, srcVecTy, chunkSize, uArch);
1280 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1283 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1284 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1286 requiredAnchorLayoutAttr, chunkSize);
1287 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1290 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1292 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1293 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1295 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1296 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1299void LayoutInfoPropagation::visitLoadMatrixOp(
1300 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1301 ArrayRef<const LayoutInfoLattice *> results) {
1303 LayoutInfo resLayoutInfo = results[0]->getValue();
1304 if (!resLayoutInfo.isAssigned())
1307 auto consumerLayoutAttr =
1308 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1310 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1314 if (!hasParamsOfLayoutKind(anchorLayout)) {
1315 VectorType resVecTy =
1316 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1321 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1322 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1327void LayoutInfoPropagation::visitStoreMatrixOp(
1328 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1329 ArrayRef<const LayoutInfoLattice *> results) {
1330 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1332 if (hasParamsOfLayoutKind(anchorLayout)) {
1333 layout = LayoutInfo(anchorLayout);
1335 VectorType srcVecTy =
1336 llvm::cast<VectorType>(storeMatrix.getData().getType());
1340 auto requiredAnchorLayoutAttr =
1342 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1343 layout = LayoutInfo(requiredAnchorLayoutAttr);
1346 propagateIfChanged(operands[0], operands[0]->meet(layout));
1355class RunLayoutInfoPropagation {
1360 unsigned indexBitWidth)
1362 SymbolTableCollection symbolTable;
1364 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1368 LayoutInfo getLayoutInfo(Value val);
1370 void printAnalysisResult(llvm::raw_ostream &os);
1373 DataFlowSolver solver;
1378LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1379 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1382 return state->getValue();
1386void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1387 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1388 os <<
"function: " << funcOp.getName() <<
":\n";
1390 for (BlockArgument arg : funcOp.getArguments()) {
1391 LayoutInfo layout = getLayoutInfo(arg);
1392 os <<
"argument: " << arg <<
"\n";
1398 funcOp.walk([&](Operation *op) {
1404 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1410 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1411 LayoutInfo layout = getLayoutInfo(r);
1412 os <<
"layout for result #" << i <<
": ";
1419 SmallVector<FunctionOpInterface> funcOps;
1420 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1421 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1422 funcOps.push_back(funcOp);
1425 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1426 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1427 funcOps.push_back(gpuFuncOp);
1431 for (FunctionOpInterface funcOp : funcOps)
1432 printFunctionResult(funcOp);
1444static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1446 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1451 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1452 auto *parentOp = arg.getOwner()->getParentOp();
1453 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1454 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1456 return getDefiningCreateNdDescOp(tiedInit->
get());
1463struct ResolveLayoutConflicts {
1464 ResolveLayoutConflicts(Operation *parentOp)
1465 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1466 LogicalResult run();
1469 Operation *parentOp;
1471 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1472 LogicalResult resolveVectorConsumer(OpOperand &operand);
1473 LogicalResult assignResultLayout(OpResult &
result);
1478LogicalResult ResolveLayoutConflicts::run() {
1481 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1485 if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
1487 if (
result.getType().isIntOrFloat()) {
1488 auto res = assignResultLayout(
result);
1490 DBGS() <<
"Failed to resolve vector consumer for multi-reduction "
1499 Type operandType = operand.get().getType();
1500 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1501 isa<xegpu::TensorDescType>(operandType)) {
1502 auto res = resolveTensorDescConsumer(operand);
1504 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1510 if (isa<VectorType>(operandType)) {
1511 auto res = resolveVectorConsumer(operand);
1513 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1522 DBGS() <<
"IR after resolving layout conflicts:\n";
1526 return r.wasInterrupted() ? failure() :
success();
1529LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &
result) {
1530 Operation *producerOp =
result.getDefiningOp();
1534 auto convertOp = xegpu::ConvertLayoutOp::create(
1537 result.replaceAllUsesExcept(convertOp.getResult(), convertOp);
1542ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1543 Value vectorValue = operand.
get();
1544 Operation *consumerOp = operand.
getOwner();
1547 if (!producerLayout) {
1548 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1549 vectorTy && vectorTy.getRank() > 1)
1550 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1555 if (!consumerLayout)
1557 "No consumer layout found for vector operand.");
1560 if (consumerLayout.isEqualTo(producerLayout))
1565 auto convertOp = xegpu::ConvertLayoutOp::create(
1566 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1567 producerLayout, consumerLayout);
1570 operand.
set(convertOp.getResult());
1575ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1576 Operation *consumerOp = operand.
getOwner();
1577 Value tdescValue = operand.
get();
1578 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1579 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1580 assert(anchorOp && currTDescType &&
1581 "Expected anchor layout op and tensor descriptor consumer.");
1582 Attribute currLayout = currTDescType.getLayout();
1583 Attribute expectedLayout = anchorOp.getAnchorLayout();
1586 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1588 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1589 if (!conflictingCreateNdOp) {
1590 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1591 << tdescValue <<
"\n";
1596 auto newTensorDescType = xegpu::TensorDescType::get(
1597 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1598 currTDescType.getElementType(), currTDescType.getEncoding(),
1600 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1601 builder, consumerOp->
getLoc(), newTensorDescType,
1602 conflictingCreateNdOp->getOperands(),
1603 conflictingCreateNdOp->getAttrs());
1621 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1628 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1631 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1632 if (!layout &&
result.getNumUses() > 0) {
1633 op->
emitWarning(
"op has users but no layout assigned for its result");
1638 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1639 auto typeWithLayout = xegpu::TensorDescType::get(
1640 tensorDescTy.getContext(), tensorDescTy.getShape(),
1641 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1642 result.setType(typeWithLayout);
1676 mlir::RegionBranchTerminatorOpInterface terminator,
1679 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1684 branchOp.getSuccessorOperandInputMapping(mapping,
1686 for (
const auto &[successorOperand, successorInputs] : mapping) {
1687 for (
Value successorInput : successorInputs) {
1688 Type inputType = successorInput.getType();
1690 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1692 xegpu::DistributeLayoutAttr successorInputLayout =
1693 getLayoutOfValue(successorInput);
1694 xegpu::DistributeLayoutAttr successorOperandLayout =
1695 getLayoutOfValue(successorOperand->get());
1698 if (!successorOperandLayout) {
1699 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1700 "branch terminator: "
1701 << successorOperand->get() <<
"\n");
1705 if (successorInputLayout &&
1706 successorInputLayout != successorOperandLayout) {
1707 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1708 "operand forwarded as the argument: "
1709 << successorInputLayout <<
" vs "
1710 << successorOperandLayout <<
"\n");
1714 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1715 auto newTdescTy = xegpu::TensorDescType::get(
1716 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1717 tdescTy.getEncoding(), successorOperandLayout);
1718 successorInput.setType(newTdescTy);
1723 if (
auto result = dyn_cast<OpResult>(successorInput))
1732 mlir::FunctionOpInterface funcOp,
1738 if (!isa<FunctionType>(funcOp.getFunctionType()))
1743 Type argType = arg.getType();
1744 newArgTypes.push_back(argType);
1745 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1747 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1749 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1750 <<
" but got none.\n");
1753 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1754 auto newTdescTy = xegpu::TensorDescType::get(
1755 tensorDescTy.getContext(), tensorDescTy.getShape(),
1756 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1757 arg.setType(newTdescTy);
1758 newArgTypes.back() = newTdescTy;
1763 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1764 funcOp.getResultTypes()));
1769struct XeGPUPropagateLayoutPass final
1770 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1771 XeGPUPropagateLayoutPass() =
default;
1772 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1773 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1774 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1775 void runOnOperation()
override;
1782 unsigned indexBitWidth,
bool printOnly) {
1783 RunLayoutInfoPropagation analysis(
target, layoutKind, indexBitWidth);
1786 auto &os = llvm::outs();
1787 analysis.printAnalysisResult(os);
1791 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1792 LayoutInfo layout = analysis.getLayoutInfo(val);
1793 if (
auto opResult = dyn_cast<OpResult>(val)) {
1794 Operation *defOp = opResult.getDefiningOp();
1795 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1796 auto anchorLayout = anchorOp.getAnchorLayout();
1797 if (anchorLayout !=
nullptr)
1798 return anchorLayout;
1800 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1802 if (requiredResLayoutAttr !=
nullptr)
1803 return requiredResLayoutAttr;
1805 if (!layout.isAssigned())
1807 xegpu::DistributeLayoutAttr layoutAttr =
1808 cast<xegpu::DistributeLayoutAttr>(layout.get());
1809 if (layout.isSliceLayout())
1810 return cast<xegpu::SliceAttr>(layoutAttr);
1812 return cast<xegpu::LayoutAttr>(layoutAttr);
1820 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1822 getXeGPULayoutForValue);
1824 .Case([&](mlir::FunctionOpInterface funcOp) {
1826 getXeGPULayoutForValue);
1829 r =
updateOp(builder, op, getXeGPULayoutForValue);
1832 op.
emitError(
"Failed to update operation with the layout.");
1838 if (walkResult.wasInterrupted())
1845 ResolveLayoutConflicts resolver(
target);
1846 return resolver.run();
1849void XeGPUPropagateLayoutPass::runOnOperation() {
1854 if (this->layoutKind ==
"lane") {
1856 }
else if (this->layoutKind ==
"inst") {
1858 }
else if (this->layoutKind ==
"subgroup") {
1859 layoutKind = xegpu::LayoutKind::Subgroup;
1861 getOperation()->emitError(
"Unsupported layout kind option: " +
1863 signalPassFailure();
1868 this->indexBitWidth, this->printOnly))) {
1869 signalPassFailure();
1874 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const