29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
84 xegpu::DistributeLayoutAttr storage =
nullptr;
87 LayoutInfo() =
default;
88 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
92 bool operator==(
const LayoutInfo &other)
const {
93 return this->isAssigned() == other.isAssigned();
96 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
98 static LayoutInfo join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
102 bool isAssigned()
const {
return storage !=
nullptr; }
118 bool isSliceLayout()
const {
121 return isa<xegpu::SliceAttr>(storage);
127 return storage.getRank();
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
293 unsigned packingSize,
294 bool isScattered =
false) {
296 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(vectorTy.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (vectorTy.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch);
305 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
308 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
309 {uArch->getSubgroupSize(), 1},
310 {1, packingFactor}));
312 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
313 {1, uArch->getSubgroupSize()},
314 {1, packingFactor}));
318static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
320 unsigned packingSize,
321 bool isScattered =
false) {
323 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
324 "Expected 1D or 2D TensorDesc.");
326 assert(tdescTy.getElementType().isIntOrFloat() &&
327 "Expected int or float element type.");
329 if (tdescTy.getRank() == 1)
330 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1,
uArch);
332 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
334 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
336 return LayoutInfo(xegpu::LayoutAttr::get(
337 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
340 return LayoutInfo(xegpu::LayoutAttr::get(
341 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
351getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
353 unsigned packingSize) {
354 Type elementTy = vectorTy.getElementType();
356 "Expected int or float type in DPAS operands");
365 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
368 return getDefaultSIMTLayoutInfo(vectorTy,
uArch, packingSize);
380class LayoutInfoPropagation
383 LayoutKind layoutKind;
387 void visitStoreNdOp(xegpu::StoreNdOp store,
391 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
395 void visitLoadNdOp(xegpu::LoadNdOp
load,
399 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
403 void visitTransposeOp(vector::TransposeOp transpose,
407 void visitVectorBitcastOp(vector::BitCastOp bitcast,
411 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
415 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
419 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
423 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
427 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
430 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
434 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
439 LayoutKind layoutKind)
441 layoutKind(layoutKind) {}
448 void visitBranchOperand(
OpOperand &operand)
override {};
450 void visitCallOperand(
OpOperand &operand)
override {};
456 void visitExternalCall(CallOpInterface call,
461 void setToExitState(LayoutInfoLattice *lattice)
override {
462 (
void)lattice->meet(LayoutInfo());
467LogicalResult LayoutInfoPropagation::visitOperation(
468 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
469 ArrayRef<const LayoutInfoLattice *> results) {
471 .Case<xegpu::DpasOp>(
472 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
473 .Case<xegpu::StoreNdOp>(
474 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
475 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
476 visitStoreScatterOp(storeScatterOp, operands, results);
478 .Case<xegpu::LoadNdOp>(
479 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
480 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
481 visitLoadGatherOp(loadGatherOp, operands, results);
483 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
484 visitCreateDescOp(createDescOp, operands, results);
486 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
487 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
489 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
490 visitPrefetchNdOp(prefetchNdOp, operands, results);
492 .Case<vector::TransposeOp>([&](
auto transposeOp) {
493 visitTransposeOp(transposeOp, operands, results);
495 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
496 visitVectorBitcastOp(bitcastOp, operands, results);
498 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
499 visitVectorMultiReductionOp(reductionOp, operands, results);
501 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
502 visitVectorBroadCastOp(broadcastOp, operands, results);
504 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
505 visitShapeCastOp(shapeCastOp, operands, results);
508 .Default([&](Operation *op) {
509 for (
const LayoutInfoLattice *resultInfo : results) {
510 if (!resultInfo->getValue().isAssigned())
512 for (
auto [operandInfo, operand] :
516 if (!isa<xegpu::TensorDescType, VectorType>(
517 operand.get().getType()))
520 meet(operandInfo, *resultInfo);
528bool LayoutInfoPropagation::hasParamsOfLayoutKind(
529 xegpu::DistributeLayoutAttr anchorLayout) {
530 if (anchorLayout ==
nullptr) {
533 if (layoutKind == LayoutKind::InstData) {
534 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
535 }
else if (layoutKind == LayoutKind::Lane) {
536 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
537 anchorLayout.getEffectiveLaneDataAsInt().empty());
538 }
else if (layoutKind == LayoutKind::Subgroup) {
539 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
540 anchorLayout.getEffectiveSgDataAsInt().empty());
545void LayoutInfoPropagation::visitPrefetchNdOp(
546 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
547 ArrayRef<const LayoutInfoLattice *> results) {
549 LayoutInfo prefetchLayout;
550 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
551 if (hasParamsOfLayoutKind(anchorLayout)) {
552 prefetchLayout = LayoutInfo(anchorLayout);
556 auto tdescTy = prefetch.getTensorDescType();
559 const auto *uArchInstruction =
560 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
561 uArch->getInstruction(
562 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
565 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
567 prefetch.emitWarning(
"No known block params found for the element type.");
568 auto [bWidth, bHeight, bCount] = blockWHC.value();
569 SmallVector<int> instData;
571 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
573 prefetch.emitWarning(
574 "No suitable instruction multiple found for the given shape.");
575 if (tdescTy.getRank() == 1)
576 instData = {instWidth};
579 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
580 if (instHeight == -1)
581 prefetch.emitWarning(
582 "No suitable instruction multiple found for the given shape.");
583 instData = {instHeight, instWidth};
586 if (layoutKind == LayoutKind::InstData)
588 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
590 prefetchLayout = getDefaultSIMTLayoutInfo(
591 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
593 prefetch.setLayoutAttr(
594 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
597 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
600void LayoutInfoPropagation::visitVectorMultiReductionOp(
601 vector::MultiDimReductionOp reduction,
602 ArrayRef<LayoutInfoLattice *> operands,
603 ArrayRef<const LayoutInfoLattice *> results) {
605 LayoutInfo resultLayout = results[0]->getValue();
606 if (!resultLayout.isAssigned())
609 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
610 if (!resultTy || resultTy.getRank() != 1) {
611 reduction.emitWarning(
"Expecting output type to be 1D vector.");
617 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
618 reduction->getContext(), 2, uArch->getSubgroupSize());
619 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
621 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
624void LayoutInfoPropagation::visitVectorBroadCastOp(
625 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
626 ArrayRef<const LayoutInfoLattice *> results) {
628 LayoutInfo resultLayout = results[0]->getValue();
629 if (!resultLayout.isAssigned())
632 VectorType resultTy =
broadcast.getResultVectorType();
633 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
639 if (sourceTy.getRank() != resultTy.getRank()) {
640 auto sourceDims = sourceTy.getShape();
641 auto resultDims = resultTy.getShape();
642 SmallVector<int64_t> bcastDims;
643 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
645 for (
int i = 0; i < dimDiff; i++)
646 bcastDims.push_back(i);
650 for (
size_t i = 0; i < sourceDims.size(); i++)
651 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
652 bcastDims.push_back(i + dimDiff);
655 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
657 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
660 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
663 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
666void LayoutInfoPropagation::visitShapeCastOp(
667 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
668 ArrayRef<const LayoutInfoLattice *> results) {
670 LayoutInfo resultLayout = results[0]->getValue();
671 if (!resultLayout.isAssigned())
673 VectorType sourceTy = shapeCast.getSourceVectorType();
674 VectorType resultTy = shapeCast.getResultVectorType();
678 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
679 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
682 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
683 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
684 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
686 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
691void LayoutInfoPropagation::visitUpdateNdOffsetOp(
692 xegpu::UpdateNdOffsetOp updateNdOffset,
693 ArrayRef<LayoutInfoLattice *> operands,
694 ArrayRef<const LayoutInfoLattice *> results) {
696 LayoutInfo resultLayout = results[0]->getValue();
697 if (!resultLayout.isAssigned())
700 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
704void LayoutInfoPropagation::visitDpasOp(
705 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
706 ArrayRef<const LayoutInfoLattice *> results) {
708 LayoutInfo dpasALayout;
709 LayoutInfo dpasBLayout;
710 LayoutInfo dpasCDLayout;
712 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
713 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
714 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
715 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
716 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
717 "Expected anchor layout for DPAS A operand.");
718 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
719 "Expected anchor layout for DPAS B operand.");
720 dpasALayout = LayoutInfo(anchorLayoutA);
721 dpasBLayout = LayoutInfo(anchorLayoutB);
722 dpasCDLayout = LayoutInfo(anchorLayoutCD);
726 VectorType aTy = dpas.getLhsType();
727 VectorType bTy = dpas.getRhsType();
730 const int subgroupSize = uArch->getSubgroupSize();
731 const auto *uArchInstruction =
732 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
733 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
735 const unsigned dataALen = aTy.getShape().front();
736 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
741 "No suitable instruction multiple found for the given shape.");
743 const unsigned dataBLen = bTy.getShape().back();
744 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
751 "No suitable instruction multiple found for the given shape.");
752 SmallVector<int> instDataA = {maxALen, subgroupSize};
753 SmallVector<int> instDataB = {subgroupSize, maxBLen};
755 if (layoutKind == LayoutKind::InstData) {
757 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
759 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
761 dpasALayout = getSIMTLayoutInfoForDPASOperand(
762 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
763 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
764 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
767 if (operands.size() > 2) {
768 VectorType cTy = dpas.getAccType();
769 if (layoutKind == LayoutKind::InstData) {
770 const unsigned dataCLen = bTy.getShape().back();
772 uArchInstruction->getSupportedN(bTy.getElementType());
774 dataCLen, ArrayRef<unsigned>(supportedCLen));
777 "No suitable instruction multiple found for the given shape.");
778 SmallVector<int> instDataC = {maxALen, maxCLen};
780 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
782 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
783 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
785 dpas.setLayoutCdAttr(
786 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
789 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
791 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
794 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
795 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
796 if (operands.size() > 2) {
797 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
802void LayoutInfoPropagation::visitStoreNdOp(
803 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
804 ArrayRef<const LayoutInfoLattice *> results) {
806 LayoutInfo storeLayout;
807 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
808 if (hasParamsOfLayoutKind(anchorLayout)) {
809 storeLayout = LayoutInfo(anchorLayout);
812 const auto *uArchInstruction =
813 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
814 uArch->getInstruction(
815 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
816 VectorType dataTy = store.getValueType();
817 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
818 store.getValueType().getElementType());
820 store.emitWarning(
"No known block params found for the element type.");
821 auto [bWidth, bHeight, bCount] = blockWHC.value();
822 SmallVector<int> instData;
824 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
827 "No suitable instruction multiple found for the given shape.");
828 if (dataTy.getRank() == 1)
829 instData = {instWidth};
832 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
833 if (instHeight == -1)
835 "No suitable instruction multiple found for the given shape.");
836 instData = {instHeight, instWidth};
839 if (layoutKind == LayoutKind::InstData)
841 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
844 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
845 uArchInstruction->getPackedFormatBitSize());
847 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
851 for (LayoutInfoLattice *operand : operands)
852 propagateIfChanged(operand, operand->meet(storeLayout));
857void LayoutInfoPropagation::visitLoadNdOp(
858 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
859 ArrayRef<const LayoutInfoLattice *> results) {
861 LayoutInfo loadLayout;
862 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
863 if (hasParamsOfLayoutKind(anchorLayout)) {
864 loadLayout = LayoutInfo(anchorLayout);
867 LayoutInfo valueLayout = results[0]->getValue();
869 if (!valueLayout.isAssigned())
871 loadLayout = valueLayout;
875 if (
auto transpose =
load.getTranspose()) {
876 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
877 "LayoutInfoPropagation stage.");
878 loadLayout = valueLayout.transpose(transpose.value());
880 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
883 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
888void LayoutInfoPropagation::visitTransposeOp(
889 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
890 ArrayRef<const LayoutInfoLattice *> results) {
892 LayoutInfo resultLayout = results[0]->getValue();
893 if (!resultLayout.isAssigned())
895 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
897 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
902void LayoutInfoPropagation::visitVectorBitcastOp(
903 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
904 ArrayRef<const LayoutInfoLattice *> results) {
906 LayoutInfo resultLayout = results[0]->getValue();
907 if (!resultLayout.isAssigned())
909 int inElemTyBitWidth =
910 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
911 int outElemTyBitWidth =
912 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
914 if (inElemTyBitWidth == outElemTyBitWidth) {
915 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
919 auto resultLaneLayout = resultLayout.getLaneLayout();
920 auto resultLaneData = resultLayout.getLaneData();
922 bitcast.getResultVectorType(),
923 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
926 "Result vector type can not be evenly distributed across lanes.");
929 int64_t rank = bitcast.getSourceVectorType().getRank();
932 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
933 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
934 : outElemTyBitWidth / inElemTyBitWidth;
935 SmallVector<int> sourceLaneLayout =
936 resultLayout.getLaneLayout();
937 SmallVector<int> outData = resultLayout.getLaneData();
942 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
943 if (outInnerBitsPerLane < inElemTyBitWidth) {
945 "Narrowing bitcast with cross lane communication is not supported.");
950 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
951 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
952 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
953 "dimension other than "
954 "the innermost dimension.");
958 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
959 : outData[rank - 1] * bitCastRatio;
960 sourceLaneData.push_back(innerMostLaneData);
964 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
965 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
970void LayoutInfoPropagation::visitLoadGatherOp(
971 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
972 ArrayRef<const LayoutInfoLattice *> results) {
974 LayoutInfo loadLayout;
975 LayoutInfo maskLayout;
976 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
977 if (hasParamsOfLayoutKind(anchorLayout)) {
978 loadLayout = LayoutInfo(anchorLayout);
979 maskLayout = loadLayout;
983 VectorType payloadTy =
load.getValueType();
985 load.emitWarning(
"Not propagating, non-vector payload supplied.");
989 const int subgroupSize = uArch->getSubgroupSize();
990 SmallVector<int> instData{subgroupSize};
991 if (
auto chunkSize =
load.getChunkSize().value_or(0); chunkSize > 1)
992 instData.push_back(chunkSize);
993 else if (
auto srcTdescTy =
994 dyn_cast<xegpu::TensorDescType>(
load.getSourceType())) {
995 if (srcTdescTy.getChunkSizeAsInt() > 1)
996 instData.push_back(chunkSize);
999 if (layoutKind == LayoutKind::InstData)
1001 LayoutInfo(xegpu::LayoutAttr::get(
load.getContext(), instData));
1003 loadLayout = getDefaultSIMTLayoutInfo(
1004 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1008 maskLayout = getDefaultSIMTLayoutInfo(
load->getContext(), 1, subgroupSize);
1010 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1013 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1014 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1016 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
1017 if (
load.getOffsets())
1018 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1023void LayoutInfoPropagation::visitCreateDescOp(
1024 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1025 ArrayRef<const LayoutInfoLattice *> results) {
1026 LayoutInfo descLayout = results[0]->getValue();
1028 if (!descLayout.isAssigned())
1032 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1033 uArch->getSubgroupSize());
1034 propagateIfChanged(operands[1], operands[1]->meet(layout));
1039void LayoutInfoPropagation::visitStoreScatterOp(
1040 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1041 ArrayRef<const LayoutInfoLattice *> results) {
1043 LayoutInfo payloadLayout;
1044 LayoutInfo maskLayout;
1045 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
1046 if (hasParamsOfLayoutKind(anchorLayout)) {
1047 payloadLayout = LayoutInfo(anchorLayout);
1048 maskLayout = payloadLayout;
1053 VectorType payloadTy = storeScatter.getValueType();
1055 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1060 const int subgroupSize = uArch->getSubgroupSize();
1062 if (layoutKind == LayoutKind::InstData) {
1063 SmallVector<int> instData{subgroupSize};
1064 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0);
1066 instData.push_back(chunkSize);
1067 else if (
auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1068 storeScatter.getDestType())) {
1069 if (dstTdescTy.getChunkSizeAsInt() > 1)
1070 instData.push_back(chunkSize);
1072 payloadLayout = LayoutInfo(
1073 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1075 auto payloadShape = payloadTy.getShape();
1076 if (payloadShape.size() > 1)
1077 assert(payloadShape[0] == subgroupSize &&
1078 "Expected the first dimension of 2D tensor descriptor to be "
1081 payloadLayout = getDefaultSIMTLayoutInfo(
1082 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1087 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1089 storeScatter.setLayoutAttr(
1090 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1093 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1095 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1096 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1098 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1099 if (storeScatter.getOffsets())
1100 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1109class RunLayoutInfoPropagation {
1113 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) :
target(op) {
1114 SymbolTableCollection symbolTable;
1116 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
1120 LayoutInfo getLayoutInfo(Value val);
1122 void printAnalysisResult(llvm::raw_ostream &os);
1125 DataFlowSolver solver;
1130LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1131 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1134 return state->getValue();
1138void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1139 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1140 os <<
"function: " << funcOp.getName() <<
":\n";
1142 for (BlockArgument arg : funcOp.getArguments()) {
1143 LayoutInfo layout = getLayoutInfo(arg);
1144 os <<
"argument: " << arg <<
"\n";
1150 funcOp.walk([&](Operation *op) {
1156 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1162 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1163 LayoutInfo layout = getLayoutInfo(r);
1164 os <<
"layout for result #" << i <<
": ";
1171 SmallVector<FunctionOpInterface> funcOps;
1172 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1173 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1174 funcOps.push_back(funcOp);
1177 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1178 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1179 funcOps.push_back(gpuFuncOp);
1183 for (FunctionOpInterface funcOp : funcOps)
1184 printFunctionResult(funcOp);
1197 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1204 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1207 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1208 if (!layout &&
result.getNumUses() > 0) {
1209 op->
emitWarning(
"op has users but no layout assigned for its result");
1214 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1215 auto typeWithLayout = xegpu::TensorDescType::get(
1216 tensorDescTy.getContext(), tensorDescTy.getShape(),
1217 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1218 result.setType(typeWithLayout);
1252 mlir::RegionBranchTerminatorOpInterface terminator,
1255 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1260 branchOp.getSuccessorOperandInputMapping(mapping,
1262 for (
const auto &[successorOperand, successorInputs] : mapping) {
1263 for (
Value successorInput : successorInputs) {
1264 Type inputType = successorInput.getType();
1266 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1268 xegpu::DistributeLayoutAttr successorInputLayout =
1269 getLayoutOfValue(successorInput);
1270 xegpu::DistributeLayoutAttr successorOperandLayout =
1271 getLayoutOfValue(successorOperand->get());
1274 if (!successorOperandLayout) {
1275 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1276 "branch terminator: "
1277 << successorOperand->get() <<
"\n");
1281 if (successorInputLayout &&
1282 successorInputLayout != successorOperandLayout) {
1283 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1284 "operand forwarded as the argument: "
1285 << successorInputLayout <<
" vs "
1286 << successorOperandLayout <<
"\n");
1290 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1291 auto newTdescTy = xegpu::TensorDescType::get(
1292 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1293 tdescTy.getEncoding(), successorOperandLayout);
1294 successorInput.setType(newTdescTy);
1299 if (
auto result = dyn_cast<OpResult>(successorInput))
1308 mlir::FunctionOpInterface funcOp,
1313 Type argType = arg.getType();
1314 newArgTypes.push_back(argType);
1315 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1317 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1319 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1320 <<
" but got none.\n");
1323 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1324 auto newTdescTy = xegpu::TensorDescType::get(
1325 tensorDescTy.getContext(), tensorDescTy.getShape(),
1326 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1327 arg.setType(newTdescTy);
1328 newArgTypes.back() = newTdescTy;
1333 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1334 funcOp.getResultTypes()));
1339struct XeGPUPropagateLayoutPass final
1340 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1341 XeGPUPropagateLayoutPass() =
default;
1342 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1343 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1344 : XeGPUPropagateLayoutBase(
options) {}
1345 void runOnOperation()
override;
1350void XeGPUPropagateLayoutPass::runOnOperation() {
1351 LayoutKind layoutKind;
1352 if (this->layoutKind ==
"lane") {
1353 layoutKind = LayoutKind::Lane;
1354 }
else if (this->layoutKind ==
"inst") {
1355 layoutKind = LayoutKind::InstData;
1356 }
else if (this->layoutKind ==
"subgroup") {
1357 layoutKind = LayoutKind::Subgroup;
1359 getOperation()->emitError(
"Unsupported layout kind option: " +
1361 signalPassFailure();
1364 RunLayoutInfoPropagation
analysis(getOperation(), layoutKind);
1367 auto &os = llvm::outs();
1372 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1373 LayoutInfo layout =
analysis.getLayoutInfo(val);
1374 if (!layout.isAssigned())
1376 xegpu::DistributeLayoutAttr layoutAttr =
1377 cast<xegpu::DistributeLayoutAttr>(layout.get());
1378 if (layout.isSliceLayout())
1379 return cast<xegpu::SliceAttr>(layoutAttr);
1380 return cast<xegpu::LayoutAttr>(layoutAttr);
1384 Operation *op = getOperation();
1386 for (mlir::Operation &op : llvm::reverse(block->
getOperations())) {
1389 .Case<mlir::RegionBranchTerminatorOpInterface>(
1390 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1392 getXeGPULayoutForValue);
1394 .Case<mlir::FunctionOpInterface>(
1395 [&](mlir::FunctionOpInterface funcOp) {
1397 getXeGPULayoutForValue);
1399 .Default([&](Operation *op) {
1400 r =
updateOp(builder, op, getXeGPULayoutForValue);
1403 op.
emitError(
"Failed to update operation with the layout.");
1409 if (walkResult.wasInterrupted()) {
1410 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0