29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56enum class LayoutKind {
Lane, InstData };
84 xegpu::DistributeLayoutAttr storage =
nullptr;
87 LayoutInfo() =
default;
88 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
92 bool operator==(
const LayoutInfo &other)
const {
93 return this->isAssigned() == other.isAssigned();
96 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
98 static LayoutInfo join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
102 bool isAssigned()
const {
return storage !=
nullptr; }
112 bool isSliceLayout()
const {
115 return isa<xegpu::SliceAttr>(storage);
121 return storage.getRank();
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](
int64_t val) { return static_cast<int>(val); });
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](
int64_t val) { return static_cast<int>(val); });
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](
int64_t val) { return static_cast<int>(val); });
156 os <<
"Not assigned.";
160LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
161 if (!
lhs.isAssigned())
167LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
168 llvm_unreachable(
"Join should not be triggered by layout propagation.");
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
183 if (!withinRange || hasDuplicates) {
184 assert(
false &&
"Invalid permutation for transpose.");
191 for (
int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
194 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
196 if (getInstData().size())
197 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
213struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
228 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
238 unsigned rank,
int subgroupSize) {
239 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered =
false) {
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch);
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered =
false) {
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1,
uArch);
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize =
uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
312 "Expected int or float type in DPAS operands");
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
324 return getDefaultSIMTLayoutInfo(vectorTy,
uArch, packingSize);
336class LayoutInfoPropagation
339 LayoutKind layoutKind;
343 void visitStoreNdOp(xegpu::StoreNdOp store,
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
351 void visitLoadNdOp(xegpu::LoadNdOp
load,
355 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
359 void visitTransposeOp(vector::TransposeOp transpose,
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
383 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
390 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
395 LayoutKind layoutKind)
397 layoutKind(layoutKind) {}
404 void visitBranchOperand(
OpOperand &operand)
override {};
406 void visitCallOperand(
OpOperand &operand)
override {};
408 void visitExternalCall(CallOpInterface call,
409 ArrayRef<LayoutInfoLattice *> operands,
410 ArrayRef<const LayoutInfoLattice *> results)
override {
413 void setToExitState(LayoutInfoLattice *lattice)
override {
414 (void)lattice->meet(LayoutInfo());
419LogicalResult LayoutInfoPropagation::visitOperation(
420 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
421 ArrayRef<const LayoutInfoLattice *> results) {
423 .Case<xegpu::DpasOp>(
424 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
425 .Case<xegpu::StoreNdOp>(
426 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
427 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
428 visitStoreScatterOp(storeScatterOp, operands, results);
430 .Case<xegpu::LoadNdOp>(
431 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
432 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
433 visitLoadGatherOp(loadGatherOp, operands, results);
435 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
436 visitCreateDescOp(createDescOp, operands, results);
438 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
439 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
441 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
442 visitPrefetchNdOp(prefetchNdOp, operands, results);
444 .Case<vector::TransposeOp>([&](
auto transposeOp) {
445 visitTransposeOp(transposeOp, operands, results);
447 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
448 visitVectorBitcastOp(bitcastOp, operands, results);
450 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
451 visitVectorMultiReductionOp(reductionOp, operands, results);
453 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
454 visitVectorBroadCastOp(broadcastOp, operands, results);
456 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
457 visitShapeCastOp(shapeCastOp, operands, results);
460 .Default([&](Operation *op) {
461 for (
const LayoutInfoLattice *resultInfo : results) {
462 if (!resultInfo->getValue().isAssigned())
464 for (
auto [operandInfo, operand] :
468 if (!isa<xegpu::TensorDescType, VectorType>(
469 operand.get().getType()))
472 meet(operandInfo, *resultInfo);
480bool LayoutInfoPropagation::hasParamsOfLayoutKind(
481 xegpu::DistributeLayoutAttr anchorLayout) {
482 if (anchorLayout ==
nullptr) {
485 if (layoutKind == LayoutKind::InstData) {
486 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
487 }
else if (layoutKind == LayoutKind::Lane) {
488 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
489 anchorLayout.getEffectiveLaneDataAsInt().empty());
494void LayoutInfoPropagation::visitPrefetchNdOp(
495 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
496 ArrayRef<const LayoutInfoLattice *> results) {
498 LayoutInfo prefetchLayout;
499 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
500 if (hasParamsOfLayoutKind(anchorLayout)) {
501 prefetchLayout = LayoutInfo(anchorLayout);
505 auto tdescTy = prefetch.getTensorDescType();
508 const auto *uArchInstruction =
509 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
511 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
514 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
516 prefetch.emitWarning(
"No known block params found for the element type.");
517 auto [bWidth, bHeight, bCount] = blockWHC.value();
518 SmallVector<int> instData;
520 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
522 prefetch.emitWarning(
523 "No suitable instruction multiple found for the given shape.");
524 if (tdescTy.getRank() == 1)
525 instData = {instWidth};
528 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
529 if (instHeight == -1)
530 prefetch.emitWarning(
531 "No suitable instruction multiple found for the given shape.");
532 instData = {instHeight, instWidth};
535 if (layoutKind == LayoutKind::InstData)
537 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
539 prefetchLayout = getDefaultSIMTLayoutInfo(
540 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
542 prefetch.setLayoutAttr(
543 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
546 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
549void LayoutInfoPropagation::visitVectorMultiReductionOp(
550 vector::MultiDimReductionOp reduction,
551 ArrayRef<LayoutInfoLattice *> operands,
552 ArrayRef<const LayoutInfoLattice *> results) {
554 LayoutInfo resultLayout = results[0]->getValue();
555 if (!resultLayout.isAssigned())
558 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
559 if (!resultTy || resultTy.getRank() != 1) {
560 reduction.emitWarning(
"Expecting output type to be 1D vector.");
566 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
568 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
570 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
573void LayoutInfoPropagation::visitVectorBroadCastOp(
574 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
575 ArrayRef<const LayoutInfoLattice *> results) {
577 LayoutInfo resultLayout = results[0]->getValue();
578 if (!resultLayout.isAssigned())
581 VectorType resultTy =
broadcast.getResultVectorType();
582 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
588 if (sourceTy.getRank() != resultTy.getRank()) {
589 auto sourceDims = sourceTy.getShape();
590 auto resultDims = resultTy.getShape();
591 SmallVector<int64_t> bcastDims;
592 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
594 for (
int i = 0; i < dimDiff; i++)
595 bcastDims.push_back(i);
599 for (
size_t i = 0; i < sourceDims.size(); i++)
600 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
601 bcastDims.push_back(i + dimDiff);
604 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
606 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
609 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
614 resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
615 .setUnitDimData(broadcastUnitDims);
616 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
619void LayoutInfoPropagation::visitShapeCastOp(
620 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
621 ArrayRef<const LayoutInfoLattice *> results) {
623 LayoutInfo resultLayout = results[0]->getValue();
624 if (!resultLayout.isAssigned())
626 VectorType sourceTy = shapeCast.getSourceVectorType();
627 VectorType resultTy = shapeCast.getResultVectorType();
631 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
632 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
635 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
636 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
637 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
639 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
644void LayoutInfoPropagation::visitUpdateNdOffsetOp(
645 xegpu::UpdateNdOffsetOp updateNdOffset,
646 ArrayRef<LayoutInfoLattice *> operands,
647 ArrayRef<const LayoutInfoLattice *> results) {
649 LayoutInfo resultLayout = results[0]->getValue();
650 if (!resultLayout.isAssigned())
653 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
657void LayoutInfoPropagation::visitDpasOp(
658 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
659 ArrayRef<const LayoutInfoLattice *> results) {
661 LayoutInfo dpasALayout;
662 LayoutInfo dpasBLayout;
663 LayoutInfo dpasCDLayout;
665 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
666 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
667 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
668 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
669 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
670 "Expected anchor layout for DPAS A operand.");
671 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
672 "Expected anchor layout for DPAS B operand.");
673 dpasALayout = LayoutInfo(anchorLayoutA);
674 dpasBLayout = LayoutInfo(anchorLayoutB);
675 dpasCDLayout = LayoutInfo(anchorLayoutCD);
679 VectorType aTy = dpas.getLhsType();
680 VectorType bTy = dpas.getRhsType();
684 const auto *uArchInstruction =
685 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->
getInstruction(
686 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
688 const unsigned dataALen = aTy.getShape().front();
689 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
694 "No suitable instruction multiple found for the given shape.");
696 const unsigned dataBLen = bTy.getShape().back();
697 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
704 "No suitable instruction multiple found for the given shape.");
705 SmallVector<int> instDataA = {maxALen, subgroupSize};
706 SmallVector<int> instDataB = {subgroupSize, maxBLen};
708 if (layoutKind == LayoutKind::InstData) {
710 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
712 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
714 dpasALayout = getSIMTLayoutInfoForDPASOperand(
715 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
716 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
717 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
720 if (operands.size() > 2) {
721 VectorType cTy = dpas.getAccType();
722 if (layoutKind == LayoutKind::InstData) {
723 const unsigned dataCLen = bTy.getShape().back();
725 uArchInstruction->getSupportedN(bTy.getElementType());
727 dataCLen, ArrayRef<unsigned>(supportedCLen));
730 "No suitable instruction multiple found for the given shape.");
731 SmallVector<int> instDataC = {maxALen, maxCLen};
733 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
735 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
736 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
738 dpas.setLayoutCdAttr(
739 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
742 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
744 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
747 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
748 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
749 if (operands.size() > 2) {
750 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
755void LayoutInfoPropagation::visitStoreNdOp(
756 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
757 ArrayRef<const LayoutInfoLattice *> results) {
759 LayoutInfo storeLayout;
760 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
761 if (hasParamsOfLayoutKind(anchorLayout)) {
762 storeLayout = LayoutInfo(anchorLayout);
765 const auto *uArchInstruction =
766 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
768 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
769 VectorType dataTy = store.getValueType();
770 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
771 store.getValueType().getElementType());
773 store.emitWarning(
"No known block params found for the element type.");
774 auto [bWidth, bHeight, bCount] = blockWHC.value();
775 SmallVector<int> instData;
777 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
780 "No suitable instruction multiple found for the given shape.");
781 if (dataTy.getRank() == 1)
782 instData = {instWidth};
785 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
786 if (instHeight == -1)
788 "No suitable instruction multiple found for the given shape.");
789 instData = {instHeight, instWidth};
792 if (layoutKind == LayoutKind::InstData)
794 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
797 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
798 uArchInstruction->getPackedFormatBitSize());
800 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
804 for (LayoutInfoLattice *operand : operands)
805 propagateIfChanged(operand, operand->meet(storeLayout));
810void LayoutInfoPropagation::visitLoadNdOp(
811 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
812 ArrayRef<const LayoutInfoLattice *> results) {
814 LayoutInfo loadLayout;
815 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
816 if (hasParamsOfLayoutKind(anchorLayout)) {
817 loadLayout = LayoutInfo(anchorLayout);
820 LayoutInfo valueLayout = results[0]->getValue();
822 if (!valueLayout.isAssigned())
824 loadLayout = valueLayout;
828 if (
auto transpose =
load.getTranspose()) {
829 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
830 "LayoutInfoPropagation stage.");
831 loadLayout = valueLayout.transpose(transpose.value());
833 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
836 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
841void LayoutInfoPropagation::visitTransposeOp(
842 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
843 ArrayRef<const LayoutInfoLattice *> results) {
845 LayoutInfo resultLayout = results[0]->getValue();
846 if (!resultLayout.isAssigned())
848 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
850 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
855void LayoutInfoPropagation::visitVectorBitcastOp(
856 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
857 ArrayRef<const LayoutInfoLattice *> results) {
859 LayoutInfo resultLayout = results[0]->getValue();
860 if (!resultLayout.isAssigned())
862 int inElemTyBitWidth =
863 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
864 int outElemTyBitWidth =
865 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
867 if (inElemTyBitWidth == outElemTyBitWidth) {
868 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
872 auto resultLaneLayout = resultLayout.getLaneLayout();
873 auto resultLaneData = resultLayout.getLaneData();
875 bitcast.getResultVectorType(),
876 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
879 "Result vector type can not be evenly distributed across lanes.");
882 int64_t rank = bitcast.getSourceVectorType().getRank();
885 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
886 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
887 : outElemTyBitWidth / inElemTyBitWidth;
888 SmallVector<int> sourceLaneLayout =
889 resultLayout.getLaneLayout();
890 SmallVector<int> outData = resultLayout.getLaneData();
895 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
896 if (outInnerBitsPerLane < inElemTyBitWidth) {
898 "Narrowing bitcast with cross lane communication is not supported.");
903 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
904 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
905 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
906 "dimension other than "
907 "the innermost dimension.");
911 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
912 : outData[rank - 1] * bitCastRatio;
913 sourceLaneData.push_back(innerMostLaneData);
917 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
918 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
923void LayoutInfoPropagation::visitLoadGatherOp(
924 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
925 ArrayRef<const LayoutInfoLattice *> results) {
927 LayoutInfo loadLayout;
928 LayoutInfo maskLayout;
929 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
930 if (hasParamsOfLayoutKind(anchorLayout)) {
931 loadLayout = LayoutInfo(anchorLayout);
932 maskLayout = loadLayout;
936 VectorType payloadTy =
load.getValueType();
938 load.emitWarning(
"Not propagating, non-vector payload supplied.");
943 SmallVector<int> instData{subgroupSize};
944 if (
auto chunkSize =
load.getChunkSize().value_or(0); chunkSize > 1)
945 instData.push_back(chunkSize);
946 else if (
auto srcTdescTy =
947 dyn_cast<xegpu::TensorDescType>(
load.getSourceType())) {
948 if (srcTdescTy.getChunkSizeAsInt() > 1)
949 instData.push_back(chunkSize);
952 if (layoutKind == LayoutKind::InstData)
954 LayoutInfo(xegpu::LayoutAttr::get(
load.getContext(), instData));
956 loadLayout = getDefaultSIMTLayoutInfo(
961 maskLayout = getDefaultSIMTLayoutInfo(
load->getContext(), 1, subgroupSize);
963 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
966 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
967 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
969 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
970 if (
load.getOffsets())
971 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
976void LayoutInfoPropagation::visitCreateDescOp(
977 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
978 ArrayRef<const LayoutInfoLattice *> results) {
979 LayoutInfo descLayout = results[0]->getValue();
981 if (!descLayout.isAssigned())
985 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
987 propagateIfChanged(operands[1], operands[1]->meet(layout));
992void LayoutInfoPropagation::visitStoreScatterOp(
993 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
994 ArrayRef<const LayoutInfoLattice *> results) {
996 LayoutInfo payloadLayout;
997 LayoutInfo maskLayout;
998 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
999 if (hasParamsOfLayoutKind(anchorLayout)) {
1000 payloadLayout = LayoutInfo(anchorLayout);
1001 maskLayout = payloadLayout;
1006 VectorType payloadTy = storeScatter.getValueType();
1008 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1015 if (layoutKind == LayoutKind::InstData) {
1016 SmallVector<int> instData{subgroupSize};
1017 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0);
1019 instData.push_back(chunkSize);
1020 else if (
auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1021 storeScatter.getDestType())) {
1022 if (dstTdescTy.getChunkSizeAsInt() > 1)
1023 instData.push_back(chunkSize);
1025 payloadLayout = LayoutInfo(
1026 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1028 auto payloadShape = payloadTy.getShape();
1029 if (payloadShape.size() > 1)
1030 assert(payloadShape[0] == subgroupSize &&
1031 "Expected the first dimension of 2D tensor descriptor to be "
1034 payloadLayout = getDefaultSIMTLayoutInfo(
1040 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1042 storeScatter.setLayoutAttr(
1043 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1046 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1048 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1049 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1051 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1052 if (storeScatter.getOffsets())
1053 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1062class RunLayoutInfoPropagation {
1066 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) :
target(op) {
1067 SymbolTableCollection symbolTable;
1069 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
1073 LayoutInfo getLayoutInfo(Value val);
1075 void printAnalysisResult(llvm::raw_ostream &os);
1078 DataFlowSolver solver;
1083LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1084 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1087 return state->getValue();
1091void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1092 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1093 os <<
"function: " << funcOp.getName() <<
":\n";
1095 for (BlockArgument arg : funcOp.getArguments()) {
1096 LayoutInfo layout = getLayoutInfo(arg);
1097 os <<
"argument: " << arg <<
"\n";
1103 funcOp.walk([&](Operation *op) {
1109 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1115 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1116 LayoutInfo layout = getLayoutInfo(r);
1117 os <<
"layout for result #" << i <<
": ";
1124 SmallVector<FunctionOpInterface> funcOps;
1125 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1126 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1127 funcOps.push_back(funcOp);
1130 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1131 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1132 funcOps.push_back(gpuFuncOp);
1136 for (FunctionOpInterface funcOp : funcOps)
1137 printFunctionResult(funcOp);
1150 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1157 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1160 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1161 if (!layout &&
result.getNumUses() > 0) {
1162 op->
emitWarning(
"op has users but no layout assigned for its result");
1167 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1168 auto typeWithLayout = xegpu::TensorDescType::get(
1169 tensorDescTy.getContext(), tensorDescTy.getShape(),
1170 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1171 result.setType(typeWithLayout);
1205 mlir::RegionBranchTerminatorOpInterface terminator,
1208 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1214 terminator.getSuccessorRegions(operands, successors);
1218 terminator.getSuccessorOperands(successor);
1220 for (
auto [successorOperand, successorInput] :
1221 llvm::zip(successorOperands, successorInputs)) {
1222 Type inputType = successorInput.getType();
1224 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1226 xegpu::DistributeLayoutAttr successorInputLayout =
1227 getLayoutOfValue(successorInput);
1228 xegpu::DistributeLayoutAttr successorOperandLayout =
1229 getLayoutOfValue(successorOperand);
1232 if (!successorOperandLayout) {
1233 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1234 "branch terminator: "
1235 << successorOperand <<
"\n");
1239 if (successorInputLayout &&
1240 successorInputLayout != successorOperandLayout) {
1241 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1242 "operand forwarded as the argument: "
1243 << successorInputLayout <<
" vs "
1244 << successorOperandLayout <<
"\n");
1248 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1249 auto newTdescTy = xegpu::TensorDescType::get(
1250 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1251 tdescTy.getEncoding(), successorOperandLayout);
1252 successorInput.setType(newTdescTy);
1257 if (
auto result = dyn_cast<OpResult>(successorInput))
1266 mlir::FunctionOpInterface funcOp,
1271 Type argType = arg.getType();
1272 newArgTypes.push_back(argType);
1273 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1275 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1277 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1278 <<
" but got none.\n");
1281 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1282 auto newTdescTy = xegpu::TensorDescType::get(
1283 tensorDescTy.getContext(), tensorDescTy.getShape(),
1284 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1285 arg.setType(newTdescTy);
1286 newArgTypes.back() = newTdescTy;
1291 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1292 funcOp.getResultTypes()));
1297struct XeGPUPropagateLayoutPass final
1298 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1299 XeGPUPropagateLayoutPass() =
default;
1300 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1301 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1302 : XeGPUPropagateLayoutBase(
options) {}
1303 void runOnOperation()
override;
1308void XeGPUPropagateLayoutPass::runOnOperation() {
1309 LayoutKind layoutKind;
1310 if (this->layoutKind ==
"lane") {
1311 layoutKind = LayoutKind::Lane;
1312 }
else if (this->layoutKind ==
"inst") {
1313 layoutKind = LayoutKind::InstData;
1315 getOperation()->emitError(
"Unsupported layout kind option: " +
1317 signalPassFailure();
1320 RunLayoutInfoPropagation
analysis(getOperation(), layoutKind);
1323 auto &os = llvm::outs();
1328 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1329 LayoutInfo layout =
analysis.getLayoutInfo(val);
1330 if (!layout.isAssigned())
1332 xegpu::DistributeLayoutAttr layoutAttr =
1333 cast<xegpu::DistributeLayoutAttr>(layout.get());
1334 if (layout.isSliceLayout())
1335 return cast<xegpu::SliceAttr>(layoutAttr);
1336 return cast<xegpu::LayoutAttr>(layoutAttr);
1340 Operation *op = getOperation();
1342 for (mlir::Operation &op : llvm::reverse(block->
getOperations())) {
1345 .Case<mlir::RegionBranchTerminatorOpInterface>(
1346 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1348 getXeGPULayoutForValue);
1350 .Case<mlir::FunctionOpInterface>(
1351 [&](mlir::FunctionOpInterface funcOp) {
1353 getXeGPULayoutForValue);
1355 .Default([&](Operation *op) {
1356 r =
updateOp(builder, op, getXeGPULayoutForValue);
1359 op.
emitError(
"Failed to update operation with the layout.");
1365 if (walkResult.wasInterrupted()) {
1366 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const