29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
84 xegpu::DistributeLayoutAttr storage =
nullptr;
87 LayoutInfo() =
default;
88 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
92 bool operator==(
const LayoutInfo &other)
const {
93 return this->isAssigned() == other.isAssigned();
96 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
98 static LayoutInfo join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
102 bool isAssigned()
const {
return storage !=
nullptr; }
118 bool isSliceLayout()
const {
121 return isa<xegpu::SliceAttr>(storage);
127 return storage.getRank();
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291template <
typename Ty>
292static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
294 unsigned packingSize) {
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
313getSIMTLayoutInforForScatterIO(VectorType vectorTy,
315 unsigned packingSize) {
317 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
318 "Expected 1D or 2D vector.");
320 assert(vectorTy.getElementType().isIntOrFloat() &&
321 "Expected int or float element type.");
323 if (vectorTy.getRank() == 1)
324 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch);
326 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
327 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
328 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
329 {uArch->getSubgroupSize(), 1},
330 {1, packingFactor}));
340getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
342 unsigned packingSize) {
343 Type elementTy = vectorTy.getElementType();
345 "Expected int or float type in DPAS operands");
354 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
357 return getSIMTLayoutInforForBlockIO(vectorTy,
uArch, packingSize);
369class LayoutInfoPropagation
372 LayoutKind layoutKind;
376 void visitStoreNdOp(xegpu::StoreNdOp store,
380 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
384 void visitLoadNdOp(xegpu::LoadNdOp
load,
388 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
392 void visitTransposeOp(vector::TransposeOp transpose,
396 void visitVectorBitcastOp(vector::BitCastOp bitcast,
400 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
404 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
408 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
412 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
416 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
419 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
423 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
428 LayoutKind layoutKind)
430 layoutKind(layoutKind) {}
437 void visitBranchOperand(
OpOperand &operand)
override {};
439 void visitCallOperand(
OpOperand &operand)
override {};
445 void visitExternalCall(CallOpInterface call,
450 void setToExitState(LayoutInfoLattice *lattice)
override {
451 (
void)lattice->meet(LayoutInfo());
456LogicalResult LayoutInfoPropagation::visitOperation(
457 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
458 ArrayRef<const LayoutInfoLattice *> results) {
460 .Case<xegpu::DpasOp>(
461 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
462 .Case<xegpu::StoreNdOp>(
463 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
464 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
465 visitStoreScatterOp(storeScatterOp, operands, results);
467 .Case<xegpu::LoadNdOp>(
468 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
469 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
470 visitLoadGatherOp(loadGatherOp, operands, results);
472 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
473 visitCreateDescOp(createDescOp, operands, results);
475 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
476 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
478 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
479 visitPrefetchNdOp(prefetchNdOp, operands, results);
481 .Case<vector::TransposeOp>([&](
auto transposeOp) {
482 visitTransposeOp(transposeOp, operands, results);
484 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
485 visitVectorBitcastOp(bitcastOp, operands, results);
487 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
488 visitVectorMultiReductionOp(reductionOp, operands, results);
490 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
491 visitVectorBroadCastOp(broadcastOp, operands, results);
493 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
494 visitShapeCastOp(shapeCastOp, operands, results);
497 .Default([&](Operation *op) {
498 for (
const LayoutInfoLattice *resultInfo : results) {
499 if (!resultInfo->getValue().isAssigned())
501 for (
auto [operandInfo, operand] :
505 if (!isa<xegpu::TensorDescType, VectorType>(
506 operand.get().getType()))
509 meet(operandInfo, *resultInfo);
517bool LayoutInfoPropagation::hasParamsOfLayoutKind(
518 xegpu::DistributeLayoutAttr anchorLayout) {
519 if (anchorLayout ==
nullptr) {
522 if (layoutKind == LayoutKind::InstData) {
523 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
524 }
else if (layoutKind == LayoutKind::Lane) {
525 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
526 anchorLayout.getEffectiveLaneDataAsInt().empty());
527 }
else if (layoutKind == LayoutKind::Subgroup) {
528 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
529 anchorLayout.getEffectiveSgDataAsInt().empty());
534void LayoutInfoPropagation::visitPrefetchNdOp(
535 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
536 ArrayRef<const LayoutInfoLattice *> results) {
538 LayoutInfo prefetchLayout;
539 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
540 if (hasParamsOfLayoutKind(anchorLayout)) {
541 prefetchLayout = LayoutInfo(anchorLayout);
545 auto tdescTy = prefetch.getTensorDescType();
548 const auto *uArchInstruction =
549 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
550 uArch->getInstruction(
551 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
554 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
556 prefetch.emitWarning(
"No known block params found for the element type.");
557 auto [bWidth, bHeight, bCount] = blockWHC.value();
558 SmallVector<int> instData;
560 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
562 prefetch.emitWarning(
563 "No suitable instruction multiple found for the given shape.");
564 if (tdescTy.getRank() == 1)
565 instData = {instWidth};
568 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
569 if (instHeight == -1)
570 prefetch.emitWarning(
571 "No suitable instruction multiple found for the given shape.");
572 instData = {instHeight, instWidth};
575 if (layoutKind == LayoutKind::InstData)
577 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
579 prefetchLayout = getSIMTLayoutInforForBlockIO(
580 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
582 prefetch.setLayoutAttr(
583 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
586 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
589void LayoutInfoPropagation::visitVectorMultiReductionOp(
590 vector::MultiDimReductionOp reduction,
591 ArrayRef<LayoutInfoLattice *> operands,
592 ArrayRef<const LayoutInfoLattice *> results) {
594 LayoutInfo resultLayout = results[0]->getValue();
595 if (!resultLayout.isAssigned())
598 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
599 if (!resultTy || resultTy.getRank() != 1) {
600 reduction.emitWarning(
"Expecting output type to be 1D vector.");
606 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
607 reduction->getContext(), 2, uArch->getSubgroupSize());
608 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
610 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
613void LayoutInfoPropagation::visitVectorBroadCastOp(
614 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
615 ArrayRef<const LayoutInfoLattice *> results) {
617 LayoutInfo resultLayout = results[0]->getValue();
618 if (!resultLayout.isAssigned())
621 VectorType resultTy =
broadcast.getResultVectorType();
622 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
628 if (sourceTy.getRank() != resultTy.getRank()) {
629 auto sourceDims = sourceTy.getShape();
630 auto resultDims = resultTy.getShape();
631 SmallVector<int64_t> bcastDims;
632 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
634 for (
int i = 0; i < dimDiff; i++)
635 bcastDims.push_back(i);
639 for (
size_t i = 0; i < sourceDims.size(); i++)
640 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
641 bcastDims.push_back(i + dimDiff);
644 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
646 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
649 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
652 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
655void LayoutInfoPropagation::visitShapeCastOp(
656 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
657 ArrayRef<const LayoutInfoLattice *> results) {
659 LayoutInfo resultLayout = results[0]->getValue();
660 if (!resultLayout.isAssigned())
662 VectorType sourceTy = shapeCast.getSourceVectorType();
663 VectorType resultTy = shapeCast.getResultVectorType();
667 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
668 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
671 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
672 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
673 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
675 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
680void LayoutInfoPropagation::visitUpdateNdOffsetOp(
681 xegpu::UpdateNdOffsetOp updateNdOffset,
682 ArrayRef<LayoutInfoLattice *> operands,
683 ArrayRef<const LayoutInfoLattice *> results) {
685 LayoutInfo resultLayout = results[0]->getValue();
686 if (!resultLayout.isAssigned())
689 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
693void LayoutInfoPropagation::visitDpasOp(
694 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
695 ArrayRef<const LayoutInfoLattice *> results) {
697 LayoutInfo dpasALayout;
698 LayoutInfo dpasBLayout;
699 LayoutInfo dpasCDLayout;
701 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
702 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
703 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
704 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
705 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
706 "Expected anchor layout for DPAS A operand.");
707 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
708 "Expected anchor layout for DPAS B operand.");
709 dpasALayout = LayoutInfo(anchorLayoutA);
710 dpasBLayout = LayoutInfo(anchorLayoutB);
711 dpasCDLayout = LayoutInfo(anchorLayoutCD);
715 VectorType aTy = dpas.getLhsType();
716 VectorType bTy = dpas.getRhsType();
719 const int subgroupSize = uArch->getSubgroupSize();
720 const auto *uArchInstruction =
721 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
722 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
724 const unsigned dataALen = aTy.getShape().front();
725 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
730 "No suitable instruction multiple found for the given shape.");
732 const unsigned dataBLen = bTy.getShape().back();
733 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
740 "No suitable instruction multiple found for the given shape.");
741 SmallVector<int> instDataA = {maxALen, subgroupSize};
742 SmallVector<int> instDataB = {subgroupSize, maxBLen};
744 if (layoutKind == LayoutKind::InstData) {
746 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
748 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
750 dpasALayout = getSIMTLayoutInfoForDPASOperand(
751 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
752 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
753 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
756 if (operands.size() > 2) {
757 VectorType cTy = dpas.getAccType();
758 if (layoutKind == LayoutKind::InstData) {
759 const unsigned dataCLen = bTy.getShape().back();
761 uArchInstruction->getSupportedN(bTy.getElementType());
763 dataCLen, ArrayRef<unsigned>(supportedCLen));
766 "No suitable instruction multiple found for the given shape.");
767 SmallVector<int> instDataC = {maxALen, maxCLen};
769 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
771 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
772 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
774 dpas.setLayoutCdAttr(
775 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
778 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
780 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
783 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
784 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
785 if (operands.size() > 2) {
786 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
791void LayoutInfoPropagation::visitStoreNdOp(
792 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
793 ArrayRef<const LayoutInfoLattice *> results) {
795 LayoutInfo storeLayout;
796 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
797 if (hasParamsOfLayoutKind(anchorLayout)) {
798 storeLayout = LayoutInfo(anchorLayout);
801 const auto *uArchInstruction =
802 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
803 uArch->getInstruction(
804 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
805 VectorType dataTy = store.getValueType();
806 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
807 store.getValueType().getElementType());
809 store.emitWarning(
"No known block params found for the element type.");
810 auto [bWidth, bHeight, bCount] = blockWHC.value();
811 SmallVector<int> instData;
813 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
816 "No suitable instruction multiple found for the given shape.");
817 if (dataTy.getRank() == 1)
818 instData = {instWidth};
821 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
822 if (instHeight == -1)
824 "No suitable instruction multiple found for the given shape.");
825 instData = {instHeight, instWidth};
828 if (layoutKind == LayoutKind::InstData)
830 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
832 storeLayout = getSIMTLayoutInforForBlockIO(
833 store.getValueType(), uArch,
834 uArchInstruction->getPackedFormatBitSize());
836 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
840 for (LayoutInfoLattice *operand : operands)
841 propagateIfChanged(operand, operand->meet(storeLayout));
846void LayoutInfoPropagation::visitLoadNdOp(
847 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
848 ArrayRef<const LayoutInfoLattice *> results) {
850 LayoutInfo loadLayout;
851 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
852 if (hasParamsOfLayoutKind(anchorLayout)) {
853 loadLayout = LayoutInfo(anchorLayout);
856 LayoutInfo valueLayout = results[0]->getValue();
858 if (!valueLayout.isAssigned())
860 loadLayout = valueLayout;
864 if (
auto transpose =
load.getTranspose()) {
865 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
866 "LayoutInfoPropagation stage.");
867 loadLayout = valueLayout.transpose(transpose.value());
869 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
872 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
877void LayoutInfoPropagation::visitTransposeOp(
878 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
879 ArrayRef<const LayoutInfoLattice *> results) {
881 LayoutInfo resultLayout = results[0]->getValue();
882 if (!resultLayout.isAssigned())
884 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
886 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
891void LayoutInfoPropagation::visitVectorBitcastOp(
892 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
893 ArrayRef<const LayoutInfoLattice *> results) {
895 LayoutInfo resultLayout = results[0]->getValue();
896 if (!resultLayout.isAssigned())
898 int inElemTyBitWidth =
899 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
900 int outElemTyBitWidth =
901 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
903 if (inElemTyBitWidth == outElemTyBitWidth) {
904 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
908 auto resultLaneLayout = resultLayout.getLaneLayout();
909 auto resultLaneData = resultLayout.getLaneData();
911 bitcast.getResultVectorType(),
912 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
915 "Result vector type can not be evenly distributed across lanes.");
918 int64_t rank = bitcast.getSourceVectorType().getRank();
921 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
922 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
923 : outElemTyBitWidth / inElemTyBitWidth;
924 SmallVector<int> sourceLaneLayout =
925 resultLayout.getLaneLayout();
926 SmallVector<int> outData = resultLayout.getLaneData();
931 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
932 if (outInnerBitsPerLane < inElemTyBitWidth) {
934 "Narrowing bitcast with cross lane communication is not supported.");
939 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
940 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
941 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
942 "dimension other than "
943 "the innermost dimension.");
947 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
948 : outData[rank - 1] * bitCastRatio;
949 sourceLaneData.push_back(innerMostLaneData);
953 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
954 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
959void LayoutInfoPropagation::visitLoadGatherOp(
960 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
961 ArrayRef<const LayoutInfoLattice *> results) {
963 LayoutInfo loadLayout;
964 LayoutInfo maskLayout;
965 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
966 if (hasParamsOfLayoutKind(anchorLayout)) {
967 loadLayout = LayoutInfo(anchorLayout);
968 maskLayout = loadLayout;
972 VectorType payloadTy =
load.getValueType();
974 load.emitWarning(
"Not propagating, non-vector payload supplied.");
978 const int subgroupSize = uArch->getSubgroupSize();
979 SmallVector<int> instData{subgroupSize};
980 if (
auto chunkSize =
load.getChunkSize().value_or(0); chunkSize > 1)
981 instData.push_back(chunkSize);
982 else if (
auto srcTdescTy =
983 dyn_cast<xegpu::TensorDescType>(
load.getSourceType())) {
984 if (srcTdescTy.getChunkSizeAsInt() > 1)
985 instData.push_back(chunkSize);
988 if (layoutKind == LayoutKind::InstData)
990 LayoutInfo(xegpu::LayoutAttr::get(
load.getContext(), instData));
992 loadLayout = getSIMTLayoutInforForScatterIO(
993 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
996 maskLayout = getDefaultSIMTLayoutInfo(
load->getContext(), 1, subgroupSize);
998 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1001 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1002 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1004 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
1005 if (
load.getOffsets())
1006 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1011void LayoutInfoPropagation::visitCreateDescOp(
1012 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1013 ArrayRef<const LayoutInfoLattice *> results) {
1014 LayoutInfo descLayout = results[0]->getValue();
1016 if (!descLayout.isAssigned())
1020 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1021 uArch->getSubgroupSize());
1022 propagateIfChanged(operands[1], operands[1]->meet(layout));
1027void LayoutInfoPropagation::visitStoreScatterOp(
1028 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1029 ArrayRef<const LayoutInfoLattice *> results) {
1031 LayoutInfo payloadLayout;
1032 LayoutInfo maskLayout;
1033 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
1034 if (hasParamsOfLayoutKind(anchorLayout)) {
1035 payloadLayout = LayoutInfo(anchorLayout);
1036 maskLayout = payloadLayout;
1041 VectorType payloadTy = storeScatter.getValueType();
1043 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1048 const int subgroupSize = uArch->getSubgroupSize();
1050 if (layoutKind == LayoutKind::InstData) {
1051 SmallVector<int> instData{subgroupSize};
1052 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0);
1054 instData.push_back(chunkSize);
1055 else if (
auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1056 storeScatter.getDestType())) {
1057 if (dstTdescTy.getChunkSizeAsInt() > 1)
1058 instData.push_back(chunkSize);
1060 payloadLayout = LayoutInfo(
1061 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1063 auto payloadShape = payloadTy.getShape();
1064 if (payloadShape.size() > 1)
1065 assert(payloadShape[0] == subgroupSize &&
1066 "Expected the first dimension of 2D tensor descriptor to be "
1069 payloadLayout = getSIMTLayoutInforForScatterIO(
1070 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
1074 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1076 storeScatter.setLayoutAttr(
1077 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1080 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1082 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1083 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1085 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1086 if (storeScatter.getOffsets())
1087 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1096class RunLayoutInfoPropagation {
1100 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) :
target(op) {
1101 SymbolTableCollection symbolTable;
1103 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
1107 LayoutInfo getLayoutInfo(Value val);
1109 void printAnalysisResult(llvm::raw_ostream &os);
1112 DataFlowSolver solver;
1117LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1118 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1121 return state->getValue();
1125void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1126 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1127 os <<
"function: " << funcOp.getName() <<
":\n";
1129 for (BlockArgument arg : funcOp.getArguments()) {
1130 LayoutInfo layout = getLayoutInfo(arg);
1131 os <<
"argument: " << arg <<
"\n";
1137 funcOp.walk([&](Operation *op) {
1143 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1149 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1150 LayoutInfo layout = getLayoutInfo(r);
1151 os <<
"layout for result #" << i <<
": ";
1158 SmallVector<FunctionOpInterface> funcOps;
1159 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1160 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1161 funcOps.push_back(funcOp);
1164 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1165 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1166 funcOps.push_back(gpuFuncOp);
1170 for (FunctionOpInterface funcOp : funcOps)
1171 printFunctionResult(funcOp);
1184 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1191 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1194 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1195 if (!layout &&
result.getNumUses() > 0) {
1196 op->
emitWarning(
"op has users but no layout assigned for its result");
1201 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1202 auto typeWithLayout = xegpu::TensorDescType::get(
1203 tensorDescTy.getContext(), tensorDescTy.getShape(),
1204 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1205 result.setType(typeWithLayout);
1239 mlir::RegionBranchTerminatorOpInterface terminator,
1242 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1247 branchOp.getSuccessorOperandInputMapping(mapping,
1249 for (
const auto &[successorOperand, successorInputs] : mapping) {
1250 for (
Value successorInput : successorInputs) {
1251 Type inputType = successorInput.getType();
1253 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1255 xegpu::DistributeLayoutAttr successorInputLayout =
1256 getLayoutOfValue(successorInput);
1257 xegpu::DistributeLayoutAttr successorOperandLayout =
1258 getLayoutOfValue(successorOperand->get());
1261 if (!successorOperandLayout) {
1262 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1263 "branch terminator: "
1264 << successorOperand->get() <<
"\n");
1268 if (successorInputLayout &&
1269 successorInputLayout != successorOperandLayout) {
1270 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1271 "operand forwarded as the argument: "
1272 << successorInputLayout <<
" vs "
1273 << successorOperandLayout <<
"\n");
1277 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1278 auto newTdescTy = xegpu::TensorDescType::get(
1279 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1280 tdescTy.getEncoding(), successorOperandLayout);
1281 successorInput.setType(newTdescTy);
1286 if (
auto result = dyn_cast<OpResult>(successorInput))
1295 mlir::FunctionOpInterface funcOp,
1300 Type argType = arg.getType();
1301 newArgTypes.push_back(argType);
1302 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1304 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1306 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1307 <<
" but got none.\n");
1310 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1311 auto newTdescTy = xegpu::TensorDescType::get(
1312 tensorDescTy.getContext(), tensorDescTy.getShape(),
1313 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1314 arg.setType(newTdescTy);
1315 newArgTypes.back() = newTdescTy;
1320 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1321 funcOp.getResultTypes()));
1326struct XeGPUPropagateLayoutPass final
1327 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1328 XeGPUPropagateLayoutPass() =
default;
1329 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1330 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1331 : XeGPUPropagateLayoutBase(
options) {}
1332 void runOnOperation()
override;
1337void XeGPUPropagateLayoutPass::runOnOperation() {
1338 LayoutKind layoutKind;
1339 if (this->layoutKind ==
"lane") {
1340 layoutKind = LayoutKind::Lane;
1341 }
else if (this->layoutKind ==
"inst") {
1342 layoutKind = LayoutKind::InstData;
1343 }
else if (this->layoutKind ==
"subgroup") {
1344 layoutKind = LayoutKind::Subgroup;
1346 getOperation()->emitError(
"Unsupported layout kind option: " +
1348 signalPassFailure();
1351 RunLayoutInfoPropagation
analysis(getOperation(), layoutKind);
1354 auto &os = llvm::outs();
1359 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1360 LayoutInfo layout =
analysis.getLayoutInfo(val);
1361 if (!layout.isAssigned())
1363 xegpu::DistributeLayoutAttr layoutAttr =
1364 cast<xegpu::DistributeLayoutAttr>(layout.get());
1365 if (layout.isSliceLayout())
1366 return cast<xegpu::SliceAttr>(layoutAttr);
1367 return cast<xegpu::LayoutAttr>(layoutAttr);
1371 Operation *op = getOperation();
1373 for (mlir::Operation &op : llvm::reverse(block->
getOperations())) {
1376 .Case<mlir::RegionBranchTerminatorOpInterface>(
1377 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1379 getXeGPULayoutForValue);
1381 .Case<mlir::FunctionOpInterface>(
1382 [&](mlir::FunctionOpInterface funcOp) {
1384 getXeGPULayoutForValue);
1386 .Default([&](Operation *op) {
1387 r =
updateOp(builder, op, getXeGPULayoutForValue);
1390 op.
emitError(
"Failed to update operation with the layout.");
1396 if (walkResult.wasInterrupted()) {
1397 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0