32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291template <
typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
320class LayoutInfoPropagation
324 unsigned indexBitWidth;
328 void visitStoreNdOp(xegpu::StoreNdOp store,
332 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
336 void visitLoadNdOp(xegpu::LoadNdOp
load,
340 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
344 void visitTransposeOp(vector::TransposeOp transpose,
348 void visitVectorBitcastOp(vector::BitCastOp bitcast,
352 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
356 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
360 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
364 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
368 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
371 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
375 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
379 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
383 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
387 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
391 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
395 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
402 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
409 void visitBranchOperand(
OpOperand &operand)
override {};
411 void visitCallOperand(
OpOperand &operand)
override {};
417 void visitExternalCall(CallOpInterface call,
422 void setToExitState(LayoutInfoLattice *lattice)
override {
423 (
void)lattice->meet(LayoutInfo());
428LogicalResult LayoutInfoPropagation::visitOperation(
429 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
430 ArrayRef<const LayoutInfoLattice *> results) {
433 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
434 .Case([&](xegpu::StoreNdOp storeNdOp) {
435 visitStoreNdOp(storeNdOp, operands, results);
437 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
438 visitStoreScatterOp(storeScatterOp, operands, results);
440 .Case([&](xegpu::LoadNdOp loadNdOp) {
441 visitLoadNdOp(loadNdOp, operands, results);
443 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
444 visitLoadGatherOp(loadGatherOp, operands, results);
446 .Case([&](xegpu::CreateDescOp createDescOp) {
447 visitCreateDescOp(createDescOp, operands, results);
449 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
450 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
452 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
453 visitPrefetchNdOp(prefetchNdOp, operands, results);
455 .Case([&](vector::TransposeOp transposeOp) {
456 visitTransposeOp(transposeOp, operands, results);
458 .Case([&](vector::BitCastOp bitcastOp) {
459 visitVectorBitcastOp(bitcastOp, operands, results);
461 .Case([&](vector::MultiDimReductionOp reductionOp) {
462 visitVectorMultiReductionOp(reductionOp, operands, results);
464 .Case([&](vector::BroadcastOp broadcastOp) {
465 visitVectorBroadCastOp(broadcastOp, operands, results);
467 .Case([&](vector::ShapeCastOp shapeCastOp) {
468 visitShapeCastOp(shapeCastOp, operands, results);
470 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
471 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
473 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
474 visitLoadMatrixOp(loadMatrixOp, operands, results);
476 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
477 visitStoreMatrixOp(storeMatrixOp, operands, results);
480 .Default([&](Operation *op) {
481 for (
const LayoutInfoLattice *resultInfo : results) {
482 if (!resultInfo->getValue().isAssigned())
484 for (
auto [operandInfo, operand] :
488 if (!isa<xegpu::TensorDescType, VectorType>(
489 operand.get().getType()))
492 meet(operandInfo, *resultInfo);
500bool LayoutInfoPropagation::hasParamsOfLayoutKind(
501 xegpu::DistributeLayoutAttr anchorLayout) {
502 if (anchorLayout ==
nullptr) {
505 if (layoutKind == xegpu::LayoutKind::InstData) {
506 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
508 if (layoutKind == xegpu::LayoutKind::Lane) {
509 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
510 anchorLayout.getEffectiveLaneDataAsInt().empty());
512 if (layoutKind == xegpu::LayoutKind::Subgroup) {
513 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
514 anchorLayout.getEffectiveSgDataAsInt().empty());
530 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
531 if (sgCount % sgLayout0)
533 int sgLayout1 = sgCount / sgLayout0;
534 int sgData0 = wgShape[0] / sgLayout0;
535 int sgData1 = wgShape[1] / sgLayout1;
536 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
537 (sgData0 % instData[0] || sgData1 % instData[1]))
539 candidates.emplace_back(sgLayout0, sgLayout1);
544 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
545 const std::pair<int, int> &
rhs) {
546 int diffLhs = std::abs(
lhs.first -
lhs.second);
547 int diffRhs = std::abs(
rhs.first -
rhs.second);
548 if (diffLhs != diffRhs)
549 return diffLhs < diffRhs;
550 return lhs.first <
rhs.first;
560 auto knownBlockSize = gpuFunc.getKnownBlockSize();
561 if (!knownBlockSize.has_value())
563 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
564 return flatBlockSize / sgSize;
567void LayoutInfoPropagation::visitPrefetchNdOp(
568 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
569 ArrayRef<const LayoutInfoLattice *> results) {
571 LayoutInfo prefetchLayout;
572 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
573 if (hasParamsOfLayoutKind(anchorLayout)) {
574 prefetchLayout = LayoutInfo(anchorLayout);
578 auto tdescTy = prefetch.getTensorDescType();
583 const auto *uArchInstruction =
584 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
586 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
589 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
591 prefetch.emitWarning(
"No known block params found for the element type.");
592 auto [bWidth, bHeight, bCount] = blockWHC.value();
593 SmallVector<int> instData;
595 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
597 prefetch.emitWarning(
598 "No suitable instruction multiple found for the given shape.");
599 if (tdescTy.getRank() == 1)
600 instData = {instWidth};
603 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
604 if (instHeight == -1)
605 prefetch.emitWarning(
606 "No suitable instruction multiple found for the given shape.");
607 instData = {instHeight, instWidth};
610 if (layoutKind == xegpu::LayoutKind::InstData)
612 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
614 prefetchLayout = getSIMTLayoutInfoBlockIO(
615 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
617 prefetch.setLayoutAttr(
618 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
621 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
624void LayoutInfoPropagation::visitVectorMultiReductionOp(
625 vector::MultiDimReductionOp reduction,
626 ArrayRef<LayoutInfoLattice *> operands,
627 ArrayRef<const LayoutInfoLattice *> results) {
629 LayoutInfo resLayoutInfo = results[0]->getValue();
630 if (!resLayoutInfo.isAssigned())
633 VectorType sourceTy = reduction.getSourceVectorType();
634 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
639 auto consumerLayoutAttr =
640 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
648 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
654 requiredResLayoutAttr, reductionDims);
656 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
658 propagateIfChanged(operands[1],
659 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
662void LayoutInfoPropagation::visitVectorBroadCastOp(
663 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
664 ArrayRef<const LayoutInfoLattice *> results) {
666 LayoutInfo resLayoutInfo = results[0]->getValue();
667 if (!resLayoutInfo.isAssigned())
671 VectorType resultTy =
broadcast.getResultVectorType();
672 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
677 auto srcShape = sourceTy.getShape();
678 auto resShape = resultTy.getShape();
680 auto resultLayoutAttr =
681 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
683 xegpu::DistributeLayoutAttr srcLayoutAttr =
686 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
689void LayoutInfoPropagation::visitShapeCastOp(
690 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
691 ArrayRef<const LayoutInfoLattice *> results) {
693 LayoutInfo resLayoutInfo = results[0]->getValue();
694 if (!resLayoutInfo.isAssigned())
696 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
697 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
698 auto resultLayoutAttr =
699 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
701 xegpu::DistributeLayoutAttr srcLayoutAttr =
704 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
709void LayoutInfoPropagation::visitUpdateNdOffsetOp(
710 xegpu::UpdateNdOffsetOp updateNdOffset,
711 ArrayRef<LayoutInfoLattice *> operands,
712 ArrayRef<const LayoutInfoLattice *> results) {
714 LayoutInfo resultLayout = results[0]->getValue();
715 if (!resultLayout.isAssigned())
718 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
722void LayoutInfoPropagation::visitDpasOp(
723 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
725 LayoutInfo dpasALayout;
726 LayoutInfo dpasBLayout;
727 LayoutInfo dpasCDLayout;
729 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
730 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
731 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
732 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
733 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
734 "Expected anchor layout for DPAS A operand.");
735 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
736 "Expected anchor layout for DPAS B operand.");
737 dpasALayout = LayoutInfo(anchorLayoutA);
738 dpasBLayout = LayoutInfo(anchorLayoutB);
739 dpasCDLayout = LayoutInfo(anchorLayoutCD);
744 VectorType aTy = dpas.getLhsType();
745 VectorType bTy = dpas.getRhsType();
746 VectorType cdTy = dpas.getResultType();
748 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
749 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
753 if (layoutKind == xegpu::LayoutKind::Subgroup) {
754 LayoutInfo consumerLayout = results[0]->getValue();
755 if (!consumerLayout.isAssigned())
758 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
762 "Unable to determine the number of subgroups for the operation.");
765 numSg = numSgOrErr.value();
768 consumerLayoutAttr, uArch, numSg);
769 if (!layouts.has_value()) {
771 "Failed to determine required layouts for DPAS operands.");
775 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
777 dpas.setLayoutAAttr(requiredALayout);
778 dpas.setLayoutBAttr(requiredBLayout);
779 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
780 dpasALayout = LayoutInfo(requiredALayout);
781 dpasBLayout = LayoutInfo(requiredBLayout);
782 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
784 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
785 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
786 if (operands.size() > 2)
787 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
791void LayoutInfoPropagation::visitStoreNdOp(
792 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
793 ArrayRef<const LayoutInfoLattice *> results) {
794 LayoutInfo storeLayout;
795 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
796 if (hasParamsOfLayoutKind(anchorLayout)) {
797 storeLayout = LayoutInfo(anchorLayout);
802 const auto *uArchInstruction =
803 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
805 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
806 VectorType dataTy = store.getValueType();
807 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
808 store.getValueType().getElementType());
810 store.emitWarning(
"No known block params found for the element type.");
811 auto [bWidth, bHeight, bCount] = blockWHC.value();
812 SmallVector<int> instData;
814 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
817 "No suitable instruction multiple found for the given shape.");
818 if (dataTy.getRank() == 1)
819 instData = {instWidth};
822 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
823 if (instHeight == -1)
825 "No suitable instruction multiple found for the given shape.");
826 instData = {instHeight, instWidth};
829 if (layoutKind == xegpu::LayoutKind::InstData)
831 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
832 else if (layoutKind == xegpu::LayoutKind::Lane)
834 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
835 uArchInstruction->getPackedFormatBitSize());
838 auto numSgOrErr =
getNumSg(store, sgSize);
841 "Unable to determine the number of subgroups for the operation.");
845 instData, numSgOrErr.value());
846 if (sgLayouts.empty()) {
848 "Unable to determine suitable subgroup layout for store value.");
851 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
852 SmallVector<int> sgData = {
853 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
854 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
855 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
863 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
867 for (LayoutInfoLattice *operand : operands)
868 propagateIfChanged(operand, operand->meet(storeLayout));
873void LayoutInfoPropagation::visitLoadNdOp(
874 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
875 ArrayRef<const LayoutInfoLattice *> results) {
876 LayoutInfo loadLayout;
877 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
878 if (hasParamsOfLayoutKind(anchorLayout)) {
879 loadLayout = LayoutInfo(anchorLayout);
882 LayoutInfo valueLayout = results[0]->getValue();
884 if (!valueLayout.isAssigned())
886 loadLayout = valueLayout;
890 if (
auto transpose =
load.getTranspose()) {
891 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
892 "LayoutInfoPropagation stage.");
893 loadLayout = valueLayout.transpose(transpose.value());
895 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
898 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
903void LayoutInfoPropagation::visitTransposeOp(
904 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
905 ArrayRef<const LayoutInfoLattice *> results) {
907 LayoutInfo resultLayout = results[0]->getValue();
908 if (!resultLayout.isAssigned())
910 auto consumerLayoutAttr =
911 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
913 consumerLayoutAttr, transpose.getPermutation());
915 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
920void LayoutInfoPropagation::visitVectorBitcastOp(
921 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
922 ArrayRef<const LayoutInfoLattice *> results) {
924 LayoutInfo resLayoutInfo = results[0]->getValue();
925 if (!resLayoutInfo.isAssigned())
928 auto srcVecType = bitcast.getSourceVectorType();
929 auto resVecType = bitcast.getResultVectorType();
931 auto consumerLayoutAttr =
932 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
937 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
941 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
942 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
946 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
948 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
951void LayoutInfoPropagation::visitInsertStridedSliceOp(
952 vector::InsertStridedSliceOp insertStridedSlice,
953 ArrayRef<LayoutInfoLattice *> operands,
954 ArrayRef<const LayoutInfoLattice *> results) {
956 LayoutInfo resLayoutInfo = results[0]->getValue();
957 if (!resLayoutInfo.isAssigned())
960 auto srcVecType = insertStridedSlice.getSourceVectorType();
961 auto resVecType = insertStridedSlice.getDestVectorType();
963 auto consumerLayoutAttr =
964 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
971 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
973 requiredResLayoutAttr);
976 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
977 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
978 propagateIfChanged(operands[1],
979 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
984void LayoutInfoPropagation::visitLoadGatherOp(
985 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
986 ArrayRef<const LayoutInfoLattice *> results) {
987 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
988 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
993 VectorType resVecTy =
load.getValueType();
994 int chunkSize =
load.getChunkSize().value_or(1);
996 LayoutInfo resLayoutInfo = results[0]->getValue();
997 if (!resLayoutInfo.isAssigned())
999 auto consumerLayoutAttr =
1000 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1002 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1003 requiredAnchorLayoutAttr = anchorLayoutAttr;
1006 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1010 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1011 load.setLayoutAttr(requiredAnchorLayoutAttr);
1014 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1017 if (chunkSize > 1) {
1018 if (layoutKind == xegpu::LayoutKind::InstData)
1020 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize});
1021 else if (layoutKind == xegpu::LayoutKind::Lane)
1023 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize}, {1});
1026 "chunked StoreScatterOp should not be used at workgroup level");
1029 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1030 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1033 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1034 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1036 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1037 if (
load.getOffsets())
1038 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1043void LayoutInfoPropagation::visitCreateDescOp(
1044 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1045 ArrayRef<const LayoutInfoLattice *> results) {
1046 LayoutInfo descLayout = results[0]->getValue();
1048 if (!descLayout.isAssigned())
1054 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1056 propagateIfChanged(operands[1], operands[1]->meet(layout));
1061void LayoutInfoPropagation::visitStoreScatterOp(
1062 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1063 ArrayRef<const LayoutInfoLattice *> results) {
1065 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1066 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1071 VectorType srcVecTy = storeScatter.getValueType();
1072 int chunkSize = storeScatter.getChunkSize().value_or(1);
1074 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1075 requiredAnchorLayoutAttr = anchorLayoutAttr;
1078 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1082 layoutKind, srcVecTy, chunkSize, uArch);
1083 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1086 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1087 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1090 if (chunkSize > 1) {
1091 if (layoutKind == xegpu::LayoutKind::InstData)
1093 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1094 else if (layoutKind == xegpu::LayoutKind::Lane)
1095 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1096 {subgroupSize}, {1});
1099 "chunked StoreScatterOp should not be used at workgroup level");
1102 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1105 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1107 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1108 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1110 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1111 if (storeScatter.getOffsets())
1112 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1115void LayoutInfoPropagation::visitLoadMatrixOp(
1116 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1117 ArrayRef<const LayoutInfoLattice *> results) {
1119 LayoutInfo resLayoutInfo = results[0]->getValue();
1120 if (!resLayoutInfo.isAssigned())
1123 auto consumerLayoutAttr =
1124 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1126 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1130 if (!hasParamsOfLayoutKind(anchorLayout)) {
1131 VectorType resVecTy =
1132 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1137 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1138 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1143void LayoutInfoPropagation::visitStoreMatrixOp(
1144 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1145 ArrayRef<const LayoutInfoLattice *> results) {
1146 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1148 if (hasParamsOfLayoutKind(anchorLayout)) {
1149 layout = LayoutInfo(anchorLayout);
1151 VectorType srcVecTy =
1152 llvm::cast<VectorType>(storeMatrix.getData().getType());
1156 auto requiredAnchorLayoutAttr =
1158 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1159 layout = LayoutInfo(requiredAnchorLayoutAttr);
1162 propagateIfChanged(operands[0], operands[0]->meet(layout));
1171class RunLayoutInfoPropagation {
1176 unsigned indexBitWidth)
1178 SymbolTableCollection symbolTable;
1180 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1184 LayoutInfo getLayoutInfo(Value val);
1186 void printAnalysisResult(llvm::raw_ostream &os);
1189 DataFlowSolver solver;
1194LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1195 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1198 return state->getValue();
1202void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1203 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1204 os <<
"function: " << funcOp.getName() <<
":\n";
1206 for (BlockArgument arg : funcOp.getArguments()) {
1207 LayoutInfo layout = getLayoutInfo(arg);
1208 os <<
"argument: " << arg <<
"\n";
1214 funcOp.walk([&](Operation *op) {
1220 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1226 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1227 LayoutInfo layout = getLayoutInfo(r);
1228 os <<
"layout for result #" << i <<
": ";
1235 SmallVector<FunctionOpInterface> funcOps;
1236 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1237 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1238 funcOps.push_back(funcOp);
1241 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1242 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1243 funcOps.push_back(gpuFuncOp);
1247 for (FunctionOpInterface funcOp : funcOps)
1248 printFunctionResult(funcOp);
1260static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1262 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1267 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1268 auto *parentOp = arg.getOwner()->getParentOp();
1269 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1270 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1272 return getDefiningCreateNdDescOp(tiedInit->
get());
1279struct ResolveLayoutConflicts {
1280 ResolveLayoutConflicts(Operation *parentOp)
1281 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1282 LogicalResult run();
1285 Operation *parentOp;
1287 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1288 LogicalResult resolveVectorConsumer(OpOperand &operand);
1293LogicalResult ResolveLayoutConflicts::run() {
1296 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1299 Type operandType = operand.get().getType();
1300 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1301 isa<xegpu::TensorDescType>(operandType)) {
1302 auto res = resolveTensorDescConsumer(operand);
1304 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1310 if (isa<VectorType>(operandType)) {
1311 auto res = resolveVectorConsumer(operand);
1313 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1321 return r.wasInterrupted() ? failure() :
success();
1325ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1326 Value vectorValue = operand.
get();
1327 Operation *consumerOp = operand.
getOwner();
1330 if (!producerLayout) {
1331 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1332 vectorTy && vectorTy.getRank() > 1)
1333 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1338 if (!consumerLayout)
1340 "No consumer layout found for vector operand.");
1343 if (consumerLayout.isEqualTo(producerLayout))
1348 auto convertOp = xegpu::ConvertLayoutOp::create(
1349 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1350 producerLayout, consumerLayout);
1353 operand.
set(convertOp.getResult());
1358ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1359 Operation *consumerOp = operand.
getOwner();
1360 Value tdescValue = operand.
get();
1361 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1362 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1363 assert(anchorOp && currTDescType &&
1364 "Expected anchor layout op and tensor descriptor consumer.");
1366 if (currTDescType.isScattered()) {
1367 DBGS() <<
"Scattered tensor descriptor not supported: " << tdescValue
1371 Attribute currLayout = currTDescType.getLayout();
1372 Attribute expectedLayout = anchorOp.getAnchorLayout();
1375 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1377 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1378 if (!conflictingCreateNdOp) {
1379 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1380 << tdescValue <<
"\n";
1385 auto newTensorDescType = xegpu::TensorDescType::get(
1386 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1387 currTDescType.getElementType(), currTDescType.getEncoding(),
1389 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1390 builder, consumerOp->
getLoc(), newTensorDescType,
1391 conflictingCreateNdOp->getOperands(),
1392 conflictingCreateNdOp->getAttrs());
1410 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1417 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1420 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1421 if (!layout &&
result.getNumUses() > 0) {
1422 op->
emitWarning(
"op has users but no layout assigned for its result");
1427 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1428 auto typeWithLayout = xegpu::TensorDescType::get(
1429 tensorDescTy.getContext(), tensorDescTy.getShape(),
1430 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1431 result.setType(typeWithLayout);
1465 mlir::RegionBranchTerminatorOpInterface terminator,
1468 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1473 branchOp.getSuccessorOperandInputMapping(mapping,
1475 for (
const auto &[successorOperand, successorInputs] : mapping) {
1476 for (
Value successorInput : successorInputs) {
1477 Type inputType = successorInput.getType();
1479 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1481 xegpu::DistributeLayoutAttr successorInputLayout =
1482 getLayoutOfValue(successorInput);
1483 xegpu::DistributeLayoutAttr successorOperandLayout =
1484 getLayoutOfValue(successorOperand->get());
1487 if (!successorOperandLayout) {
1488 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1489 "branch terminator: "
1490 << successorOperand->get() <<
"\n");
1494 if (successorInputLayout &&
1495 successorInputLayout != successorOperandLayout) {
1496 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1497 "operand forwarded as the argument: "
1498 << successorInputLayout <<
" vs "
1499 << successorOperandLayout <<
"\n");
1503 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1504 auto newTdescTy = xegpu::TensorDescType::get(
1505 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1506 tdescTy.getEncoding(), successorOperandLayout);
1507 successorInput.setType(newTdescTy);
1512 if (
auto result = dyn_cast<OpResult>(successorInput))
1521 mlir::FunctionOpInterface funcOp,
1527 if (!isa<FunctionType>(funcOp.getFunctionType()))
1532 Type argType = arg.getType();
1533 newArgTypes.push_back(argType);
1534 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1536 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1538 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1539 <<
" but got none.\n");
1542 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1543 auto newTdescTy = xegpu::TensorDescType::get(
1544 tensorDescTy.getContext(), tensorDescTy.getShape(),
1545 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1546 arg.setType(newTdescTy);
1547 newArgTypes.back() = newTdescTy;
1552 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1553 funcOp.getResultTypes()));
1558struct XeGPUPropagateLayoutPass final
1559 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1560 XeGPUPropagateLayoutPass() =
default;
1561 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1562 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1563 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1564 void runOnOperation()
override;
1571 unsigned indexBitWidth,
bool printOnly) {
1572 RunLayoutInfoPropagation analysis(
target, layoutKind, indexBitWidth);
1575 auto &os = llvm::outs();
1576 analysis.printAnalysisResult(os);
1580 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1581 LayoutInfo layout = analysis.getLayoutInfo(val);
1582 if (!layout.isAssigned())
1584 if (
auto opResult = dyn_cast<OpResult>(val)) {
1586 Operation *defOp = opResult.getDefiningOp();
1587 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1588 auto anchorLayout = anchorOp.getAnchorLayout();
1589 if (anchorLayout !=
nullptr)
1590 return anchorLayout;
1592 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1594 if (requiredResLayoutAttr !=
nullptr)
1595 return requiredResLayoutAttr;
1597 xegpu::DistributeLayoutAttr layoutAttr =
1598 cast<xegpu::DistributeLayoutAttr>(layout.get());
1599 if (layout.isSliceLayout())
1600 return cast<xegpu::SliceAttr>(layoutAttr);
1602 return cast<xegpu::LayoutAttr>(layoutAttr);
1610 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1612 getXeGPULayoutForValue);
1614 .Case([&](mlir::FunctionOpInterface funcOp) {
1616 getXeGPULayoutForValue);
1619 r =
updateOp(builder, op, getXeGPULayoutForValue);
1622 op.
emitError(
"Failed to update operation with the layout.");
1628 if (walkResult.wasInterrupted())
1635 ResolveLayoutConflicts resolver(
target);
1636 return resolver.run();
1639void XeGPUPropagateLayoutPass::runOnOperation() {
1641 if (this->layoutKind ==
"lane") {
1643 }
else if (this->layoutKind ==
"inst") {
1645 }
else if (this->layoutKind ==
"subgroup") {
1646 layoutKind = xegpu::LayoutKind::Subgroup;
1648 getOperation()->emitError(
"Unsupported layout kind option: " +
1650 signalPassFailure();
1655 this->indexBitWidth, this->printOnly))) {
1656 signalPassFailure();
1661 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const