29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/raw_ostream.h"
43 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48 #define DEBUG_TYPE "xegpu-propagate-layout"
49 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
82 xegpu::DistributeLayoutAttr storage =
nullptr;
85 LayoutInfo() =
default;
86 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
90 bool operator==(
const LayoutInfo &other)
const {
91 return this->isAssigned() == other.isAssigned();
94 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
96 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
98 void print(raw_ostream &os)
const;
100 bool isAssigned()
const {
return storage !=
nullptr; }
110 bool isSliceLayout()
const {
113 return isa<xegpu::SliceAttr>(storage);
116 int64_t getRank()
const {
119 return storage.getRank();
128 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
129 "Expected lane layout to be assigned");
130 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
131 [](int64_t val) { return static_cast<int>(val); });
137 assert(storage.getEffectiveLaneDataAsInt().size() &&
138 "Expected lane data to be assigned");
139 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
140 [](int64_t val) { return static_cast<int>(val); });
146 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
147 [](int64_t val) { return static_cast<int>(val); });
154 os <<
"Not assigned.";
158 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
159 if (!lhs.isAssigned())
165 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
166 llvm_unreachable(
"Join should not be triggered by layout propagation.");
174 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
175 bool hasDuplicates = seen.size() != permutation.size();
176 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
177 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
180 if (!withinRange || hasDuplicates) {
181 assert(
false &&
"Invalid permutation for transpose.");
188 for (int64_t idx : permutation) {
189 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
190 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
191 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
194 laneLayout, laneData));
202 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
204 using Lattice::Lattice;
210 template <
typename T>
211 int getLargestDivisor(T dim,
ArrayRef<T> candidates,
213 static_assert(std::is_integral<T>::value,
"T must be an integer type");
216 if (!candidateMultiples.empty())
218 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
219 for (T candidate : candidates) {
220 for (T multiple : multiples) {
221 int value =
static_cast<int>(candidate * multiple);
222 if (value != 0 && dim % value == 0 && value > largest)
240 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
250 unsigned rank,
int subgroupSize) {
251 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
259 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
262 unsigned packingSize,
263 bool isScattered =
false) {
265 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
266 "Expected 1D or 2D vector.");
268 assert(vectorTy.getElementType().isIntOrFloat() &&
269 "Expected int or float element type.");
271 if (vectorTy.getRank() == 1)
272 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1,
uArch, instData);
274 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
275 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
278 {uArch->getSubgroupSize(), 1},
279 {1, packingFactor}));
282 {1, uArch->getSubgroupSize()},
283 {1, packingFactor}));
287 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
290 unsigned packingSize,
291 bool isScattered =
false) {
293 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
294 "Expected 1D or 2D TensorDesc.");
296 assert(tdescTy.getElementType().isIntOrFloat() &&
297 "Expected int or float element type.");
299 if (tdescTy.getRank() == 1)
300 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1,
uArch, instData);
302 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
304 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
311 tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
321 getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum,
324 Type elementTy = vectorTy.getElementType();
326 "Expected int or float type in DPAS operands");
338 return getDefaultSIMTLayoutInfo(vectorTy,
uArch, instData, packingSize);
350 class LayoutInfoPropagation
356 void visitStoreNdOp(xegpu::StoreNdOp store,
360 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
364 void visitLoadNdOp(xegpu::LoadNdOp load,
368 void visitLoadGatherOp(xegpu::LoadGatherOp load,
372 void visitTransposeOp(vector::TransposeOp transpose,
376 void visitVectorBitcastOp(vector::BitCastOp bitcast,
380 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
384 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
388 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
392 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
396 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
399 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
413 void visitBranchOperand(
OpOperand &operand)
override {};
415 void visitCallOperand(
OpOperand &operand)
override {};
417 void visitExternalCall(CallOpInterface call,
422 void setToExitState(LayoutInfoLattice *lattice)
override {
423 (void)lattice->meet(LayoutInfo());
428 LogicalResult LayoutInfoPropagation::visitOperation(
432 .Case<xegpu::DpasOp>(
433 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
434 .Case<xegpu::StoreNdOp>(
435 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
436 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
437 visitStoreScatterOp(storeScatterOp, operands, results);
439 .Case<xegpu::LoadNdOp>(
440 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
441 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
442 visitLoadGatherOp(loadGatherOp, operands, results);
444 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
445 visitCreateDescOp(createDescOp, operands, results);
447 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
448 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
450 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
451 visitPrefetchNdOp(prefetchNdOp, operands, results);
453 .Case<vector::TransposeOp>([&](
auto transposeOp) {
454 visitTransposeOp(transposeOp, operands, results);
456 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
457 visitVectorBitcastOp(bitcastOp, operands, results);
459 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
460 visitVectorMultiReductionOp(reductionOp, operands, results);
462 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
463 visitVectorBroadCastOp(broadcastOp, operands, results);
465 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
466 visitShapeCastOp(shapeCastOp, operands, results);
470 for (
const LayoutInfoLattice *resultInfo : results) {
471 if (!resultInfo->getValue().isAssigned())
473 for (
auto [operandInfo, operand] :
477 if (!isa<xegpu::TensorDescType, VectorType>(
478 operand.get().getType()))
481 meet(operandInfo, *resultInfo);
489 void LayoutInfoPropagation::visitPrefetchNdOp(
494 auto tdescTy = prefetch.getTensorDescType();
497 const auto *uArchInstruction =
498 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
503 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
505 prefetch.emitWarning(
"No known block params found for the element type.");
506 auto [bWidth, bHeight, bCount] = blockWHC.value();
508 int instWidth = getLargestDivisor(
509 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
512 prefetch.emitWarning(
513 "No suitable instruction multiple found for the given shape.");
514 if (tdescTy.getRank() == 1)
515 instData = {instWidth};
517 int instHeight = getLargestDivisor(
518 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
519 if (instHeight == -1)
520 prefetch.emitWarning(
521 "No suitable instruction multiple found for the given shape.");
522 instData = {instHeight, instWidth};
524 auto prefetchLayout = getDefaultSIMTLayoutInfo(
525 tdescTy,
uArch, instData, uArchInstruction->getPackedFormatBitSize());
527 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
530 void LayoutInfoPropagation::visitVectorMultiReductionOp(
531 vector::MultiDimReductionOp reduction,
535 LayoutInfo resultLayout = results[0]->getValue();
536 if (!resultLayout.isAssigned())
539 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
540 if (!resultTy || resultTy.getRank() != 1) {
541 reduction.emitWarning(
"Expecting output type to be 1D vector.");
547 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
549 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
551 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
554 void LayoutInfoPropagation::visitVectorBroadCastOp(
558 LayoutInfo resultLayout = results[0]->getValue();
559 if (!resultLayout.isAssigned())
562 VectorType resultTy =
broadcast.getResultVectorType();
563 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
565 broadcast.emitWarning(
"Expecting source type to be a vector type.");
570 if (sourceTy.getRank() != resultTy.getRank()) {
571 broadcast.emitWarning(
"Expecting source and result to have same rank.");
575 if (broadcastUnitDims.size() != 1) {
576 broadcast.emitWarning(
"Expecting source type to be nD vector only with "
577 "one broadcasted dimension.");
581 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
584 void LayoutInfoPropagation::visitShapeCastOp(
588 LayoutInfo resultLayout = results[0]->getValue();
589 if (!resultLayout.isAssigned())
591 VectorType sourceTy = shapeCast.getSourceVectorType();
592 VectorType resultTy = shapeCast.getResultVectorType();
596 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
597 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
600 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
602 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
604 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
609 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
610 xegpu::UpdateNdOffsetOp updateNdOffset,
614 LayoutInfo resultLayout = results[0]->getValue();
615 if (!resultLayout.isAssigned())
618 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
622 void LayoutInfoPropagation::visitDpasOp(
625 VectorType aTy = dpas.getLhsType();
626 VectorType bTy = dpas.getRhsType();
630 const auto *uArchInstruction =
634 const unsigned dataALen = aTy.getShape().front();
635 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
640 "No suitable instruction multiple found for the given shape.");
642 const unsigned dataBLen = bTy.getShape().back();
643 auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
648 "No suitable instruction multiple found for the given shape.");
652 propagateIfChanged(operands[0],
653 operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
654 aTy, 0,
uArch, instDataA,
655 uArchInstruction->getPackedFormatBitSizeA())));
656 propagateIfChanged(operands[1],
657 operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
658 bTy, 1,
uArch, instDataB,
659 uArchInstruction->getPackedFormatBitSizeB())));
660 if (operands.size() > 2) {
661 VectorType cTy = dpas.getAccType();
662 const unsigned dataCLen = bTy.getShape().back();
663 auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
668 "No suitable instruction multiple found for the given shape.");
670 propagateIfChanged(operands[2],
671 operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
672 cTy, 2,
uArch, instDataC,
673 uArchInstruction->getPackedFormatBitSizeB())));
678 void LayoutInfoPropagation::visitStoreNdOp(
683 const auto *uArchInstruction =
684 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
687 VectorType dataTy = store.getValueType();
688 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
689 store.getValueType().getElementType());
691 store.emitWarning(
"No known block params found for the element type.");
692 auto [bWidth, bHeight, bCount] = blockWHC.value();
694 int instWidth = getLargestDivisor(
695 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
699 "No suitable instruction multiple found for the given shape.");
700 if (dataTy.getRank() == 1)
701 instData = {instWidth};
703 int instHeight = getLargestDivisor(
704 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
705 if (instHeight == -1)
707 "No suitable instruction multiple found for the given shape.");
708 instData = {instHeight, instWidth};
710 LayoutInfo storeLayout =
711 getDefaultSIMTLayoutInfo(store.getValueType(),
uArch, instData,
712 uArchInstruction->getPackedFormatBitSize());
714 for (LayoutInfoLattice *operand : operands)
715 propagateIfChanged(operand, operand->meet(storeLayout));
720 void LayoutInfoPropagation::visitLoadNdOp(
723 LayoutInfo valueLayout = results[0]->getValue();
725 if (!valueLayout.isAssigned())
727 LayoutInfo tensorDescLayout = valueLayout;
731 if (
auto transpose = load.getTranspose()) {
732 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
733 "LayoutInfoPropagation stage.");
734 tensorDescLayout = valueLayout.transpose(transpose.value());
737 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
742 void LayoutInfoPropagation::visitTransposeOp(
746 LayoutInfo resultLayout = results[0]->getValue();
747 if (!resultLayout.isAssigned())
749 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
751 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
756 void LayoutInfoPropagation::visitVectorBitcastOp(
760 LayoutInfo resultLayout = results[0]->getValue();
761 if (!resultLayout.isAssigned())
763 int inElemTyBitWidth =
764 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
765 int outElemTyBitWidth =
766 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
768 if (inElemTyBitWidth == outElemTyBitWidth) {
769 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
773 auto resultLaneLayout = resultLayout.getLaneLayout();
774 auto resultLaneData = resultLayout.getLaneData();
776 bitcast.getResultVectorType(),
780 "Result vector type can not be evenly distributed across lanes.");
783 int64_t rank = bitcast.getSourceVectorType().getRank();
786 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
787 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
788 : outElemTyBitWidth / inElemTyBitWidth;
790 resultLayout.getLaneLayout();
796 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
797 if (outInnerBitsPerLane < inElemTyBitWidth) {
799 "Narrowing bitcast with cross lane communication is not supported.");
805 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
806 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
807 "dimension other than "
808 "the innermost dimension.");
812 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
813 : outData[rank - 1] * bitCastRatio;
814 sourceLaneData.push_back(innerMostLaneData);
819 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
824 void LayoutInfoPropagation::visitLoadGatherOp(
828 auto payloadTy = dyn_cast<VectorType>(load.getValueType());
830 load.emitWarning(
"Not propagating, non-vector payload supplied.");
836 if (
auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
837 instData.push_back(chunkSize);
838 else if (
auto srcTdescTy =
839 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
840 if (srcTdescTy.getChunkSizeAsInt() > 1)
841 instData.push_back(chunkSize);
843 LayoutInfo layout = getDefaultSIMTLayoutInfo(
848 LayoutInfo maskLayout =
849 getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
852 if (isa<xegpu::TensorDescType>(load.getSourceType()))
853 propagateIfChanged(operands[0], operands[0]->meet(layout));
855 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
856 if (load.getOffsets())
857 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
862 void LayoutInfoPropagation::visitCreateDescOp(
865 LayoutInfo descLayout = results[0]->getValue();
867 if (!descLayout.isAssigned())
871 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
873 propagateIfChanged(operands[1], operands[1]->meet(layout));
878 void LayoutInfoPropagation::visitStoreScatterOp(
884 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
886 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
892 auto payloadShape = payloadTy.getShape();
893 if (payloadShape.size() > 1)
895 payloadShape[0] == subgroupSize &&
896 "Expected the first dimension of 2D tensor descriptor to be equal to "
900 if (
auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
901 instData.push_back(chunkSize);
902 else if (
auto dstTdescTy =
903 dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
904 if (dstTdescTy.getChunkSizeAsInt() > 1)
905 instData.push_back(chunkSize);
907 LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
911 LayoutInfo maskLayout =
912 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
914 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
916 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
917 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
919 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
920 if (storeScatter.getOffsets())
921 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
930 class RunLayoutInfoPropagation {
934 RunLayoutInfoPropagation(
Operation *op) : target(op) {
937 solver.load<LayoutInfoPropagation>(symbolTable);
938 (void)solver.initializeAndRun(op);
941 LayoutInfo getLayoutInfo(
Value val);
943 void printAnalysisResult(llvm::raw_ostream &os);
951 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
952 auto *state = solver.lookupState<LayoutInfoLattice>(val);
955 return state->getValue();
959 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
960 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
961 os <<
"function: " << funcOp.getName() <<
":\n";
964 LayoutInfo layout = getLayoutInfo(arg);
965 os <<
"argument: " << arg <<
"\n";
977 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
984 LayoutInfo layout = getLayoutInfo(r);
985 os <<
"layout for result #" << i <<
": ";
993 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
994 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
995 funcOps.push_back(funcOp);
998 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
999 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1000 funcOps.push_back(gpuFuncOp);
1004 for (FunctionOpInterface funcOp : funcOps)
1005 printFunctionResult(funcOp);
1018 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1023 Type resultType = result.getType();
1025 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1028 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1029 if (!layout && result.getNumUses() > 0) {
1030 op->
emitWarning(
"op has users but no layout assigned for its result");
1035 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1037 tensorDescTy.getContext(), tensorDescTy.getShape(),
1038 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1039 result.setType(typeWithLayout);
1071 static LogicalResult
1073 mlir::RegionBranchTerminatorOpInterface terminator,
1076 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1082 terminator.getSuccessorRegions(operands, successors);
1086 terminator.getSuccessorOperands(successor);
1088 for (
auto [successorOperand, successorInput] :
1089 llvm::zip(successorOperands, successorInputs)) {
1090 Type inputType = successorInput.getType();
1092 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1094 xegpu::DistributeLayoutAttr successorInputLayout =
1095 getLayoutOfValue(successorInput);
1096 xegpu::DistributeLayoutAttr successorOperandLayout =
1097 getLayoutOfValue(successorOperand);
1100 if (!successorOperandLayout) {
1101 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1102 "branch terminator: "
1103 << successorOperand <<
"\n");
1107 if (successorInputLayout &&
1108 successorInputLayout != successorOperandLayout) {
1109 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1110 "operand forwarded as the argument: "
1111 << successorInputLayout <<
" vs "
1112 << successorOperandLayout <<
"\n");
1116 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1118 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1119 tdescTy.getEncoding(), successorOperandLayout);
1120 successorInput.setType(newTdescTy);
1125 if (
auto result = dyn_cast<OpResult>(successorInput))
1134 mlir::FunctionOpInterface funcOp,
1139 Type argType = arg.getType();
1140 newArgTypes.push_back(argType);
1141 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1143 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1145 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1146 <<
" but got none.\n");
1149 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1151 tensorDescTy.getContext(), tensorDescTy.getShape(),
1152 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1153 arg.setType(newTdescTy);
1154 newArgTypes.back() = newTdescTy;
1160 funcOp.getResultTypes()));
1165 struct XeGPUPropagateLayoutPass final
1166 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1167 XeGPUPropagateLayoutPass() =
default;
1168 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1169 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1170 : XeGPUPropagateLayoutBase(
options) {}
1171 void runOnOperation()
override;
1176 void XeGPUPropagateLayoutPass::runOnOperation() {
1177 auto &
analysis = getAnalysis<RunLayoutInfoPropagation>();
1180 auto &os = llvm::outs();
1185 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1186 LayoutInfo layout =
analysis.getLayoutInfo(val);
1187 if (!layout.isAssigned())
1189 xegpu::DistributeLayoutAttr layoutAttr =
1190 cast<xegpu::DistributeLayoutAttr>(layout.get());
1191 if (this->layoutKind ==
"lane")
1192 layoutAttr = layoutAttr.dropInstData();
1193 if (layout.isSliceLayout())
1194 return cast<xegpu::SliceAttr>(layoutAttr);
1195 return cast<xegpu::LayoutAttr>(layoutAttr);
1202 LogicalResult r = success();
1203 TypeSwitch<Operation *>(&op)
1204 .Case<mlir::RegionBranchTerminatorOpInterface>(
1205 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1206 r = updateControlFlowOps(builder, branchTermOp,
1207 getXeGPULayoutForValue);
1209 .Case<mlir::FunctionOpInterface>(
1210 [&](mlir::FunctionOpInterface funcOp) {
1212 getXeGPULayoutForValue);
1215 r =
updateOp(builder, op, getXeGPULayoutForValue);
1218 op.
emitError(
"Failed to update operation with the layout.");
1224 if (walkResult.wasInterrupted()) {
1225 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Subgroup2DBlockPrefetch
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
const Instruction * getInstruction(InstructionKind instKind) const
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0