29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56enum class LayoutKind {
Lane, InstData };
84 xegpu::DistributeLayoutAttr storage =
nullptr;
87 LayoutInfo() =
default;
88 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
92 bool operator==(
const LayoutInfo &other)
const {
93 return this->isAssigned() == other.isAssigned();
96 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
98 static LayoutInfo join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
102 bool isAssigned()
const {
return storage !=
nullptr; }
112 bool isSliceLayout()
const {
115 return isa<xegpu::SliceAttr>(storage);
121 return storage.getRank();
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](
int64_t val) { return static_cast<int>(val); });
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](
int64_t val) { return static_cast<int>(val); });
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](
int64_t val) { return static_cast<int>(val); });
156 os <<
"Not assigned.";
160LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
161 if (!
lhs.isAssigned())
167LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
168 llvm_unreachable(
"Join should not be triggered by layout propagation.");
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
183 if (!withinRange || hasDuplicates) {
184 assert(
false &&
"Invalid permutation for transpose.");
191 for (
int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
194 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
196 if (getInstData().size())
197 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
213struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
228 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
238 unsigned rank,
int subgroupSize) {
239 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered =
false) {
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch);
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered =
false) {
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1,
uArch);
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize =
uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
312 "Expected int or float type in DPAS operands");
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
324 return getDefaultSIMTLayoutInfo(vectorTy,
uArch, packingSize);
336class LayoutInfoPropagation
339 LayoutKind layoutKind;
343 void visitStoreNdOp(xegpu::StoreNdOp store,
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
351 void visitLoadNdOp(xegpu::LoadNdOp
load,
355 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
359 void visitTransposeOp(vector::TransposeOp transpose,
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
383 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
393 LayoutKind layoutKind)
395 layoutKind(layoutKind) {}
402 void visitBranchOperand(
OpOperand &operand)
override {};
404 void visitCallOperand(
OpOperand &operand)
override {};
406 void visitExternalCall(CallOpInterface call,
407 ArrayRef<LayoutInfoLattice *> operands,
408 ArrayRef<const LayoutInfoLattice *> results)
override {
411 void setToExitState(LayoutInfoLattice *lattice)
override {
412 (void)lattice->meet(LayoutInfo());
417LogicalResult LayoutInfoPropagation::visitOperation(
418 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results) {
421 .Case<xegpu::DpasOp>(
422 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
423 .Case<xegpu::StoreNdOp>(
424 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
425 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
426 visitStoreScatterOp(storeScatterOp, operands, results);
428 .Case<xegpu::LoadNdOp>(
429 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
430 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
431 visitLoadGatherOp(loadGatherOp, operands, results);
433 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
434 visitCreateDescOp(createDescOp, operands, results);
436 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
437 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
439 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
440 visitPrefetchNdOp(prefetchNdOp, operands, results);
442 .Case<vector::TransposeOp>([&](
auto transposeOp) {
443 visitTransposeOp(transposeOp, operands, results);
445 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
446 visitVectorBitcastOp(bitcastOp, operands, results);
448 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
449 visitVectorMultiReductionOp(reductionOp, operands, results);
451 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
452 visitVectorBroadCastOp(broadcastOp, operands, results);
454 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
455 visitShapeCastOp(shapeCastOp, operands, results);
458 .Default([&](Operation *op) {
459 for (
const LayoutInfoLattice *resultInfo : results) {
460 if (!resultInfo->getValue().isAssigned())
462 for (
auto [operandInfo, operand] :
466 if (!isa<xegpu::TensorDescType, VectorType>(
467 operand.get().getType()))
470 meet(operandInfo, *resultInfo);
478void LayoutInfoPropagation::visitPrefetchNdOp(
479 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
480 ArrayRef<const LayoutInfoLattice *> results) {
483 auto tdescTy = prefetch.getTensorDescType();
486 const auto *uArchInstruction =
487 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
489 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
492 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
494 prefetch.emitWarning(
"No known block params found for the element type.");
495 auto [bWidth, bHeight, bCount] = blockWHC.value();
496 SmallVector<int> instData;
498 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
501 prefetch.emitWarning(
502 "No suitable instruction multiple found for the given shape.");
503 if (tdescTy.getRank() == 1)
504 instData = {instWidth};
507 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
508 if (instHeight == -1)
509 prefetch.emitWarning(
510 "No suitable instruction multiple found for the given shape.");
511 instData = {instHeight, instWidth};
513 LayoutInfo prefetchLayout;
514 if (layoutKind == LayoutKind::InstData)
516 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
518 prefetchLayout = getDefaultSIMTLayoutInfo(
519 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
522 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
525void LayoutInfoPropagation::visitVectorMultiReductionOp(
526 vector::MultiDimReductionOp reduction,
527 ArrayRef<LayoutInfoLattice *> operands,
528 ArrayRef<const LayoutInfoLattice *> results) {
530 LayoutInfo resultLayout = results[0]->getValue();
531 if (!resultLayout.isAssigned())
534 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
535 if (!resultTy || resultTy.getRank() != 1) {
536 reduction.emitWarning(
"Expecting output type to be 1D vector.");
542 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
544 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
546 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
549void LayoutInfoPropagation::visitVectorBroadCastOp(
550 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
551 ArrayRef<const LayoutInfoLattice *> results) {
553 LayoutInfo resultLayout = results[0]->getValue();
554 if (!resultLayout.isAssigned())
557 VectorType resultTy =
broadcast.getResultVectorType();
558 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
560 broadcast.emitWarning(
"Expecting source type to be a vector type.");
565 if (sourceTy.getRank() != resultTy.getRank()) {
566 broadcast.emitWarning(
"Expecting source and result to have same rank.");
570 if (broadcastUnitDims.size() != 1) {
571 broadcast.emitWarning(
"Expecting source type to be nD vector only with "
572 "one broadcasted dimension.");
576 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
579void LayoutInfoPropagation::visitShapeCastOp(
580 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
581 ArrayRef<const LayoutInfoLattice *> results) {
583 LayoutInfo resultLayout = results[0]->getValue();
584 if (!resultLayout.isAssigned())
586 VectorType sourceTy = shapeCast.getSourceVectorType();
587 VectorType resultTy = shapeCast.getResultVectorType();
591 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
592 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
595 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
596 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
597 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
599 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
604void LayoutInfoPropagation::visitUpdateNdOffsetOp(
605 xegpu::UpdateNdOffsetOp updateNdOffset,
606 ArrayRef<LayoutInfoLattice *> operands,
607 ArrayRef<const LayoutInfoLattice *> results) {
609 LayoutInfo resultLayout = results[0]->getValue();
610 if (!resultLayout.isAssigned())
613 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
617void LayoutInfoPropagation::visitDpasOp(
618 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
619 ArrayRef<const LayoutInfoLattice *> results) {
620 VectorType aTy = dpas.getLhsType();
621 VectorType bTy = dpas.getRhsType();
625 const auto *uArchInstruction =
626 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->
getInstruction(
627 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
629 const unsigned dataALen = aTy.getShape().front();
630 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
635 "No suitable instruction multiple found for the given shape.");
637 const unsigned dataBLen = bTy.getShape().back();
638 auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
643 "No suitable instruction multiple found for the given shape.");
644 SmallVector<int> instDataA = {maxALen, subgroupSize};
645 SmallVector<int> instDataB = {subgroupSize, maxBLen};
647 LayoutInfo dpasALayout;
648 LayoutInfo dpasBLayout;
649 LayoutInfo dpasCLayout;
651 if (layoutKind == LayoutKind::InstData) {
653 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
655 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
657 dpasALayout = getSIMTLayoutInfoForDPASOperand(
658 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
659 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
660 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
663 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
664 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
665 if (operands.size() > 2) {
666 VectorType cTy = dpas.getAccType();
667 const unsigned dataCLen = bTy.getShape().back();
668 auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
673 "No suitable instruction multiple found for the given shape.");
674 SmallVector<int> instDataC = {maxALen, maxCLen};
676 if (layoutKind == LayoutKind::InstData)
678 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
680 dpasCLayout = getSIMTLayoutInfoForDPASOperand(
681 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
683 propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
688void LayoutInfoPropagation::visitStoreNdOp(
689 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
690 ArrayRef<const LayoutInfoLattice *> results) {
693 const auto *uArchInstruction =
694 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
696 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
697 VectorType dataTy = store.getValueType();
698 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
699 store.getValueType().getElementType());
701 store.emitWarning(
"No known block params found for the element type.");
702 auto [bWidth, bHeight, bCount] = blockWHC.value();
703 SmallVector<int> instData;
705 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
709 "No suitable instruction multiple found for the given shape.");
710 if (dataTy.getRank() == 1)
711 instData = {instWidth};
714 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
715 if (instHeight == -1)
717 "No suitable instruction multiple found for the given shape.");
718 instData = {instHeight, instWidth};
721 LayoutInfo storeLayout;
722 if (layoutKind == LayoutKind::InstData)
724 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
727 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
728 uArchInstruction->getPackedFormatBitSize());
730 for (LayoutInfoLattice *operand : operands)
731 propagateIfChanged(operand, operand->meet(storeLayout));
736void LayoutInfoPropagation::visitLoadNdOp(
737 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
738 ArrayRef<const LayoutInfoLattice *> results) {
739 LayoutInfo valueLayout = results[0]->getValue();
741 if (!valueLayout.isAssigned())
743 LayoutInfo tensorDescLayout = valueLayout;
747 if (
auto transpose =
load.getTranspose()) {
748 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
749 "LayoutInfoPropagation stage.");
750 tensorDescLayout = valueLayout.transpose(transpose.value());
753 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
758void LayoutInfoPropagation::visitTransposeOp(
759 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
760 ArrayRef<const LayoutInfoLattice *> results) {
762 LayoutInfo resultLayout = results[0]->getValue();
763 if (!resultLayout.isAssigned())
765 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
767 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
772void LayoutInfoPropagation::visitVectorBitcastOp(
773 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
774 ArrayRef<const LayoutInfoLattice *> results) {
776 LayoutInfo resultLayout = results[0]->getValue();
777 if (!resultLayout.isAssigned())
779 int inElemTyBitWidth =
780 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
781 int outElemTyBitWidth =
782 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
784 if (inElemTyBitWidth == outElemTyBitWidth) {
785 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
789 auto resultLaneLayout = resultLayout.getLaneLayout();
790 auto resultLaneData = resultLayout.getLaneData();
792 bitcast.getResultVectorType(),
793 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
796 "Result vector type can not be evenly distributed across lanes.");
799 int64_t rank = bitcast.getSourceVectorType().getRank();
802 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
803 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
804 : outElemTyBitWidth / inElemTyBitWidth;
805 SmallVector<int> sourceLaneLayout =
806 resultLayout.getLaneLayout();
807 SmallVector<int> outData = resultLayout.getLaneData();
812 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
813 if (outInnerBitsPerLane < inElemTyBitWidth) {
815 "Narrowing bitcast with cross lane communication is not supported.");
820 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
821 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
822 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
823 "dimension other than "
824 "the innermost dimension.");
828 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
829 : outData[rank - 1] * bitCastRatio;
830 sourceLaneData.push_back(innerMostLaneData);
834 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
835 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
840void LayoutInfoPropagation::visitLoadGatherOp(
841 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
842 ArrayRef<const LayoutInfoLattice *> results) {
844 auto payloadTy = dyn_cast<VectorType>(
load.getValueType());
846 load.emitWarning(
"Not propagating, non-vector payload supplied.");
851 SmallVector<int> instData{subgroupSize};
852 if (
auto chunkSize =
load.getChunkSize().value_or(0); chunkSize > 1)
853 instData.push_back(chunkSize);
854 else if (
auto srcTdescTy =
855 dyn_cast<xegpu::TensorDescType>(
load.getSourceType())) {
856 if (srcTdescTy.getChunkSizeAsInt() > 1)
857 instData.push_back(chunkSize);
860 if (layoutKind == LayoutKind::InstData)
861 layout = LayoutInfo(xegpu::LayoutAttr::get(
load.getContext(), instData));
863 layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
868 LayoutInfo maskLayout =
869 getDefaultSIMTLayoutInfo(
load->getContext(), 1, subgroupSize);
872 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
873 propagateIfChanged(operands[0], operands[0]->meet(layout));
875 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
876 if (
load.getOffsets())
877 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
882void LayoutInfoPropagation::visitCreateDescOp(
883 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
884 ArrayRef<const LayoutInfoLattice *> results) {
885 LayoutInfo descLayout = results[0]->getValue();
887 if (!descLayout.isAssigned())
891 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
893 propagateIfChanged(operands[1], operands[1]->meet(layout));
898void LayoutInfoPropagation::visitStoreScatterOp(
899 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
900 ArrayRef<const LayoutInfoLattice *> results) {
904 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
906 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
909 LayoutInfo payloadLayout;
913 if (
auto layout = storeScatter.getLayoutAttr()) {
914 payloadLayout = LayoutInfo(layout);
916 if (layoutKind == LayoutKind::InstData) {
917 SmallVector<int> instData{subgroupSize};
918 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0);
920 instData.push_back(chunkSize);
921 else if (
auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
922 storeScatter.getDestType())) {
923 if (dstTdescTy.getChunkSizeAsInt() > 1)
924 instData.push_back(chunkSize);
926 payloadLayout = LayoutInfo(
927 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
929 auto payloadShape = payloadTy.getShape();
930 if (payloadShape.size() > 1)
931 assert(payloadShape[0] == subgroupSize &&
932 "Expected the first dimension of 2D tensor descriptor to be "
935 payloadLayout = getDefaultSIMTLayoutInfo(
941 LayoutInfo maskLayout =
942 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
944 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
946 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
947 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
949 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
950 if (storeScatter.getOffsets())
951 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
960class RunLayoutInfoPropagation {
964 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) :
target(op) {
965 SymbolTableCollection symbolTable;
967 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
971 LayoutInfo getLayoutInfo(Value val);
973 void printAnalysisResult(llvm::raw_ostream &os);
976 DataFlowSolver solver;
981LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
982 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
985 return state->getValue();
989void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
990 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
991 os <<
"function: " << funcOp.getName() <<
":\n";
993 for (BlockArgument arg : funcOp.getArguments()) {
994 LayoutInfo layout = getLayoutInfo(arg);
995 os <<
"argument: " << arg <<
"\n";
1001 funcOp.walk([&](Operation *op) {
1007 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1013 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1014 LayoutInfo layout = getLayoutInfo(r);
1015 os <<
"layout for result #" << i <<
": ";
1022 SmallVector<FunctionOpInterface> funcOps;
1023 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1024 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1025 funcOps.push_back(funcOp);
1028 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1029 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1030 funcOps.push_back(gpuFuncOp);
1034 for (FunctionOpInterface funcOp : funcOps)
1035 printFunctionResult(funcOp);
1048 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1055 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1058 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1059 if (!layout &&
result.getNumUses() > 0) {
1060 op->
emitWarning(
"op has users but no layout assigned for its result");
1065 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1066 auto typeWithLayout = xegpu::TensorDescType::get(
1067 tensorDescTy.getContext(), tensorDescTy.getShape(),
1068 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1069 result.setType(typeWithLayout);
1103 mlir::RegionBranchTerminatorOpInterface terminator,
1106 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1112 terminator.getSuccessorRegions(operands, successors);
1116 terminator.getSuccessorOperands(successor);
1118 for (
auto [successorOperand, successorInput] :
1119 llvm::zip(successorOperands, successorInputs)) {
1120 Type inputType = successorInput.getType();
1122 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1124 xegpu::DistributeLayoutAttr successorInputLayout =
1125 getLayoutOfValue(successorInput);
1126 xegpu::DistributeLayoutAttr successorOperandLayout =
1127 getLayoutOfValue(successorOperand);
1130 if (!successorOperandLayout) {
1131 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1132 "branch terminator: "
1133 << successorOperand <<
"\n");
1137 if (successorInputLayout &&
1138 successorInputLayout != successorOperandLayout) {
1139 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1140 "operand forwarded as the argument: "
1141 << successorInputLayout <<
" vs "
1142 << successorOperandLayout <<
"\n");
1146 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1147 auto newTdescTy = xegpu::TensorDescType::get(
1148 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1149 tdescTy.getEncoding(), successorOperandLayout);
1150 successorInput.setType(newTdescTy);
1155 if (
auto result = dyn_cast<OpResult>(successorInput))
1164 mlir::FunctionOpInterface funcOp,
1169 Type argType = arg.getType();
1170 newArgTypes.push_back(argType);
1171 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1173 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1175 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1176 <<
" but got none.\n");
1179 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1180 auto newTdescTy = xegpu::TensorDescType::get(
1181 tensorDescTy.getContext(), tensorDescTy.getShape(),
1182 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1183 arg.setType(newTdescTy);
1184 newArgTypes.back() = newTdescTy;
1189 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1190 funcOp.getResultTypes()));
1195struct XeGPUPropagateLayoutPass final
1196 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1197 XeGPUPropagateLayoutPass() =
default;
1198 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1199 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1200 : XeGPUPropagateLayoutBase(
options) {}
1201 void runOnOperation()
override;
1206void XeGPUPropagateLayoutPass::runOnOperation() {
1207 LayoutKind layoutKind;
1208 if (this->layoutKind ==
"lane") {
1209 layoutKind = LayoutKind::Lane;
1210 }
else if (this->layoutKind ==
"inst") {
1211 layoutKind = LayoutKind::InstData;
1213 getOperation()->emitError(
"Unsupported layout kind option: " +
1215 signalPassFailure();
1218 RunLayoutInfoPropagation
analysis(getOperation(), layoutKind);
1221 auto &os = llvm::outs();
1226 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1227 LayoutInfo layout =
analysis.getLayoutInfo(val);
1228 if (!layout.isAssigned())
1230 xegpu::DistributeLayoutAttr layoutAttr =
1231 cast<xegpu::DistributeLayoutAttr>(layout.get());
1232 if (layout.isSliceLayout())
1233 return cast<xegpu::SliceAttr>(layoutAttr);
1234 return cast<xegpu::LayoutAttr>(layoutAttr);
1238 Operation *op = getOperation();
1240 for (mlir::Operation &op : llvm::reverse(block->
getOperations())) {
1243 .Case<mlir::RegionBranchTerminatorOpInterface>(
1244 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1246 getXeGPULayoutForValue);
1248 .Case<mlir::FunctionOpInterface>(
1249 [&](mlir::FunctionOpInterface funcOp) {
1251 getXeGPULayoutForValue);
1253 .Default([&](Operation *op) {
1254 r =
updateOp(builder, op, getXeGPULayoutForValue);
1257 op.
emitError(
"Failed to update operation with the layout.");
1263 if (walkResult.wasInterrupted()) {
1264 signalPassFailure();
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const