32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
273 assert(rank >= 1 &&
"Expected at least 1D vector.");
283 return LayoutInfo(xegpu::LayoutAttr::get(ctx, laneLayout, laneData));
288template <
typename Ty>
289static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
291 unsigned packingSize) {
293 assert(ty.getRank() >= 1 &&
"Expected at least 1D vector.");
295 assert(ty.getElementType().isIntOrFloat() &&
296 "Expected int or float element type.");
298 if (ty.getRank() == 1)
299 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
301 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
302 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
304 unsigned rank = ty.getRank();
308 laneData[rank - 1] = packingFactor;
310 xegpu::LayoutAttr::get(ty.getContext(), laneLayout, laneData));
322class LayoutInfoPropagation
329 unsigned indexBitWidth;
333 void visitDpasMxOp(xegpu::DpasMxOp dpasMx,
337 void visitStoreNdOp(xegpu::StoreNdOp store,
341 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
345 void visitLoadNdOp(xegpu::LoadNdOp
load,
349 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
353 void visitTransposeOp(vector::TransposeOp transpose,
357 void visitVectorBitcastOp(vector::BitCastOp bitcast,
361 void visitVectorInterleaveOp(vector::InterleaveOp interleave,
365 void visitVectorDeinterleaveOp(vector::DeinterleaveOp deinterleave,
369 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
373 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
377 void visitVectorReductionOp(vector::ReductionOp reduction,
381 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
384 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
388 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
392 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
396 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
400 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
404 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
408 void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
412 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
419 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
426 void visitBranchOperand(
OpOperand &operand)
override {};
428 void visitCallOperand(
OpOperand &operand)
override {};
434 void visitExternalCall(CallOpInterface call,
439 void setToExitState(LayoutInfoLattice *lattice)
override {
440 (
void)lattice->meet(LayoutInfo());
445LogicalResult LayoutInfoPropagation::visitOperation(
446 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
447 ArrayRef<const LayoutInfoLattice *> results) {
450 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
451 .Case([&](xegpu::DpasMxOp dpasMxOp) {
452 visitDpasMxOp(dpasMxOp, operands, results);
454 .Case([&](xegpu::StoreNdOp storeNdOp) {
455 visitStoreNdOp(storeNdOp, operands, results);
457 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
458 visitStoreScatterOp(storeScatterOp, operands, results);
460 .Case([&](xegpu::LoadNdOp loadNdOp) {
461 visitLoadNdOp(loadNdOp, operands, results);
463 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
464 visitLoadGatherOp(loadGatherOp, operands, results);
466 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
467 visitPrefetchNdOp(prefetchNdOp, operands, results);
469 .Case([&](vector::TransposeOp transposeOp) {
470 visitTransposeOp(transposeOp, operands, results);
472 .Case([&](vector::BitCastOp bitcastOp) {
473 visitVectorBitcastOp(bitcastOp, operands, results);
475 .Case([&](vector::InterleaveOp interleaveOp) {
476 visitVectorInterleaveOp(interleaveOp, operands, results);
478 .Case([&](vector::DeinterleaveOp deinterleaveOp) {
479 visitVectorDeinterleaveOp(deinterleaveOp, operands, results);
481 .Case([&](vector::MultiDimReductionOp reductionOp) {
482 visitVectorMultiReductionOp(reductionOp, operands, results);
484 .Case([&](vector::ReductionOp reductionOp) {
485 visitVectorReductionOp(reductionOp, operands, results);
487 .Case([&](vector::BroadcastOp broadcastOp) {
488 visitVectorBroadCastOp(broadcastOp, operands, results);
490 .Case([&](vector::ShapeCastOp shapeCastOp) {
491 visitShapeCastOp(shapeCastOp, operands, results);
493 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
494 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
496 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
497 visitLoadMatrixOp(loadMatrixOp, operands, results);
499 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
500 visitStoreMatrixOp(storeMatrixOp, operands, results);
502 .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
503 visitConvertLayoutOp(convertLayoutOp, operands, results);
506 .Default([&](Operation *op) {
507 for (
const LayoutInfoLattice *resultInfo : results) {
508 if (!resultInfo->getValue().isAssigned())
510 for (
auto [operandInfo, operand] :
514 if (!isa<xegpu::TensorDescType, VectorType>(
515 operand.get().getType()))
518 meet(operandInfo, *resultInfo);
526bool LayoutInfoPropagation::hasParamsOfLayoutKind(
527 xegpu::DistributeLayoutAttr anchorLayout) {
528 if (anchorLayout ==
nullptr) {
531 if (layoutKind == xegpu::LayoutKind::InstData) {
532 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
534 if (layoutKind == xegpu::LayoutKind::Lane) {
535 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
536 anchorLayout.getEffectiveLaneDataAsInt().empty());
538 if (layoutKind == xegpu::LayoutKind::Subgroup) {
539 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
540 anchorLayout.getEffectiveSgDataAsInt().empty());
556 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
557 if (sgCount % sgLayout0)
559 int sgLayout1 = sgCount / sgLayout0;
560 int sgData0 = wgShape[0] / sgLayout0;
561 int sgData1 = wgShape[1] / sgLayout1;
562 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
563 (sgData0 % instData[0] || sgData1 % instData[1]))
565 candidates.emplace_back(sgLayout0, sgLayout1);
570 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
571 const std::pair<int, int> &
rhs) {
572 int diffLhs = std::abs(
lhs.first -
lhs.second);
573 int diffRhs = std::abs(
rhs.first -
rhs.second);
574 if (diffLhs != diffRhs)
575 return diffLhs < diffRhs;
576 return lhs.first <
rhs.first;
586 auto knownBlockSize = gpuFunc.getKnownBlockSize();
587 if (!knownBlockSize.has_value())
589 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
590 return flatBlockSize / sgSize;
593void LayoutInfoPropagation::visitPrefetchNdOp(
594 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
595 ArrayRef<const LayoutInfoLattice *> results) {
597 LayoutInfo prefetchLayout;
598 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
599 if (hasParamsOfLayoutKind(anchorLayout)) {
600 prefetchLayout = LayoutInfo(anchorLayout);
604 auto tdescTy = prefetch.getTensorDescType();
609 const auto *uArchInstruction =
610 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
612 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
615 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
617 prefetch.emitWarning(
"No known block params found for the element type.");
618 auto [bWidth, bHeight, bCount] = blockWHC.value();
619 SmallVector<int> instData;
621 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
623 prefetch.emitWarning(
624 "No suitable instruction multiple found for the given shape.");
625 if (tdescTy.getRank() == 1)
626 instData = {instWidth};
629 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
630 if (instHeight == -1)
631 prefetch.emitWarning(
632 "No suitable instruction multiple found for the given shape.");
633 instData = {instHeight, instWidth};
636 if (layoutKind == xegpu::LayoutKind::InstData)
638 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
640 prefetchLayout = getSIMTLayoutInfoBlockIO(
641 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
643 prefetch.setLayoutAttr(
644 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
647 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
650void LayoutInfoPropagation::visitVectorMultiReductionOp(
651 vector::MultiDimReductionOp reduction,
652 ArrayRef<LayoutInfoLattice *> operands,
653 ArrayRef<const LayoutInfoLattice *> results) {
654 Type resultTy = reduction.getDestType();
656 LayoutInfo resLayoutInfo = results[0]->getValue();
658 xegpu::DistributeLayoutAttr consumerLayoutAttr;
660 if (!resLayoutInfo.isAssigned())
663 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
666 VectorType sourceTy = reduction.getSourceVectorType();
667 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
673 if (layoutKind == xegpu::LayoutKind::Subgroup) {
675 if (succeeded(numSgOrErr))
676 numSg = numSgOrErr.value();
685 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch);
691 requiredResLayoutAttr, reductionDims);
693 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
695 propagateIfChanged(operands[1],
696 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
699void LayoutInfoPropagation::visitVectorReductionOp(
700 vector::ReductionOp reduction, ArrayRef<LayoutInfoLattice *> operands,
701 ArrayRef<const LayoutInfoLattice *> results) {
703 VectorType sourceTy = reduction.getSourceVectorType();
708 auto requiredResLayoutAttr =
713 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
714 if (reduction.getAcc())
715 propagateIfChanged(operands[1],
716 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
719void LayoutInfoPropagation::visitVectorBroadCastOp(
720 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
721 ArrayRef<const LayoutInfoLattice *> results) {
723 LayoutInfo resLayoutInfo = results[0]->getValue();
724 if (!resLayoutInfo.isAssigned())
728 VectorType resultTy =
broadcast.getResultVectorType();
729 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
734 auto srcShape = sourceTy.getShape();
735 auto resShape = resultTy.getShape();
737 auto resultLayoutAttr =
738 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
740 xegpu::DistributeLayoutAttr srcLayoutAttr =
743 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
746void LayoutInfoPropagation::visitShapeCastOp(
747 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
748 ArrayRef<const LayoutInfoLattice *> results) {
750 LayoutInfo resLayoutInfo = results[0]->getValue();
751 if (!resLayoutInfo.isAssigned())
753 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
754 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
755 auto resultLayoutAttr =
756 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
758 xegpu::DistributeLayoutAttr srcLayoutAttr =
761 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
765void LayoutInfoPropagation::visitDpasOp(
766 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
767 ArrayRef<const LayoutInfoLattice *> results) {
768 LayoutInfo dpasALayout;
769 LayoutInfo dpasBLayout;
770 LayoutInfo dpasCDLayout;
772 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
773 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
774 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
775 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
776 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
777 "Expected anchor layout for DPAS A operand.");
778 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
779 "Expected anchor layout for DPAS B operand.");
780 dpasALayout = LayoutInfo(anchorLayoutA);
781 dpasBLayout = LayoutInfo(anchorLayoutB);
782 dpasCDLayout = LayoutInfo(anchorLayoutCD);
787 VectorType aTy = dpas.getLhsType();
788 VectorType bTy = dpas.getRhsType();
789 VectorType cdTy = dpas.getResultType();
791 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
792 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
796 if (layoutKind == xegpu::LayoutKind::Subgroup) {
797 LayoutInfo consumerLayout = results[0]->getValue();
798 if (!consumerLayout.isAssigned())
801 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
805 "Unable to determine the number of subgroups for the operation.");
808 numSg = numSgOrErr.value();
811 consumerLayoutAttr, numSg, uArch);
812 if (!layouts.has_value()) {
814 "Failed to determine required layouts for DPAS operands.");
818 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
820 dpas.setLayoutAAttr(requiredALayout);
821 dpas.setLayoutBAttr(requiredBLayout);
822 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
823 dpasALayout = LayoutInfo(requiredALayout);
824 dpasBLayout = LayoutInfo(requiredBLayout);
825 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
827 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
828 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
829 if (operands.size() > 2)
830 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
836void LayoutInfoPropagation::visitDpasMxOp(
837 xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
838 ArrayRef<const LayoutInfoLattice *> results) {
841 LayoutInfo dpasMxALayout, dpasMxBLayout, dpasMxCDLayout;
842 LayoutInfo dpasMxAScaleLayout, dpasMxBScaleLayout;
845 xegpu::DistributeLayoutAttr anchorLayoutA = dpasMx.getLayoutAAttr();
846 xegpu::DistributeLayoutAttr anchorLayoutB = dpasMx.getLayoutBAttr();
847 xegpu::DistributeLayoutAttr anchorLayoutCD = dpasMx.getLayoutCdAttr();
850 if (anchorLayoutA && anchorLayoutB && anchorLayoutCD &&
851 hasParamsOfLayoutKind(anchorLayoutA) &&
852 hasParamsOfLayoutKind(anchorLayoutB) &&
853 hasParamsOfLayoutKind(anchorLayoutCD)) {
854 dpasMxALayout = LayoutInfo(anchorLayoutA);
855 dpasMxBLayout = LayoutInfo(anchorLayoutB);
856 dpasMxCDLayout = LayoutInfo(anchorLayoutCD);
859 xegpu::DistributeLayoutAttr anchorLayoutAScale =
860 dpasMx.getLayoutAScaleAttr();
861 xegpu::DistributeLayoutAttr anchorLayoutBScale =
862 dpasMx.getLayoutBScaleAttr();
863 if (anchorLayoutAScale)
864 dpasMxAScaleLayout = LayoutInfo(anchorLayoutAScale);
865 if (anchorLayoutBScale)
866 dpasMxBScaleLayout = LayoutInfo(anchorLayoutBScale);
873 VectorType aTy = dpasMx.getAType();
874 VectorType bTy = dpasMx.getBType();
875 VectorType cdTy = dpasMx.getResultType();
880 Value scaleA = dpasMx.getScaleA();
881 Value scaleB = dpasMx.getScaleB();
883 aScaleTy = dyn_cast<VectorType>(scaleA.
getType());
885 bScaleTy = dyn_cast<VectorType>(scaleB.
getType());
887 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
888 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
889 requiredBLayout, requiredAScaleLayout, requiredBScaleLayout;
892 if (layoutKind == xegpu::LayoutKind::Subgroup) {
893 LayoutInfo consumerLayout = results[0]->getValue();
894 if (!consumerLayout.isAssigned())
897 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
901 "Unable to determine the number of subgroups for the operation.");
904 numSg = numSgOrErr.value();
909 consumerLayoutAttr, numSg, uArch);
910 if (!layouts.has_value()) {
912 "Failed to determine required layouts for DPAS_MX operands.");
916 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr,
917 requiredAScaleLayout, requiredBScaleLayout) = *layouts;
919 dpasMx.setLayoutAAttr(requiredALayout);
920 dpasMx.setLayoutBAttr(requiredBLayout);
921 dpasMx.setLayoutCdAttr(requiredCDLayoutAttr);
922 if (requiredAScaleLayout)
923 dpasMx.setLayoutAScaleAttr(requiredAScaleLayout);
924 if (requiredBScaleLayout)
925 dpasMx.setLayoutBScaleAttr(requiredBScaleLayout);
927 dpasMxALayout = LayoutInfo(requiredALayout);
928 dpasMxBLayout = LayoutInfo(requiredBLayout);
929 dpasMxCDLayout = LayoutInfo(requiredCDLayoutAttr);
930 if (requiredAScaleLayout)
931 dpasMxAScaleLayout = LayoutInfo(requiredAScaleLayout);
932 if (requiredBScaleLayout)
933 dpasMxBScaleLayout = LayoutInfo(requiredBScaleLayout);
940 propagateIfChanged(operands[0], operands[0]->meet(dpasMxALayout));
941 propagateIfChanged(operands[1], operands[1]->meet(dpasMxBLayout));
943 if (dpasMx.getAcc()) {
944 propagateIfChanged(operands[idx], operands[idx]->meet(dpasMxCDLayout));
947 if (dpasMx.getScaleA()) {
948 if (dpasMxAScaleLayout.isAssigned())
949 propagateIfChanged(operands[idx],
950 operands[idx]->meet(dpasMxAScaleLayout));
953 if (dpasMx.getScaleB()) {
954 if (dpasMxBScaleLayout.isAssigned())
955 propagateIfChanged(operands[idx],
956 operands[idx]->meet(dpasMxBScaleLayout));
962void LayoutInfoPropagation::visitStoreNdOp(
963 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
964 ArrayRef<const LayoutInfoLattice *> results) {
965 LayoutInfo storeLayout;
966 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
967 if (hasParamsOfLayoutKind(anchorLayout)) {
968 storeLayout = LayoutInfo(anchorLayout);
973 const auto *uArchInstruction =
974 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
976 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
977 VectorType dataTy = store.getValueType();
978 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
979 store.getValueType().getElementType());
981 store.emitWarning(
"No known block params found for the element type.");
982 auto [bWidth, bHeight, bCount] = blockWHC.value();
985 SmallVector<int> instData(dataTy.getRank(), 1);
987 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
990 "No suitable instruction multiple found for the given shape.");
991 if (dataTy.getRank() == 1) {
992 instData = {instWidth};
995 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
996 if (instHeight == -1)
998 "No suitable instruction multiple found for the given shape.");
999 instData[dataTy.getRank() - 2] = instHeight;
1000 instData[dataTy.getRank() - 1] = instWidth;
1003 if (layoutKind == xegpu::LayoutKind::InstData)
1005 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
1006 else if (layoutKind == xegpu::LayoutKind::Lane)
1008 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
1009 uArchInstruction->getPackedFormatBitSize());
1012 auto numSgOrErr =
getNumSg(store, sgSize);
1013 if (
failed(numSgOrErr)) {
1015 "Unable to determine the number of subgroups for the operation.");
1019 instData, numSgOrErr.value());
1020 if (sgLayouts.empty()) {
1022 "Unable to determine suitable subgroup layout for store value.");
1025 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
1026 SmallVector<int> sgData = {
1027 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
1028 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
1029 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
1030 dataTy.getContext(),
1036 store.setLayoutAttr(
1037 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
1041 for (LayoutInfoLattice *operand : operands)
1042 propagateIfChanged(operand, operand->meet(storeLayout));
1047void LayoutInfoPropagation::visitLoadNdOp(
1048 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
1049 ArrayRef<const LayoutInfoLattice *> results) {
1050 LayoutInfo loadLayout;
1051 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
1052 if (hasParamsOfLayoutKind(anchorLayout)) {
1053 loadLayout = LayoutInfo(anchorLayout);
1056 LayoutInfo valueLayout = results[0]->getValue();
1058 if (!valueLayout.isAssigned())
1060 loadLayout = valueLayout;
1064 if (
auto transpose =
load.getTranspose()) {
1065 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
1066 "LayoutInfoPropagation stage.");
1067 loadLayout = valueLayout.transpose(transpose.value());
1069 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1072 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1077void LayoutInfoPropagation::visitConvertLayoutOp(
1078 xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
1079 ArrayRef<const LayoutInfoLattice *> results) {
1080 xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
1081 LayoutInfo convertLayout(anchorLayout);
1083 propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
1088void LayoutInfoPropagation::visitTransposeOp(
1089 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
1090 ArrayRef<const LayoutInfoLattice *> results) {
1092 LayoutInfo resultLayout = results[0]->getValue();
1093 if (!resultLayout.isAssigned())
1096 auto consumerLayoutAttr =
1097 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
1099 consumerLayoutAttr, transpose.getPermutation());
1102 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1107void LayoutInfoPropagation::visitVectorBitcastOp(
1108 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
1109 ArrayRef<const LayoutInfoLattice *> results) {
1111 LayoutInfo resLayoutInfo = results[0]->getValue();
1112 if (!resLayoutInfo.isAssigned())
1115 auto srcVecType = bitcast.getSourceVectorType();
1116 auto resVecType = bitcast.getResultVectorType();
1118 auto consumerLayoutAttr =
1119 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1124 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1128 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1129 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
1133 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
1135 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1141void LayoutInfoPropagation::visitVectorInterleaveOp(
1142 vector::InterleaveOp interleave, ArrayRef<LayoutInfoLattice *> operands,
1143 ArrayRef<const LayoutInfoLattice *> results) {
1145 LayoutInfo resLayoutInfo = results[0]->getValue();
1146 if (!resLayoutInfo.isAssigned())
1149 auto srcVecType = interleave.getSourceVectorType();
1150 auto resVecType = interleave.getResultVectorType();
1152 auto consumerLayoutAttr =
1153 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1160 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1165 auto srcLayoutAttr =
1169 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1170 propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(srcLayoutAttr)));
1176void LayoutInfoPropagation::visitVectorDeinterleaveOp(
1177 vector::DeinterleaveOp deinterleave, ArrayRef<LayoutInfoLattice *> operands,
1178 ArrayRef<const LayoutInfoLattice *> results) {
1181 LayoutInfo resLayoutInfo = results[0]->getValue();
1182 if (!resLayoutInfo.isAssigned())
1185 auto consumerLayoutAttr =
1186 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1192 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1195void LayoutInfoPropagation::visitInsertStridedSliceOp(
1196 vector::InsertStridedSliceOp insertStridedSlice,
1197 ArrayRef<LayoutInfoLattice *> operands,
1198 ArrayRef<const LayoutInfoLattice *> results) {
1200 LayoutInfo resLayoutInfo = results[0]->getValue();
1201 if (!resLayoutInfo.isAssigned())
1204 auto srcVecType = insertStridedSlice.getSourceVectorType();
1205 auto resVecType = insertStridedSlice.getDestVectorType();
1207 auto consumerLayoutAttr =
1208 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1209 const uArch *uArch =
1215 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1217 requiredResLayoutAttr);
1220 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
1221 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1222 propagateIfChanged(operands[1],
1223 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
1228void LayoutInfoPropagation::visitLoadGatherOp(
1229 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
1230 ArrayRef<const LayoutInfoLattice *> results) {
1231 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1232 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
1236 VectorType resVecTy =
load.getValueType();
1237 int chunkSize =
load.getChunkSize().value_or(1);
1239 LayoutInfo resLayoutInfo = results[0]->getValue();
1240 if (!resLayoutInfo.isAssigned())
1242 auto consumerLayoutAttr =
1243 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1245 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1246 requiredAnchorLayoutAttr = anchorLayoutAttr;
1249 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1253 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1254 load.setLayoutAttr(requiredAnchorLayoutAttr);
1257 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1259 requiredAnchorLayoutAttr, chunkSize);
1260 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1261 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1264 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1265 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1267 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1268 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1273void LayoutInfoPropagation::visitStoreScatterOp(
1274 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1275 ArrayRef<const LayoutInfoLattice *> results) {
1277 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1278 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1282 VectorType srcVecTy = storeScatter.getValueType();
1283 int chunkSize = storeScatter.getChunkSize().value_or(1);
1285 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1286 requiredAnchorLayoutAttr = anchorLayoutAttr;
1289 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1293 layoutKind, srcVecTy, chunkSize, uArch);
1294 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1297 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1298 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1300 requiredAnchorLayoutAttr, chunkSize);
1301 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1304 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1306 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1307 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1309 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1310 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1313void LayoutInfoPropagation::visitLoadMatrixOp(
1314 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1315 ArrayRef<const LayoutInfoLattice *> results) {
1317 LayoutInfo resLayoutInfo = results[0]->getValue();
1318 if (!resLayoutInfo.isAssigned())
1321 auto consumerLayoutAttr =
1322 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1324 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1328 if (!hasParamsOfLayoutKind(anchorLayout)) {
1329 VectorType resVecTy =
1330 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1335 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1336 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1340void LayoutInfoPropagation::visitStoreMatrixOp(
1341 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1342 ArrayRef<const LayoutInfoLattice *> results) {
1343 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1345 if (hasParamsOfLayoutKind(anchorLayout)) {
1346 layout = LayoutInfo(anchorLayout);
1348 VectorType srcVecTy =
1349 llvm::cast<VectorType>(storeMatrix.getData().getType());
1353 auto requiredAnchorLayoutAttr =
1355 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1356 layout = LayoutInfo(requiredAnchorLayoutAttr);
1359 propagateIfChanged(operands[0], operands[0]->meet(layout));
1368class RunLayoutInfoPropagation {
1373 unsigned indexBitWidth)
1375 SymbolTableCollection symbolTable;
1377 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1381 LayoutInfo getLayoutInfo(Value val);
1383 void printAnalysisResult(llvm::raw_ostream &os);
1386 DataFlowSolver solver;
1391LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1392 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1395 return state->getValue();
1399void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1400 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1401 os <<
"function: " << funcOp.getName() <<
":\n";
1403 for (BlockArgument arg : funcOp.getArguments()) {
1404 LayoutInfo layout = getLayoutInfo(arg);
1405 os <<
"argument: " << arg <<
"\n";
1411 funcOp.walk([&](Operation *op) {
1417 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1423 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1424 LayoutInfo layout = getLayoutInfo(r);
1425 os <<
"layout for result #" << i <<
": ";
1432 SmallVector<FunctionOpInterface> funcOps;
1433 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1434 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1435 funcOps.push_back(funcOp);
1438 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1439 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1440 funcOps.push_back(gpuFuncOp);
1444 for (FunctionOpInterface funcOp : funcOps)
1445 printFunctionResult(funcOp);
1457static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1459 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1464 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1465 auto *parentOp = arg.getOwner()->getParentOp();
1466 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1467 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1469 return getDefiningCreateNdDescOp(tiedInit->
get());
1476struct ResolveLayoutConflicts {
1477 ResolveLayoutConflicts(Operation *parentOp)
1478 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1479 LogicalResult run();
1482 Operation *parentOp;
1484 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1485 LogicalResult resolveVectorConsumer(OpOperand &operand);
1486 LogicalResult assignResultLayout(OpResult &
result);
1491LogicalResult ResolveLayoutConflicts::run() {
1494 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1498 if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
1500 if (
result.getType().isIntOrFloat()) {
1501 auto res = assignResultLayout(
result);
1503 DBGS() <<
"Failed to resolve vector consumer for multi-reduction "
1512 Type operandType = operand.get().getType();
1513 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1514 isa<xegpu::TensorDescType>(operandType)) {
1515 auto res = resolveTensorDescConsumer(operand);
1517 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1523 if (isa<VectorType>(operandType)) {
1524 auto res = resolveVectorConsumer(operand);
1526 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1535 DBGS() <<
"IR after resolving layout conflicts:\n";
1539 return r.wasInterrupted() ? failure() :
success();
1542LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &
result) {
1543 Operation *producerOp =
result.getDefiningOp();
1547 auto convertOp = xegpu::ConvertLayoutOp::create(
1550 result.replaceAllUsesExcept(convertOp.getResult(), convertOp);
1555ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1556 Value vectorValue = operand.
get();
1557 Operation *consumerOp = operand.
getOwner();
1560 if (!producerLayout) {
1561 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1562 vectorTy && vectorTy.getRank() > 1)
1563 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1570 if (isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(
1575 if (!consumerLayout)
1577 "No consumer layout found for vector operand.");
1580 if (consumerLayout.isEqualTo(producerLayout))
1591 isa<OpResult>(vectorValue) &&
1594 Operation *
clone = builder.
clone(*producerOp);
1599 operand.
set(cloneResult);
1605 auto convertOp = xegpu::ConvertLayoutOp::create(
1606 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1607 producerLayout, consumerLayout);
1610 operand.
set(convertOp.getResult());
1615ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1616 Operation *consumerOp = operand.
getOwner();
1617 Value tdescValue = operand.
get();
1618 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1619 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1620 assert(anchorOp && currTDescType &&
1621 "Expected anchor layout op and tensor descriptor consumer.");
1622 Attribute currLayout = currTDescType.getLayout();
1623 Attribute expectedLayout = anchorOp.getAnchorLayout();
1626 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1628 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1629 if (!conflictingCreateNdOp) {
1630 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1631 << tdescValue <<
"\n";
1636 auto newTensorDescType = xegpu::TensorDescType::get(
1637 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1638 currTDescType.getElementType(), currTDescType.getEncoding(),
1640 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1641 builder, consumerOp->
getLoc(), newTensorDescType,
1642 conflictingCreateNdOp->getOperands(),
1643 conflictingCreateNdOp->getAttrs());
1661 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1668 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1671 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1672 if (!layout &&
result.getNumUses() > 0) {
1673 op->
emitWarning(
"op has users but no layout assigned for its result");
1678 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1679 auto typeWithLayout = xegpu::TensorDescType::get(
1680 tensorDescTy.getContext(), tensorDescTy.getShape(),
1681 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1682 result.setType(typeWithLayout);
1716 mlir::RegionBranchTerminatorOpInterface terminator,
1719 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1724 branchOp.getSuccessorOperandInputMapping(mapping,
1726 for (
const auto &[successorOperand, successorInputs] : mapping) {
1727 for (
Value successorInput : successorInputs) {
1728 Type inputType = successorInput.getType();
1730 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1732 xegpu::DistributeLayoutAttr successorOperandLayout =
1733 getLayoutOfValue(successorOperand->get());
1736 if (!successorOperandLayout) {
1737 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1738 "branch terminator: "
1739 << successorOperand->get() <<
"\n");
1743 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1744 auto newTdescTy = xegpu::TensorDescType::get(
1745 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1746 tdescTy.getEncoding(), successorOperandLayout);
1747 successorInput.setType(newTdescTy);
1752 if (
auto result = dyn_cast<OpResult>(successorInput))
1761 mlir::FunctionOpInterface funcOp,
1767 if (!isa<FunctionType>(funcOp.getFunctionType()))
1772 Type argType = arg.getType();
1773 newArgTypes.push_back(argType);
1774 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1776 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1778 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1779 <<
" but got none.\n");
1782 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1783 auto newTdescTy = xegpu::TensorDescType::get(
1784 tensorDescTy.getContext(), tensorDescTy.getShape(),
1785 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1786 arg.setType(newTdescTy);
1787 newArgTypes.back() = newTdescTy;
1792 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1793 funcOp.getResultTypes()));
1798struct XeGPUPropagateLayoutPass final
1799 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1800 XeGPUPropagateLayoutPass() =
default;
1801 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1802 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1803 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1804 void runOnOperation()
override;
1811 unsigned indexBitWidth,
bool printOnly) {
1812 RunLayoutInfoPropagation analysis(
target, layoutKind, indexBitWidth);
1815 auto &os = llvm::outs();
1816 analysis.printAnalysisResult(os);
1820 auto getLayoutFromPropagation =
1821 [&](
Value val) -> xegpu::DistributeLayoutAttr {
1822 LayoutInfo layout = analysis.getLayoutInfo(val);
1823 if (
auto opResult = dyn_cast<OpResult>(val)) {
1824 Operation *defOp = opResult.getDefiningOp();
1825 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1826 auto anchorLayout = anchorOp.getAnchorLayout();
1827 if (anchorLayout !=
nullptr)
1828 return anchorLayout;
1830 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1832 if (requiredResLayoutAttr !=
nullptr)
1833 return requiredResLayoutAttr;
1835 if (!layout.isAssigned())
1837 xegpu::DistributeLayoutAttr layoutAttr =
1838 cast<xegpu::DistributeLayoutAttr>(layout.get());
1839 if (layout.isSliceLayout())
1840 return cast<xegpu::SliceAttr>(layoutAttr);
1842 return cast<xegpu::LayoutAttr>(layoutAttr);
1850 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1852 getLayoutFromPropagation);
1854 .Case([&](mlir::RegionBranchOpInterface branchOp) {
1856 getLayoutFromPropagation);
1858 .Case([&](mlir::FunctionOpInterface funcOp) {
1860 getLayoutFromPropagation);
1863 r =
updateOp(builder, op, getLayoutFromPropagation);
1866 op.
emitError(
"Failed to update operation with the layout.");
1872 if (walkResult.wasInterrupted())
1879 ResolveLayoutConflicts resolver(
target);
1880 return resolver.run();
1883void XeGPUPropagateLayoutPass::runOnOperation() {
1888 if (this->layoutKind ==
"lane") {
1890 }
else if (this->layoutKind ==
"inst") {
1892 }
else if (this->layoutKind ==
"subgroup") {
1893 layoutKind = xegpu::LayoutKind::Subgroup;
1895 getOperation()->emitError(
"Unsupported layout kind option: " +
1897 signalPassFailure();
1902 this->indexBitWidth, this->printOnly))) {
1903 signalPassFailure();
1908 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp, GetLayoutFnTy getLayoutOfValue)
Propagate layouts from a region branch op's region entry block arguments back to its init operands.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const