29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56enum class LayoutKind {
Lane, InstData };
84 xegpu::DistributeLayoutAttr storage =
nullptr;
87 LayoutInfo() =
default;
88 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
92 bool operator==(
const LayoutInfo &other)
const {
93 return this->isAssigned() == other.isAssigned();
96 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
98 static LayoutInfo join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
102 bool isAssigned()
const {
return storage !=
nullptr; }
112 bool isSliceLayout()
const {
115 return isa<xegpu::SliceAttr>(storage);
121 return storage.getRank();
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](
int64_t val) { return static_cast<int>(val); });
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](
int64_t val) { return static_cast<int>(val); });
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](
int64_t val) { return static_cast<int>(val); });
156 os <<
"Not assigned.";
160LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
161 if (!
lhs.isAssigned())
167LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
168 llvm_unreachable(
"Join should not be triggered by layout propagation.");
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
183 if (!withinRange || hasDuplicates) {
184 assert(
false &&
"Invalid permutation for transpose.");
191 for (
int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
194 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
196 if (getInstData().size())
197 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
213struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
228 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
238 unsigned rank,
int subgroupSize) {
239 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered =
false) {
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch);
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered =
false) {
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1,
uArch);
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize =
uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
312 "Expected int or float type in DPAS operands");
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
324 return getDefaultSIMTLayoutInfo(vectorTy,
uArch, packingSize);
336class LayoutInfoPropagation
339 LayoutKind layoutKind;
343 void visitStoreNdOp(xegpu::StoreNdOp store,
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
351 void visitLoadNdOp(xegpu::LoadNdOp
load,
355 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
359 void visitTransposeOp(vector::TransposeOp transpose,
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
383 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
393 LayoutKind layoutKind)
395 layoutKind(layoutKind) {}
402 void visitBranchOperand(
OpOperand &operand)
override {};
404 void visitCallOperand(
OpOperand &operand)
override {};
406 void visitExternalCall(CallOpInterface call,
407 ArrayRef<LayoutInfoLattice *> operands,
408 ArrayRef<const LayoutInfoLattice *> results)
override {
411 void setToExitState(LayoutInfoLattice *lattice)
override {
412 (void)lattice->meet(LayoutInfo());
417LogicalResult LayoutInfoPropagation::visitOperation(
418 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results) {
421 .Case<xegpu::DpasOp>(
422 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
423 .Case<xegpu::StoreNdOp>(
424 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
425 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
426 visitStoreScatterOp(storeScatterOp, operands, results);
428 .Case<xegpu::LoadNdOp>(
429 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
430 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
431 visitLoadGatherOp(loadGatherOp, operands, results);
433 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
434 visitCreateDescOp(createDescOp, operands, results);
436 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
437 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
439 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
440 visitPrefetchNdOp(prefetchNdOp, operands, results);
442 .Case<vector::TransposeOp>([&](
auto transposeOp) {
443 visitTransposeOp(transposeOp, operands, results);
445 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
446 visitVectorBitcastOp(bitcastOp, operands, results);
448 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
449 visitVectorMultiReductionOp(reductionOp, operands, results);
451 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
452 visitVectorBroadCastOp(broadcastOp, operands, results);
454 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
455 visitShapeCastOp(shapeCastOp, operands, results);
458 .Default([&](Operation *op) {
459 for (
const LayoutInfoLattice *resultInfo : results) {
460 if (!resultInfo->getValue().isAssigned())
462 for (
auto [operandInfo, operand] :
466 if (!isa<xegpu::TensorDescType, VectorType>(
467 operand.get().getType()))
470 meet(operandInfo, *resultInfo);
478void LayoutInfoPropagation::visitPrefetchNdOp(
479 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
480 ArrayRef<const LayoutInfoLattice *> results) {
483 auto tdescTy = prefetch.getTensorDescType();
486 const auto *uArchInstruction =
487 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
489 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
492 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
494 prefetch.emitWarning(
"No known block params found for the element type.");
495 auto [bWidth, bHeight, bCount] = blockWHC.value();
496 SmallVector<int> instData;
498 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
500 prefetch.emitWarning(
501 "No suitable instruction multiple found for the given shape.");
502 if (tdescTy.getRank() == 1)
503 instData = {instWidth};
506 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
507 if (instHeight == -1)
508 prefetch.emitWarning(
509 "No suitable instruction multiple found for the given shape.");
510 instData = {instHeight, instWidth};
512 LayoutInfo prefetchLayout;
513 if (layoutKind == LayoutKind::InstData)
515 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
517 prefetchLayout = getDefaultSIMTLayoutInfo(
518 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
521 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
524void LayoutInfoPropagation::visitVectorMultiReductionOp(
525 vector::MultiDimReductionOp reduction,
526 ArrayRef<LayoutInfoLattice *> operands,
527 ArrayRef<const LayoutInfoLattice *> results) {
529 LayoutInfo resultLayout = results[0]->getValue();
530 if (!resultLayout.isAssigned())
533 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
534 if (!resultTy || resultTy.getRank() != 1) {
535 reduction.emitWarning(
"Expecting output type to be 1D vector.");
541 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
543 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
545 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
548void LayoutInfoPropagation::visitVectorBroadCastOp(
549 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
550 ArrayRef<const LayoutInfoLattice *> results) {
552 LayoutInfo resultLayout = results[0]->getValue();
553 if (!resultLayout.isAssigned())
556 VectorType resultTy =
broadcast.getResultVectorType();
557 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
559 broadcast.emitWarning(
"Expecting source type to be a vector type.");
564 if (sourceTy.getRank() != resultTy.getRank()) {
565 broadcast.emitWarning(
"Expecting source and result to have same rank.");
569 if (broadcastUnitDims.size() != 1) {
570 broadcast.emitWarning(
"Expecting source type to be nD vector only with "
571 "one broadcasted dimension.");
575 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
578void LayoutInfoPropagation::visitShapeCastOp(
579 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
580 ArrayRef<const LayoutInfoLattice *> results) {
582 LayoutInfo resultLayout = results[0]->getValue();
583 if (!resultLayout.isAssigned())
585 VectorType sourceTy = shapeCast.getSourceVectorType();
586 VectorType resultTy = shapeCast.getResultVectorType();
590 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
591 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
594 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
595 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
596 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
598 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
603void LayoutInfoPropagation::visitUpdateNdOffsetOp(
604 xegpu::UpdateNdOffsetOp updateNdOffset,
605 ArrayRef<LayoutInfoLattice *> operands,
606 ArrayRef<const LayoutInfoLattice *> results) {
608 LayoutInfo resultLayout = results[0]->getValue();
609 if (!resultLayout.isAssigned())
612 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
616void LayoutInfoPropagation::visitDpasOp(
617 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
618 ArrayRef<const LayoutInfoLattice *> results) {
619 VectorType aTy = dpas.getLhsType();
620 VectorType bTy = dpas.getRhsType();
624 const auto *uArchInstruction =
625 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->
getInstruction(
626 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
628 const unsigned dataALen = aTy.getShape().front();
629 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
634 "No suitable instruction multiple found for the given shape.");
636 const unsigned dataBLen = bTy.getShape().back();
637 auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
642 "No suitable instruction multiple found for the given shape.");
643 SmallVector<int> instDataA = {maxALen, subgroupSize};
644 SmallVector<int> instDataB = {subgroupSize, maxBLen};
646 LayoutInfo dpasALayout;
647 LayoutInfo dpasBLayout;
648 LayoutInfo dpasCLayout;
650 if (layoutKind == LayoutKind::InstData) {
652 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
654 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
656 dpasALayout = getSIMTLayoutInfoForDPASOperand(
657 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
658 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
659 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
662 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
663 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
664 if (operands.size() > 2) {
665 VectorType cTy = dpas.getAccType();
666 const unsigned dataCLen = bTy.getShape().back();
667 auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
672 "No suitable instruction multiple found for the given shape.");
673 SmallVector<int> instDataC = {maxALen, maxCLen};
675 if (layoutKind == LayoutKind::InstData)
677 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
679 dpasCLayout = getSIMTLayoutInfoForDPASOperand(
680 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
682 propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
687void LayoutInfoPropagation::visitStoreNdOp(
688 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
689 ArrayRef<const LayoutInfoLattice *> results) {
692 const auto *uArchInstruction =
693 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
695 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
696 VectorType dataTy = store.getValueType();
697 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
698 store.getValueType().getElementType());
700 store.emitWarning(
"No known block params found for the element type.");
701 auto [bWidth, bHeight, bCount] = blockWHC.value();
702 SmallVector<int> instData;
704 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
707 "No suitable instruction multiple found for the given shape.");
708 if (dataTy.getRank() == 1)
709 instData = {instWidth};
712 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
713 if (instHeight == -1)
715 "No suitable instruction multiple found for the given shape.");
716 instData = {instHeight, instWidth};
719 LayoutInfo storeLayout;
720 if (layoutKind == LayoutKind::InstData)
722 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
725 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
726 uArchInstruction->getPackedFormatBitSize());
728 for (LayoutInfoLattice *operand : operands)
729 propagateIfChanged(operand, operand->meet(storeLayout));
734void LayoutInfoPropagation::visitLoadNdOp(
735 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
736 ArrayRef<const LayoutInfoLattice *> results) {
737 LayoutInfo valueLayout = results[0]->getValue();
739 if (!valueLayout.isAssigned())
741 LayoutInfo tensorDescLayout = valueLayout;
745 if (
auto transpose =
load.getTranspose()) {
746 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
747 "LayoutInfoPropagation stage.");
748 tensorDescLayout = valueLayout.transpose(transpose.value());
751 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
756void LayoutInfoPropagation::visitTransposeOp(
757 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
758 ArrayRef<const LayoutInfoLattice *> results) {
760 LayoutInfo resultLayout = results[0]->getValue();
761 if (!resultLayout.isAssigned())
763 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
765 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
770void LayoutInfoPropagation::visitVectorBitcastOp(
771 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
772 ArrayRef<const LayoutInfoLattice *> results) {
774 LayoutInfo resultLayout = results[0]->getValue();
775 if (!resultLayout.isAssigned())
777 int inElemTyBitWidth =
778 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
779 int outElemTyBitWidth =
780 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
782 if (inElemTyBitWidth == outElemTyBitWidth) {
783 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
787 auto resultLaneLayout = resultLayout.getLaneLayout();
788 auto resultLaneData = resultLayout.getLaneData();
790 bitcast.getResultVectorType(),
791 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
794 "Result vector type can not be evenly distributed across lanes.");
797 int64_t rank = bitcast.getSourceVectorType().getRank();
800 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
801 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
802 : outElemTyBitWidth / inElemTyBitWidth;
803 SmallVector<int> sourceLaneLayout =
804 resultLayout.getLaneLayout();
805 SmallVector<int> outData = resultLayout.getLaneData();
810 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
811 if (outInnerBitsPerLane < inElemTyBitWidth) {
813 "Narrowing bitcast with cross lane communication is not supported.");
818 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
819 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
820 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
821 "dimension other than "
822 "the innermost dimension.");
826 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
827 : outData[rank - 1] * bitCastRatio;
828 sourceLaneData.push_back(innerMostLaneData);
832 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
833 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
838void LayoutInfoPropagation::visitLoadGatherOp(
839 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
840 ArrayRef<const LayoutInfoLattice *> results) {
842 auto payloadTy = dyn_cast<VectorType>(
load.getValueType());
844 load.emitWarning(
"Not propagating, non-vector payload supplied.");
849 SmallVector<int> instData{subgroupSize};
850 if (
auto chunkSize =
load.getChunkSize().value_or(0); chunkSize > 1)
851 instData.push_back(chunkSize);
852 else if (
auto srcTdescTy =
853 dyn_cast<xegpu::TensorDescType>(
load.getSourceType())) {
854 if (srcTdescTy.getChunkSizeAsInt() > 1)
855 instData.push_back(chunkSize);
858 if (layoutKind == LayoutKind::InstData)
859 layout = LayoutInfo(xegpu::LayoutAttr::get(
load.getContext(), instData));
861 layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
866 LayoutInfo maskLayout =
867 getDefaultSIMTLayoutInfo(
load->getContext(), 1, subgroupSize);
870 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
871 propagateIfChanged(operands[0], operands[0]->meet(layout));
873 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
874 if (
load.getOffsets())
875 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
880void LayoutInfoPropagation::visitCreateDescOp(
881 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
882 ArrayRef<const LayoutInfoLattice *> results) {
883 LayoutInfo descLayout = results[0]->getValue();
885 if (!descLayout.isAssigned())
889 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
891 propagateIfChanged(operands[1], operands[1]->meet(layout));
896void LayoutInfoPropagation::visitStoreScatterOp(
897 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
898 ArrayRef<const LayoutInfoLattice *> results) {
902 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
904 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
907 LayoutInfo payloadLayout;
911 if (
auto layout = storeScatter.getLayoutAttr()) {
912 payloadLayout = LayoutInfo(layout);
914 if (layoutKind == LayoutKind::InstData) {
915 SmallVector<int> instData{subgroupSize};
916 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0);
918 instData.push_back(chunkSize);
919 else if (
auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
920 storeScatter.getDestType())) {
921 if (dstTdescTy.getChunkSizeAsInt() > 1)
922 instData.push_back(chunkSize);
924 payloadLayout = LayoutInfo(
925 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
927 auto payloadShape = payloadTy.getShape();
928 if (payloadShape.size() > 1)
929 assert(payloadShape[0] == subgroupSize &&
930 "Expected the first dimension of 2D tensor descriptor to be "
933 payloadLayout = getDefaultSIMTLayoutInfo(
939 LayoutInfo maskLayout =
940 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
942 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
944 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
945 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
947 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
948 if (storeScatter.getOffsets())
949 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
958class RunLayoutInfoPropagation {
962 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) :
target(op) {
963 SymbolTableCollection symbolTable;
965 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
969 LayoutInfo getLayoutInfo(Value val);
971 void printAnalysisResult(llvm::raw_ostream &os);
974 DataFlowSolver solver;
979LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
980 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
983 return state->getValue();
987void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
988 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
989 os <<
"function: " << funcOp.getName() <<
":\n";
991 for (BlockArgument arg : funcOp.getArguments()) {
992 LayoutInfo layout = getLayoutInfo(arg);
993 os <<
"argument: " << arg <<
"\n";
999 funcOp.walk([&](Operation *op) {
1005 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1011 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1012 LayoutInfo layout = getLayoutInfo(r);
1013 os <<
"layout for result #" << i <<
": ";
1020 SmallVector<FunctionOpInterface> funcOps;
1021 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1022 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1023 funcOps.push_back(funcOp);
1026 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1027 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1028 funcOps.push_back(gpuFuncOp);
1032 for (FunctionOpInterface funcOp : funcOps)
1033 printFunctionResult(funcOp);
1046 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1053 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1056 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1057 if (!layout &&
result.getNumUses() > 0) {
1058 op->
emitWarning(
"op has users but no layout assigned for its result");
1063 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1064 auto typeWithLayout = xegpu::TensorDescType::get(
1065 tensorDescTy.getContext(), tensorDescTy.getShape(),
1066 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1067 result.setType(typeWithLayout);
1101 mlir::RegionBranchTerminatorOpInterface terminator,
1104 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1110 terminator.getSuccessorRegions(operands, successors);
1114 terminator.getSuccessorOperands(successor);
1116 for (
auto [successorOperand, successorInput] :
1117 llvm::zip(successorOperands, successorInputs)) {
1118 Type inputType = successorInput.getType();
1120 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1122 xegpu::DistributeLayoutAttr successorInputLayout =
1123 getLayoutOfValue(successorInput);
1124 xegpu::DistributeLayoutAttr successorOperandLayout =
1125 getLayoutOfValue(successorOperand);
1128 if (!successorOperandLayout) {
1129 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1130 "branch terminator: "
1131 << successorOperand <<
"\n");
1135 if (successorInputLayout &&
1136 successorInputLayout != successorOperandLayout) {
1137 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1138 "operand forwarded as the argument: "
1139 << successorInputLayout <<
" vs "
1140 << successorOperandLayout <<
"\n");
1144 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1145 auto newTdescTy = xegpu::TensorDescType::get(
1146 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1147 tdescTy.getEncoding(), successorOperandLayout);
1148 successorInput.setType(newTdescTy);
1153 if (
auto result = dyn_cast<OpResult>(successorInput))
1162 mlir::FunctionOpInterface funcOp,
1167 Type argType = arg.getType();
1168 newArgTypes.push_back(argType);
1169 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1171 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1173 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1174 <<
" but got none.\n");
1177 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1178 auto newTdescTy = xegpu::TensorDescType::get(
1179 tensorDescTy.getContext(), tensorDescTy.getShape(),
1180 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1181 arg.setType(newTdescTy);
1182 newArgTypes.back() = newTdescTy;
1187 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1188 funcOp.getResultTypes()));
1193struct XeGPUPropagateLayoutPass final
1194 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1195 XeGPUPropagateLayoutPass() =
default;
1196 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1197 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1198 : XeGPUPropagateLayoutBase(
options) {}
1199 void runOnOperation()
override;
1204void XeGPUPropagateLayoutPass::runOnOperation() {
1205 LayoutKind layoutKind;
1206 if (this->layoutKind ==
"lane") {
1207 layoutKind = LayoutKind::Lane;
1208 }
else if (this->layoutKind ==
"inst") {
1209 layoutKind = LayoutKind::InstData;
1211 getOperation()->emitError(
"Unsupported layout kind option: " +
1213 signalPassFailure();
1216 RunLayoutInfoPropagation
analysis(getOperation(), layoutKind);
1219 auto &os = llvm::outs();
1224 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1225 LayoutInfo layout =
analysis.getLayoutInfo(val);
1226 if (!layout.isAssigned())
1228 xegpu::DistributeLayoutAttr layoutAttr =
1229 cast<xegpu::DistributeLayoutAttr>(layout.get());
1230 if (layout.isSliceLayout())
1231 return cast<xegpu::SliceAttr>(layoutAttr);
1232 return cast<xegpu::LayoutAttr>(layoutAttr);
1236 Operation *op = getOperation();
1238 for (mlir::Operation &op : llvm::reverse(block->
getOperations())) {
1241 .Case<mlir::RegionBranchTerminatorOpInterface>(
1242 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1244 getXeGPULayoutForValue);
1246 .Case<mlir::FunctionOpInterface>(
1247 [&](mlir::FunctionOpInterface funcOp) {
1249 getXeGPULayoutForValue);
1251 .Default([&](Operation *op) {
1252 r =
updateOp(builder, op, getXeGPULayoutForValue);
1255 op.
emitError(
"Failed to update operation with the layout.");
1261 if (walkResult.wasInterrupted()) {
1262 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const