29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/Support/FormatVariadic.h"
39 out.reserve(attrs.size());
41 for (
auto attr : attrs) {
42 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
43 auto newLayout = dist.dropSgLayoutAndData();
45 out.emplace_back(attr.getName(), newLayout);
57 out.reserve(attrs.size());
59 for (
auto attr : attrs) {
60 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
61 auto newLayout = dist.dropInstData();
63 out.emplace_back(attr.getName(), newLayout);
75 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.
getType());
76 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
78 auto typeWithLayout = xegpu::TensorDescType::get(
79 tensorDescTy.getContext(), tensorDescTy.getShape(),
80 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
95 llvm::ReversePostOrderTraversal<Region *> rpot(®ion);
97 for (
Block *block : llvm::reverse(blocks)) {
99 for (
Operation &op : llvm::reverse(*block)) {
107 for (
Region &nested : op.getRegions())
116 xegpu::DistributeLayoutAttr layout =
nullptr;
132 if (op->
getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
143 if (isa<xegpu::TensorDescType>(resultType))
148 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
152 xegpu::DistributeLayoutAttr operandLayout =
155 if (isa<VectorType>(opr.get().getType()) && operandLayout)
167 mlir::RegionBranchTerminatorOpInterface yieldOp) {
168 auto regionBranchOp =
169 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
175 yieldOp.getSuccessorRegions(operandAttrs, successors);
178 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
182 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
183 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
185 for (
unsigned i = 0; i < count; ++i) {
186 xegpu::DistributeLayoutAttr layout;
187 if (successor.isParent()) {
190 auto regionResult = regionBranchOp->getResult(i);
195 if (isa<xegpu::TensorDescType>(regionResult.getType()))
206 auto operandType = succOps[i].
getType();
207 if (isa<VectorType>(operandType) ||
208 dyn_cast<xegpu::TensorDescType>(operandType))
223 for (
Region ®ion : regionOp->getRegions()) {
228 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
229 for (
auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
235 if (isa<xegpu::TensorDescType>(regionArg.getType()))
241 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
242 for (
Value predVal : predValues) {
244 for (
OpOperand &operand : regionOp->getOpOperands()) {
245 if (operand.get() == predVal)
283 auto processFunc = [&](
Region &body, StringRef funcName) {
285 if (
auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
287 }
else if (
auto yieldOp =
288 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
290 }
else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
296 rootOp->
walk([&](func::FuncOp
func) {
297 processFunc(
func.getBody(),
func.getSymName());
299 rootOp->
walk([&](gpu::GPUFuncOp
func) {
300 processFunc(
func.getBody(),
func.getName());
306template <
typename T,
typename>
308 Operation *owner = operandOrResult.getOwner();
326 for (
auto namedAttr : nestOp->
getAttrs()) {
327 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
328 attrsToRemove.push_back(namedAttr.getName());
330 for (
auto attrName : attrsToRemove)
339 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
340 attrsToRemove.push_back(namedAttr.getName());
342 for (
auto attrName : attrsToRemove)
349xegpu::DistributeLayoutAttr
355 size_t dimDiff = resShape.size() - srcShape.size();
356 auto bcastSourceLayout = resLayout;
357 for (
size_t i = dimDiff; i < resShape.size(); i++) {
358 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
359 bcastDims.push_back(i);
364 if (!bcastDims.empty())
365 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
369 for (
size_t i = 0; i < dimDiff; i++)
370 sliceDims.push_back(i);
371 bcastSourceLayout = xegpu::SliceAttr::get(
372 resLayout.getContext(), bcastSourceLayout,
375 return bcastSourceLayout;
380xegpu::DistributeLayoutAttr
384 assert(isa<xegpu::SliceAttr>(resLayout) &&
385 "reduction result layout must be slice layout");
387 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
389 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
390 "reduction dims must match with slice dims");
392 return sliceLayout.getParent();
395xegpu::DistributeLayoutAttr
402xegpu::DistributeLayoutAttr
405 return resLayout.transposeDims(permutation);
411xegpu::DistributeLayoutAttr
413 int resElemTyBitWidth,
int srcElemTyBitWidth) {
418 size_t sgDataSize = sgData.size();
419 size_t instDataSize = instData.size();
420 size_t laneDataSize = laneData.size();
424 int64_t dim = resLayout.getRank() - 1;
426 if (srcElemTyBitWidth <= resElemTyBitWidth) {
427 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
429 sgDataValue = sgData.back() * bitWidthRatio;
431 instDataValue = instData.back() * bitWidthRatio;
433 laneDataValue = laneData.back() * bitWidthRatio;
435 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
437 assert((sgData.back() % bitWidthRatio) == 0 &&
438 "sgData not divisible by bitWidthRatio");
439 sgDataValue = sgData.back() / bitWidthRatio;
442 assert((instData.back() % bitWidthRatio) == 0 &&
443 "instData not divisible by bitWidthRatio");
444 instDataValue = instData.back() / bitWidthRatio;
447 assert((laneData.back() % bitWidthRatio) == 0 &&
448 "laneData not divisible by bitWidthRatio");
449 laneDataValue = laneData.back() / bitWidthRatio;
453 xegpu::DistributeLayoutAttr finalSrcLayout;
455 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
457 return finalSrcLayout;
464xegpu::DistributeLayoutAttr
470 size_t sgDataSize = sgData.size();
471 size_t instDataSize = instData.size();
472 size_t laneDataSize = laneData.size();
476 int64_t dim = resLayout.getRank() - 1;
480 constexpr int ratio = 2;
482 assert((sgData.back() % ratio) == 0 &&
483 "sgData not divisible by interleave ratio");
484 sgDataValue = sgData.back() / ratio;
487 assert((instData.back() % ratio) == 0 &&
488 "instData not divisible by interleave ratio");
489 instDataValue = instData.back() / ratio;
492 assert((laneData.back() % ratio) == 0 &&
493 "laneData not divisible by interleave ratio");
494 laneDataValue = laneData.back() / ratio;
497 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
504xegpu::DistributeLayoutAttr
510 size_t sgDataSize = sgData.size();
511 size_t instDataSize = instData.size();
512 size_t laneDataSize = laneData.size();
516 int64_t dim = resLayout.getRank() - 1;
520 constexpr int ratio = 2;
522 sgDataValue = sgData.back() * ratio;
524 instDataValue = instData.back() * ratio;
526 laneDataValue = laneData.back() * ratio;
528 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
538 int srcShapeSize = srcShape.size();
539 int resShapeSize = resShape.size();
540 int dimDiff = resShapeSize - srcShapeSize;
545 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
546 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
547 for (
int i = 0; i < dimDiff; i++) {
548 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
549 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
550 "Leading dimensions being sliced off must not be distributed");
552 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
561xegpu::DistributeLayoutAttr
566 int srcShapeSize = srcShape.size();
567 int resShapeSize = resShape.size();
568 int dimDiff = resShapeSize - srcShapeSize;
573 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
574 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
575 for (
int i = 0; i < dimDiff; i++) {
576 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
577 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
578 "Leading dimensions being sliced off must not be distributed");
580 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
590xegpu::DistributeLayoutAttr
595 int srcShapeSize = srcShape.size();
596 int resShapeSize = resShape.size();
597 int dimDiff = srcShapeSize - resShapeSize;
598 auto context = resLayout.getContext();
602 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
603 auto sgData = resLayout.getEffectiveSgDataAsInt();
604 auto instData = resLayout.getEffectiveInstDataAsInt();
605 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
606 auto laneData = resLayout.getEffectiveLaneDataAsInt();
607 auto order = resLayout.getEffectiveOrderAsInt();
617 for (
auto &o : order)
623 for (
int i = 0; i < dimDiff; i++) {
624 if (!sgLayout.empty())
625 sgLayout.insert(sgLayout.begin(), 1);
627 sgData.insert(sgData.begin(), 1);
628 if (!instData.empty())
629 instData.insert(instData.begin(), 1);
630 if (!laneLayout.empty())
631 laneLayout.insert(laneLayout.begin(), 1);
632 if (!laneData.empty())
633 laneData.insert(laneData.begin(), 1);
634 order.push_back(dimDiff - 1 - i);
644 auto srcLayout = xegpu::LayoutAttr::get(
645 context, sgLayout.empty() ?
nullptr : toAttr(sgLayout),
646 sgData.empty() ?
nullptr : toAttr(sgData),
647 instData.empty() ?
nullptr : toAttr(instData),
648 laneLayout.empty() ?
nullptr : toAttr(laneLayout),
649 laneData.empty() ?
nullptr : toAttr(laneData),
650 (!orderAttr || orderAttr.empty()) ?
nullptr : toAttr(order));
658xegpu::DistributeLayoutAttr
682 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
689 auto srcLayout = resLayout;
690 for (
const auto &dimGroup : splitDimGroups)
691 srcLayout = srcLayout.collapseDims(dimGroup);
700 if ((dst.size() != 2) && (dst.size() != 1))
702 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
703 std::multiplies<int64_t>());
705 return (dst[0] == srcSize);
706 return (dst[0] == 1) && (dst[1] == srcSize);
709 if (matchCollapseToInnermostDim(srcShape, resShape)) {
710 int srcShapeSize = srcShape.size();
711 int resShapeSize = resShape.size();
712 auto context = resLayout.getContext();
713 auto resInstData = resLayout.getEffectiveInstDataAsInt();
714 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
715 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
731 if (resInstData.size() != 0) {
733 for (
int i = 0; i < resShapeSize - 1; i++) {
734 assert(resInstData[i] == 1 &&
735 "only innermost dim can have non-unit instData");
738 inferredInstData[srcShapeSize - 1] =
739 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
740 return xegpu::LayoutAttr::get(context, inferredInstData);
743 if (resLaneLayout.size() != 0) {
744 for (
int i = 0; i < resShapeSize - 1; i++) {
745 assert(resLaneData[i] == 1 &&
746 "only innermost dim can have non-unit instData");
748 assert(srcShape.back() % resLaneLayout.back() == 0 &&
749 "source innermost dim must be >= result lane layout");
752 inferredLaneLayout.back() = resLaneLayout.back();
753 inferredLaneData.back() = std::min(
754 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
755 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
759 llvm_unreachable(
"running into unsupported shape cast scenarios");
766 xegpu::DistributeLayoutAttr payloadLayout,
int chunkSize) {
767 auto rank = payloadLayout.getRank();
769 return payloadLayout.dropDims(
770 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
771 return payloadLayout;
840 auto srcShape = srcVecTy.getShape();
841 int srcRank = srcShape.size();
842 auto context = srcVecTy.getContext();
850 const int subgroupSize =
uArch->getSubgroupSize();
851 int64_t maxReduceVectorSize = 1;
852 xegpu::DistributeLayoutAttr srcLayout;
854 xegpu::SliceAttr consumerSliceLayout =
855 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
856 if (consumerSliceLayout &&
857 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
858 srcLayout = consumerSliceLayout.getParent();
860 srcLayout.getEffectiveSgLayoutAsInt();
863 for (
int dim = 0; dim < srcRank; dim++) {
864 if (llvm::is_contained(reductionDims, dim))
866 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
870 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
873 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
876 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
879 consumerLayout ? consumerLayout.getOrder() :
nullptr;
881 int remainingSgCount =
882 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
886 for (
int i = 0; i < srcRank; i++) {
887 if (!llvm::is_contained(reductionDims, i) &&
888 consumerIdx <
static_cast<int>(consumerSgLayout.size())) {
889 sgLayout[i] = consumerSgLayout[consumerIdx];
890 sgData[i] = consumerSgData[consumerIdx];
891 remainingSgCount /= sgLayout[i];
892 order[i] = consumerOrder[consumerIdx];
899 int64_t remainOrder = consumerSgLayout.size();
900 for (
int i = 0; i < srcRank; i++) {
901 if (llvm::is_contained(reductionDims, i)) {
903 std::min(srcShape[i],
static_cast<int64_t>(remainingSgCount));
904 assert((srcShape[i] % sgLayout[i] == 0) &&
905 "source shape not divisible by sg_layout");
906 sgData[i] = srcShape[i] / sgLayout[i];
907 remainingSgCount /= sgLayout[i];
908 order[i] = remainOrder++;
912 assert(remainingSgCount == 1 &&
"not all subgroups distributed");
913 srcLayout = xegpu::LayoutAttr::get(
914 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
917 (!orderAttr || orderAttr.empty()) ?
nullptr : toInt32Attr(order));
923 instData[srcRank - 2] =
924 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
925 instData[srcRank - 1] =
926 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
927 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
931 laneLayout[srcRank - 1] =
932 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
934 laneData[srcRank - 2] =
935 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
936 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
937 toInt32Attr(laneData));
940 return xegpu::SliceAttr::get(context, srcLayout,
951 auto srcShape = srcVecTy.getShape();
952 auto context = srcVecTy.getContext();
953 auto subgroupSize =
uArch->getSubgroupSize();
954 xegpu::LayoutAttr srcLayout;
957 assert(
true &&
"subgroup layout assignment not supported for reduction (op "
958 "is not expected at this level).");
960 assert(
true &&
"instData layout assignment not supported for reduction (op "
961 "is not expected at this level).");
964 laneLayout[0] = std::min(subgroupSize,
static_cast<int32_t
>(srcShape[0]));
966 srcLayout = xegpu::LayoutAttr::get(
971 auto result = xegpu::SliceAttr::get(context, srcLayout,
1002 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1003 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1009 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1010 "laneData must be available for all dimensions");
1011 size_t dim = srcShape.size() - 1;
1015 const int subgroupSize =
uArch->getSubgroupSize();
1017 if (srcElemTyBitWidth > resElemTyBitWidth) {
1021 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1022 int innermostDimLaneLayout = subgroupSize;
1024 sgDataValue = sgData[dim];
1026 instDataValue = instData[dim];
1029 while ((instDataValue <= srcShape[dim]) &&
1030 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1032 assert((srcShape[dim] % instDataValue) == 0 &&
1033 "srcShape, instData, and lanelayout for innermost must be 2^n !");
1035 laneDataValue = laneData[dim];
1036 while ((laneDataValue <= srcShape[dim]) &&
1037 (laneDataValue % bitWidthRatio != 0))
1041 xegpu::DistributeLayoutAttr resLayout;
1042 resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
1046 return consumerLayout;
1073 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1074 "consumer layout rank must match source shape rank");
1075 const size_t innerMostDim = srcShape.size() - 1;
1081 constexpr int ratio = 2;
1082 int innermostDimLaneLayout =
uArch->getSubgroupSize();
1085 sgDataValue = sgData[innerMostDim];
1087 while ((sgDataValue <= srcShape[innerMostDim]) &&
1088 (sgDataValue % ratio != 0))
1089 sgDataValue *= ratio;
1091 instDataValue = instData[innerMostDim];
1094 while ((instDataValue <= srcShape[innerMostDim]) &&
1095 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1096 instDataValue *= ratio;
1097 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1098 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1100 laneDataValue = laneData[innerMostDim];
1103 while ((laneDataValue <= srcShape[innerMostDim]) &&
1104 (laneDataValue % ratio != 0))
1105 laneDataValue *= ratio;
1108 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1117 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1120 xegpu::DistributeLayoutAttr requiredResLayout;
1122 consumerLayout.getEffectiveInstDataAsInt();
1124 consumerLayout.getEffectiveLaneDataAsInt();
1126 consumerLayout.getEffectiveLaneLayoutAsInt();
1131 requiredResLayout = consumerLayout;
1132 int srcRank = srcShape.size();
1136 "subgroup layout assignment not supported for insertStridedSlice.");
1138 for (
int dim = 0; dim < srcRank; dim++) {
1139 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1141 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1144 for (
int dim = 0; dim < srcRank; dim++) {
1145 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1146 "srcShape must be divisible by laneLayout for all dimensions");
1147 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1148 consumerLaneData[dim]);
1150 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1153 return requiredResLayout;
1170 xegpu::DistributeLayoutAttr consumerLayout,
bool isChunkedLoad,
1174 return consumerLayout;
1177 consumerLayout.getEffectiveInstDataAsInt();
1179 consumerLayout.getEffectiveLaneDataAsInt();
1185 if (!isChunkedLoad) {
1187 instData.back() = std::min(
static_cast<int>(consumerInstData.back()),
1188 maxChunkSize * subgroupSize);
1189 return xegpu::LayoutAttr::get(context, instData);
1192 std::min(
static_cast<int>(consumerLaneData.back()), maxChunkSize);
1193 laneLayout.back() = std::min(
static_cast<int64_t>(subgroupSize),
1194 resShape.back() / laneData.back());
1195 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1198 assert(resShape.size() == 2 &&
"Chunked Store must access 2D tensor tile.");
1200 instData[0] = subgroupSize;
1202 std::min(
static_cast<int>(consumerInstData[1]), maxChunkSize);
1203 return xegpu::LayoutAttr::get(context, instData);
1205 laneLayout[0] = subgroupSize;
1207 std::min(
static_cast<int>(consumerLaneData[1]), maxChunkSize);
1208 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1221 auto context = resVecTy.getContext();
1222 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1224 const auto *uArchInstruction =
1225 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1227 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1230 (chunkSize > 1), maxChunkSize, resShape,
1236xegpu::DistributeLayoutAttr
1238 VectorType resVecTy,
1239 xegpu::DistributeLayoutAttr consumerLayout,
1244 auto context = resVecTy.getContext();
1245 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1247 const auto *uArchInstruction =
1248 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1250 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1252 false, maxChunkSize, resShape,
1267static xegpu::DistributeLayoutAttr
1273 int srcShapeSize = srcShape.size();
1280 "subgroup layout assignment not supported for storeScatter.");
1284 if (!isChunkedStore) {
1286 instData[srcShapeSize - 1] =
1287 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1288 return xegpu::LayoutAttr::get(context, instData);
1290 laneLayout[srcShapeSize - 1] =
1291 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1292 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1295 assert(srcShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
1297 instData[0] = subgroupSize;
1298 instData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1299 return xegpu::LayoutAttr::get(context, instData);
1301 laneLayout[0] = subgroupSize;
1302 laneData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1303 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1310xegpu::DistributeLayoutAttr
1312 VectorType srcVecTy,
int chunkSize,
1315 const int subgroupSize =
uArch->getSubgroupSize();
1317 auto context = srcVecTy.getContext();
1318 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1320 const auto *uArchInstruction =
1321 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1323 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1325 maxChunkSize, srcShape, subgroupSize);
1329xegpu::DistributeLayoutAttr
1331 VectorType srcVecTy,
1334 const int subgroupSize =
uArch->getSubgroupSize();
1336 auto context = srcVecTy.getContext();
1337 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1339 const auto *uArchInstruction =
1340 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1342 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1345 srcShape, subgroupSize);
1353template <
typename RankedTy>
1356 std::optional<unsigned> packingSize = std::nullopt,
bool vnni =
false) {
1358 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
1359 "Expected 1D non-vnni or 2D vector.");
1361 assert(ty.getElementType().isIntOrFloat() &&
1362 "Expected int or float element type.");
1364 auto context = ty.getContext();
1365 auto rank = ty.getRank();
1368 if (packingSize.has_value()) {
1369 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1370 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1371 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1374 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1389 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1390 if (sgCount % sgLayout0)
1392 int64_t sgLayout1 = sgCount / sgLayout0;
1393 int64_t sgData0 = wgShape[0] / sgLayout0;
1394 int64_t sgData1 = wgShape[1] / sgLayout1;
1395 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1396 (sgData0 % instData[0] || sgData1 % instData[1]))
1398 candidates.emplace_back(sgLayout0, sgLayout1);
1405 int diffLhs = std::abs(
lhs.first -
lhs.second);
1406 int diffRhs = std::abs(
rhs.first -
rhs.second);
1407 if (diffLhs != diffRhs)
1408 return diffLhs < diffRhs;
1409 return lhs.first <
rhs.first;
1420 bool isDpasMx =
false) {
1425 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1433 const unsigned dataALen = aTy.getShape().front();
1434 auto supportedALen = uArchInstruction->
getSupportedM(aTy.getElementType());
1438 const unsigned dataBLen = bTy.getShape().back();
1439 auto supportedBLen = uArchInstruction->
getSupportedN(bTy.getElementType());
1443 auto supportedCLen = uArchInstruction->
getSupportedN(cdTy.getElementType());
1446 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1447 return std::nullopt;
1451 int kDimSize = subgroupSize;
1453 auto supportedKLen = uArchInstruction->
getSupportedK(aTy.getElementType());
1454 kDimSize = supportedKLen[0];
1458 instDataA[aTy.getRank() - 2] = maxALen;
1459 instDataA[aTy.getRank() - 1] = kDimSize;
1461 instDataB[bTy.getRank() - 2] = kDimSize;
1462 instDataB[bTy.getRank() - 1] = maxBLen;
1464 instDataCD[cdTy.getRank() - 2] = maxALen;
1465 instDataCD[cdTy.getRank() - 1] = maxCLen;
1466 return std::make_tuple(instDataA, instDataB, instDataCD);
1471static std::optional<
1472 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1473 xegpu::DistributeLayoutAttr>>
1475 VectorType bTy, VectorType cdTy,
1476 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1480 return std::nullopt;
1481 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1482 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1483 instDataCD.size() == 2 &&
1484 "Sg layout creation expects valid 2D inst data");
1486 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1487 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1489 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1496 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1497 return std::nullopt;
1503 std::optional<LayoutRepresentation> bestPick;
1505 return aTy.getShape().back() / sgLayout.second ==
1506 bTy.getShape().front() / sgLayout.first;
1508 for (
auto &sgLayout : layoutsB) {
1509 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1510 if (!checkAlignedSgDataAB(sgLayout))
1513 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1514 bestPick = sgLayout;
1522 bestPick = sgLayout;
1526 return std::nullopt;
1529 static_cast<int>(bestPick->second)};
1530 SmallVector<int> sgDataA = {
static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1531 static_cast<int>(aTy.getShape()[1])};
1533 static_cast<int>(bTy.getShape()[0]),
1534 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1536 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1537 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1542 nullptr,
nullptr,
nullptr);
1546 nullptr,
nullptr,
nullptr);
1550 nullptr,
nullptr,
nullptr);
1552 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1559 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1560 xegpu::DistributeLayoutAttr>>
1562 VectorType bTy, VectorType cdTy,
1563 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1565 auto context = aTy.getContext();
1566 const auto *uArchInstruction =
1572 "Number of subgroups must be provided for sg layout creation.");
1578 return std::nullopt;
1579 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1580 return std::make_tuple(
1581 xegpu::LayoutAttr::get(
1583 xegpu::LayoutAttr::get(
1585 xegpu::LayoutAttr::get(
1589 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1591 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1594 return std::make_tuple(aLayout, bLayout, cdLayout);
1596 return std::nullopt;
1603static xegpu::DistributeLayoutAttr
1605 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1607 if (!scaleTy || !matrixLayout)
1615 if (scaleShape.empty())
1618 auto uArchInstruction =
1619 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1623 int64_t rank = matrixLayout.getRank();
1624 assert(rank == 2 &&
"dpas layouts must be two dimensions");
1631 auto order = matrixLayout.getOrder();
1635 if (!sgLayout.empty() && !sgData.empty()) {
1636 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1637 scaleSgData.assign(sgData.begin(), sgData.end());
1638 scaleSgData[rank - 2] = std::max<int64_t>(
1639 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1640 scaleSgData[rank - 1] = std::max<int64_t>(
1641 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1648 if (!instData.empty()) {
1649 scaleInstData.assign(instData.begin(), instData.end());
1651 scaleInstData[rank - 2] = std::max<int64_t>(
1652 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1655 scaleInstData[rank - 1] = std::max<int64_t>(
1656 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1662 if (!laneLayout.empty() && !laneData.empty()) {
1663 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1664 scaleLaneData.assign(laneData.begin(), laneData.end());
1665 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1666 if (isBScale ^ isRowMajor) {
1667 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1668 scaleLaneLayout[rank - 2] =
1669 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1671 scaleLaneData[rank - 2] =
1672 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1673 scaleLaneData[rank - 1] =
1674 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1676 return xegpu::LayoutAttr::get(
1678 scaleSgLayout.empty() ?
nullptr
1680 scaleSgData.empty() ?
nullptr
1682 scaleInstData.empty() ?
nullptr
1684 scaleLaneLayout.empty()
1687 scaleLaneData.empty() ?
nullptr
1696 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1697 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1698 xegpu::DistributeLayoutAttr>>
1700 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1701 VectorType bScaleTy,
1702 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1704 auto context = aTy.getContext();
1708 "Number of subgroups must be provided for sg layout creation.");
1710 consumerLayout, numSg,
uArch);
1712 return std::nullopt;
1714 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1723 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1729 return std::nullopt;
1730 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1732 auto dpasALayout = xegpu::LayoutAttr::get(
1734 auto dpasBLayout = xegpu::LayoutAttr::get(
1736 auto dpasCDLayout = xegpu::LayoutAttr::get(
1745 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1748 const auto *uArchInstruction =
1752 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1754 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1763 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1766 return std::nullopt;
1769xegpu::DistributeLayoutAttr
1771 xegpu::DistributeLayoutAttr resLayout) {
1778 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1779 auto srcTy = dyn_cast<VectorType>(
broadcast.getSourceType());
1783 resLayout,
broadcast.getResultVectorType().getShape(),
1790 if (
auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1799 if (
auto reduction = dyn_cast<vector::ReductionOp>(op))
1804 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1805 int resElemBitWidth =
1806 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1807 int srcElemBitWidth =
1808 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1815 if (
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1817 resLayout, shapeCast.getResultVectorType().getShape(),
1818 shapeCast.getSourceVectorType().getShape());
1823 if (
auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1826 resLayout, insertSlice.getDestVectorType().getShape(),
1827 insertSlice.getSourceVectorType().getShape());
1835 if (
auto insert = dyn_cast<vector::InsertOp>(op)) {
1836 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1837 VectorType valueToStoreTy =
1838 dyn_cast<VectorType>(insert.getValueToStore().getType());
1840 if ((idx == 0) && valueToStoreTy) {
1842 valueToStoreTy.getShape());
1850 if (
auto extract = dyn_cast<vector::ExtractOp>(op)) {
1851 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1852 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1853 if (!srcVecTy || !resVecTy)
1856 srcVecTy.getShape());
1861 if (
auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1863 transpose.getPermutation());
1868 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1869 int resElemBitWidth =
1870 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1871 int srcElemBitWidth =
1872 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1878 if (
auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1883 if (
auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1888 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1900 xegpu::DistributeLayoutAttr resLayout;
1904 if (inferredOperandLayout)
1905 return inferredOperandLayout;
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region ®ion, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
bool hasAttrOfType(NameT &&name)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
@ SubgroupMatrixMultiplyAcc
@ SubgroupScaledMatrixMultiplyAcc
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
const Instruction * getInstruction(InstructionKind instKind) const