30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/Support/FormatVariadic.h"
40 out.reserve(attrs.size());
42 for (
auto attr : attrs) {
43 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
44 auto newLayout = dist.dropSgLayoutAndData();
46 out.emplace_back(attr.getName(), newLayout);
58 out.reserve(attrs.size());
60 for (
auto attr : attrs) {
61 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
62 auto newLayout = dist.dropInstData();
64 out.emplace_back(attr.getName(), newLayout);
76 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.
getType());
77 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
79 auto typeWithLayout = xegpu::TensorDescType::get(
80 tensorDescTy.getContext(), tensorDescTy.getShape(),
81 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
96 llvm::ReversePostOrderTraversal<Region *> rpot(®ion);
98 for (
Block *block : llvm::reverse(blocks)) {
100 for (
Operation &op : llvm::reverse(*block)) {
108 for (
Region &nested : op.getRegions())
117 xegpu::DistributeLayoutAttr layout =
nullptr;
152 if (op->
getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
163 if (isa<xegpu::TensorDescType>(resultType))
168 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
171 if (isa<vector::DeinterleaveOp>(op))
175 xegpu::DistributeLayoutAttr operandLayout =
177 if (isa<VectorType>(opr.get().getType()) && operandLayout)
189 mlir::RegionBranchTerminatorOpInterface yieldOp) {
190 auto regionBranchOp =
191 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
197 yieldOp.getSuccessorRegions(operandAttrs, successors);
200 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
204 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
205 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
207 for (
unsigned i = 0; i < count; ++i) {
208 xegpu::DistributeLayoutAttr layout;
209 if (successor.isParent()) {
212 auto regionResult = regionBranchOp->getResult(i);
217 if (isa<xegpu::TensorDescType>(regionResult.getType()))
228 auto operandType = succOps[i].
getType();
229 if (isa<VectorType>(operandType) ||
230 dyn_cast<xegpu::TensorDescType>(operandType))
245 for (
Region ®ion : regionOp->getRegions()) {
250 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
251 for (
auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
257 if (isa<xegpu::TensorDescType>(regionArg.getType()))
263 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
264 for (
Value predVal : predValues) {
266 for (
OpOperand &operand : regionOp->getOpOperands()) {
267 if (operand.get() == predVal)
305 auto processFunc = [&](
Region &body, StringRef funcName) {
307 if (
auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
309 }
else if (
auto yieldOp =
310 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
312 }
else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
318 rootOp->
walk([&](func::FuncOp
func) {
319 processFunc(
func.getBody(),
func.getSymName());
321 rootOp->
walk([&](gpu::GPUFuncOp
func) {
322 processFunc(
func.getBody(),
func.getName());
328template <
typename T,
typename>
330 Operation *owner = operandOrResult.getOwner();
348 for (
auto namedAttr : nestOp->
getAttrs()) {
349 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
350 attrsToRemove.push_back(namedAttr.getName());
352 for (
auto attrName : attrsToRemove)
361 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
362 attrsToRemove.push_back(namedAttr.getName());
364 for (
auto attrName : attrsToRemove)
371xegpu::DistributeLayoutAttr
377 size_t dimDiff = resShape.size() - srcShape.size();
378 auto bcastSourceLayout = resLayout;
379 for (
size_t i = dimDiff; i < resShape.size(); i++) {
380 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
381 bcastDims.push_back(i);
386 if (!bcastDims.empty())
387 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
391 for (
size_t i = 0; i < dimDiff; i++)
392 sliceDims.push_back(i);
393 bcastSourceLayout = xegpu::SliceAttr::get(
394 resLayout.getContext(), bcastSourceLayout,
397 return bcastSourceLayout;
402xegpu::DistributeLayoutAttr
406 assert(isa<xegpu::SliceAttr>(resLayout) &&
407 "reduction result layout must be slice layout");
409 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
411 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
412 "reduction dims must match with slice dims");
414 return sliceLayout.getParent();
417xegpu::DistributeLayoutAttr
424xegpu::DistributeLayoutAttr
427 return resLayout.transposeDims(permutation);
433xegpu::DistributeLayoutAttr
435 int resElemTyBitWidth,
int srcElemTyBitWidth) {
440 size_t sgDataSize = sgData.size();
441 size_t instDataSize = instData.size();
442 size_t laneDataSize = laneData.size();
446 int64_t dim = resLayout.getRank() - 1;
448 if (srcElemTyBitWidth <= resElemTyBitWidth) {
449 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
451 sgDataValue = sgData.back() * bitWidthRatio;
453 instDataValue = instData.back() * bitWidthRatio;
455 laneDataValue = laneData.back() * bitWidthRatio;
457 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
459 assert((sgData.back() % bitWidthRatio) == 0 &&
460 "sgData not divisible by bitWidthRatio");
461 sgDataValue = sgData.back() / bitWidthRatio;
464 assert((instData.back() % bitWidthRatio) == 0 &&
465 "instData not divisible by bitWidthRatio");
466 instDataValue = instData.back() / bitWidthRatio;
469 assert((laneData.back() % bitWidthRatio) == 0 &&
470 "laneData not divisible by bitWidthRatio");
471 laneDataValue = laneData.back() / bitWidthRatio;
475 xegpu::DistributeLayoutAttr finalSrcLayout;
477 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
479 return finalSrcLayout;
486xegpu::DistributeLayoutAttr
492 size_t sgDataSize = sgData.size();
493 size_t instDataSize = instData.size();
494 size_t laneDataSize = laneData.size();
498 int64_t dim = resLayout.getRank() - 1;
502 constexpr int ratio = 2;
504 assert((sgData.back() % ratio) == 0 &&
505 "sgData not divisible by interleave ratio");
506 sgDataValue = sgData.back() / ratio;
509 assert((instData.back() % ratio) == 0 &&
510 "instData not divisible by interleave ratio");
511 instDataValue = instData.back() / ratio;
514 assert((laneData.back() % ratio) == 0 &&
515 "laneData not divisible by interleave ratio");
516 laneDataValue = laneData.back() / ratio;
519 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
526xegpu::DistributeLayoutAttr
532 size_t sgDataSize = sgData.size();
533 size_t instDataSize = instData.size();
534 size_t laneDataSize = laneData.size();
538 int64_t dim = resLayout.getRank() - 1;
542 constexpr int ratio = 2;
544 sgDataValue = sgData.back() * ratio;
546 instDataValue = instData.back() * ratio;
548 laneDataValue = laneData.back() * ratio;
550 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
560 int srcShapeSize = srcShape.size();
561 int resShapeSize = resShape.size();
562 int dimDiff = resShapeSize - srcShapeSize;
567 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
568 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
569 for (
int i = 0; i < dimDiff; i++) {
570 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
571 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
572 "Leading dimensions being sliced off must not be distributed");
574 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
583xegpu::DistributeLayoutAttr
588 int srcShapeSize = srcShape.size();
589 int resShapeSize = resShape.size();
590 int dimDiff = resShapeSize - srcShapeSize;
595 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
596 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
597 for (
int i = 0; i < dimDiff; i++) {
598 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
599 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
600 "Leading dimensions being sliced off must not be distributed");
602 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
612xegpu::DistributeLayoutAttr
617 int srcShapeSize = srcShape.size();
618 int resShapeSize = resShape.size();
619 int dimDiff = srcShapeSize - resShapeSize;
620 auto context = resLayout.getContext();
624 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
625 auto sgData = resLayout.getEffectiveSgDataAsInt();
626 auto instData = resLayout.getEffectiveInstDataAsInt();
627 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
628 auto laneData = resLayout.getEffectiveLaneDataAsInt();
629 auto order = resLayout.getEffectiveOrderAsInt();
639 for (
auto &o : order)
645 for (
int i = 0; i < dimDiff; i++) {
646 if (!sgLayout.empty())
647 sgLayout.insert(sgLayout.begin(), 1);
649 sgData.insert(sgData.begin(), 1);
650 if (!instData.empty())
651 instData.insert(instData.begin(), 1);
652 if (!laneLayout.empty())
653 laneLayout.insert(laneLayout.begin(), 1);
654 if (!laneData.empty())
655 laneData.insert(laneData.begin(), 1);
656 order.push_back(dimDiff - 1 - i);
666 auto srcLayout = xegpu::LayoutAttr::get(
667 context, sgLayout.empty() ?
nullptr : toAttr(sgLayout),
668 sgData.empty() ?
nullptr : toAttr(sgData),
669 instData.empty() ?
nullptr : toAttr(instData),
670 laneLayout.empty() ?
nullptr : toAttr(laneLayout),
671 laneData.empty() ?
nullptr : toAttr(laneData),
672 (!orderAttr || orderAttr.empty()) ?
nullptr : toAttr(order));
680xegpu::DistributeLayoutAttr
704 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
711 auto srcLayout = resLayout;
712 for (
const auto &dimGroup : splitDimGroups)
713 srcLayout = srcLayout.collapseDims(dimGroup);
722 if ((dst.size() != 2) && (dst.size() != 1))
724 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
725 std::multiplies<int64_t>());
727 return (dst[0] == srcSize);
728 return (dst[0] == 1) && (dst[1] == srcSize);
731 if (matchCollapseToInnermostDim(srcShape, resShape)) {
732 int srcShapeSize = srcShape.size();
733 int resShapeSize = resShape.size();
734 auto context = resLayout.getContext();
735 auto resInstData = resLayout.getEffectiveInstDataAsInt();
736 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
737 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
753 if (resInstData.size() != 0) {
755 for (
int i = 0; i < resShapeSize - 1; i++) {
756 assert(resInstData[i] == 1 &&
757 "only innermost dim can have non-unit instData");
760 inferredInstData[srcShapeSize - 1] =
761 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
762 return xegpu::LayoutAttr::get(context, inferredInstData);
765 if (resLaneLayout.size() != 0) {
766 for (
int i = 0; i < resShapeSize - 1; i++) {
767 assert(resLaneData[i] == 1 &&
768 "only innermost dim can have non-unit instData");
770 assert(srcShape.back() % resLaneLayout.back() == 0 &&
771 "source innermost dim must be >= result lane layout");
774 inferredLaneLayout.back() = resLaneLayout.back();
775 inferredLaneData.back() = std::min(
776 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
777 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
781 llvm_unreachable(
"running into unsupported shape cast scenarios");
788 xegpu::DistributeLayoutAttr payloadLayout,
int chunkSize) {
789 auto rank = payloadLayout.getRank();
791 return payloadLayout.dropDims(
792 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
793 return payloadLayout;
862 auto srcShape = srcVecTy.getShape();
863 int srcRank = srcShape.size();
864 auto context = srcVecTy.getContext();
872 const int subgroupSize =
uArch->getSubgroupSize();
873 int64_t maxReduceVectorSize = 1;
874 xegpu::DistributeLayoutAttr srcLayout;
876 xegpu::SliceAttr consumerSliceLayout =
877 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
878 if (consumerSliceLayout &&
879 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
880 srcLayout = consumerSliceLayout.getParent();
882 srcLayout.getEffectiveSgLayoutAsInt();
885 for (
int dim = 0; dim < srcRank; dim++) {
886 if (llvm::is_contained(reductionDims, dim))
888 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
892 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
895 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
898 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
901 consumerLayout ? consumerLayout.getOrder() :
nullptr;
903 int remainingSgCount =
904 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
908 for (
int i = 0; i < srcRank; i++) {
909 if (!llvm::is_contained(reductionDims, i) &&
910 consumerIdx <
static_cast<int>(consumerSgLayout.size())) {
911 sgLayout[i] = consumerSgLayout[consumerIdx];
912 sgData[i] = consumerSgData[consumerIdx];
913 remainingSgCount /= sgLayout[i];
914 order[i] = consumerOrder[consumerIdx];
921 int64_t remainOrder = consumerSgLayout.size();
922 for (
int i = 0; i < srcRank; i++) {
923 if (llvm::is_contained(reductionDims, i)) {
925 std::min(srcShape[i],
static_cast<int64_t>(remainingSgCount));
926 assert((srcShape[i] % sgLayout[i] == 0) &&
927 "source shape not divisible by sg_layout");
928 sgData[i] = srcShape[i] / sgLayout[i];
929 remainingSgCount /= sgLayout[i];
930 order[i] = remainOrder++;
934 assert(remainingSgCount == 1 &&
"not all subgroups distributed");
935 srcLayout = xegpu::LayoutAttr::get(
936 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
939 (!orderAttr || orderAttr.empty()) ?
nullptr : toInt32Attr(order));
945 instData[srcRank - 2] =
946 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
947 instData[srcRank - 1] =
948 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
949 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
953 laneLayout[srcRank - 1] =
954 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
956 laneData[srcRank - 2] =
957 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
958 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
959 toInt32Attr(laneData));
962 return xegpu::SliceAttr::get(context, srcLayout,
973 auto srcShape = srcVecTy.getShape();
974 auto context = srcVecTy.getContext();
975 auto subgroupSize =
uArch->getSubgroupSize();
976 xegpu::LayoutAttr srcLayout;
979 assert(
true &&
"subgroup layout assignment not supported for reduction (op "
980 "is not expected at this level).");
982 assert(
true &&
"instData layout assignment not supported for reduction (op "
983 "is not expected at this level).");
986 laneLayout[0] = std::min(subgroupSize,
static_cast<int32_t
>(srcShape[0]));
988 srcLayout = xegpu::LayoutAttr::get(
993 auto result = xegpu::SliceAttr::get(context, srcLayout,
1024 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1025 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1033 consumerLayout.getEffectiveLaneLayoutAsInt();
1035 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1036 "laneData must be available for all dimensions");
1037 size_t innerMostDim = srcShape.size() - 1;
1041 if (srcElemTyBitWidth > resElemTyBitWidth) {
1045 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1047 sgDataValue = sgData[innerMostDim];
1048 while ((sgDataValue <= resShape[innerMostDim]) &&
1049 (sgDataValue % bitWidthRatio) != 0)
1052 instDataValue = instData[innerMostDim];
1053 const int innermostDimLaneLayout = laneLayout.empty()
1054 ?
uArch->getSubgroupSize()
1055 : laneLayout[innerMostDim];
1058 while ((instDataValue <= resShape[innerMostDim]) &&
1059 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1061 assert((resShape[innerMostDim] % instDataValue) == 0 &&
1062 "resShape, instData, and lanelayout for innermost must be 2^n !");
1064 laneDataValue = laneData[innerMostDim];
1065 while ((laneDataValue <= resShape[innerMostDim]) &&
1066 (laneDataValue % bitWidthRatio != 0))
1070 xegpu::DistributeLayoutAttr resLayout;
1071 resLayout = consumerLayout.setDimData(innerMostDim, sgDataValue,
1072 instDataValue, laneDataValue);
1075 return consumerLayout;
1102 consumerLayout.getEffectiveLaneLayoutAsInt();
1104 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1105 "consumer layout rank must match source shape rank");
1106 const size_t innerMostDim = srcShape.size() - 1;
1112 constexpr int ratio = 2;
1115 sgDataValue = sgData[innerMostDim];
1117 while ((sgDataValue <= srcShape[innerMostDim]) &&
1118 (sgDataValue % ratio != 0))
1119 sgDataValue *= ratio;
1121 instDataValue = instData[innerMostDim];
1122 const int innermostDimLaneLayout = laneLayout.empty()
1123 ?
uArch->getSubgroupSize()
1124 : laneLayout[innerMostDim];
1127 while ((instDataValue <= srcShape[innerMostDim]) &&
1128 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1129 instDataValue *= ratio;
1130 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1131 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1133 laneDataValue = laneData[innerMostDim];
1136 while ((laneDataValue <= srcShape[innerMostDim]) &&
1137 (laneDataValue % ratio != 0))
1138 laneDataValue *= ratio;
1141 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1150 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1153 xegpu::DistributeLayoutAttr requiredResLayout;
1155 consumerLayout.getEffectiveInstDataAsInt();
1157 consumerLayout.getEffectiveLaneDataAsInt();
1159 consumerLayout.getEffectiveLaneLayoutAsInt();
1164 requiredResLayout = consumerLayout;
1165 int srcRank = srcShape.size();
1169 "subgroup layout assignment not supported for insertStridedSlice.");
1171 for (
int dim = 0; dim < srcRank; dim++) {
1172 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1174 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1177 for (
int dim = 0; dim < srcRank; dim++) {
1178 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1179 "srcShape must be divisible by laneLayout for all dimensions");
1180 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1181 consumerLaneData[dim]);
1183 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1186 return requiredResLayout;
1203 xegpu::DistributeLayoutAttr consumerLayout,
bool isChunkedLoad,
1207 return consumerLayout;
1210 consumerLayout.getEffectiveInstDataAsInt();
1212 consumerLayout.getEffectiveLaneDataAsInt();
1218 if (!isChunkedLoad) {
1220 instData.back() = std::min(
static_cast<int>(consumerInstData.back()),
1221 maxChunkSize * subgroupSize);
1222 return xegpu::LayoutAttr::get(context, instData);
1225 std::min(
static_cast<int>(consumerLaneData.back()), maxChunkSize);
1226 laneLayout.back() = std::min(
static_cast<int64_t>(subgroupSize),
1227 resShape.back() / laneData.back());
1228 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1231 assert(resShape.size() == 2 &&
"Chunked Store must access 2D tensor tile.");
1233 instData[0] = subgroupSize;
1235 std::min(
static_cast<int>(consumerInstData[1]), maxChunkSize);
1236 return xegpu::LayoutAttr::get(context, instData);
1238 laneLayout[0] = subgroupSize;
1240 std::min(
static_cast<int>(consumerLaneData[1]), maxChunkSize);
1241 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1254 auto context = resVecTy.getContext();
1255 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1257 const auto *uArchInstruction =
1258 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1260 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1263 (chunkSize > 1), maxChunkSize, resShape,
1269xegpu::DistributeLayoutAttr
1271 VectorType resVecTy,
1272 xegpu::DistributeLayoutAttr consumerLayout,
1277 auto context = resVecTy.getContext();
1278 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1280 const auto *uArchInstruction =
1281 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1283 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1285 false, maxChunkSize, resShape,
1300static xegpu::DistributeLayoutAttr
1306 int srcShapeSize = srcShape.size();
1313 "subgroup layout assignment not supported for storeScatter.");
1317 if (!isChunkedStore) {
1319 instData[srcShapeSize - 1] =
1320 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1321 return xegpu::LayoutAttr::get(context, instData);
1323 laneLayout[srcShapeSize - 1] =
1324 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1325 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1328 assert(srcShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
1330 instData[0] = subgroupSize;
1331 instData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1332 return xegpu::LayoutAttr::get(context, instData);
1334 laneLayout[0] = subgroupSize;
1335 laneData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1336 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1343xegpu::DistributeLayoutAttr
1345 VectorType srcVecTy,
int chunkSize,
1348 const int subgroupSize =
uArch->getSubgroupSize();
1350 auto context = srcVecTy.getContext();
1351 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1353 const auto *uArchInstruction =
1354 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1356 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1358 maxChunkSize, srcShape, subgroupSize);
1362xegpu::DistributeLayoutAttr
1364 VectorType srcVecTy,
1367 const int subgroupSize =
uArch->getSubgroupSize();
1369 auto context = srcVecTy.getContext();
1370 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1372 const auto *uArchInstruction =
1373 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1375 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1378 srcShape, subgroupSize);
1386template <
typename RankedTy>
1389 std::optional<unsigned> packingSize = std::nullopt,
bool vnni =
false) {
1391 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
1392 "Expected 1D non-vnni or 2D vector.");
1394 assert(ty.getElementType().isIntOrFloat() &&
1395 "Expected int or float element type.");
1397 auto context = ty.getContext();
1398 auto rank = ty.getRank();
1401 if (packingSize.has_value()) {
1402 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1403 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1404 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1407 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1422 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1423 if (sgCount % sgLayout0)
1425 int64_t sgLayout1 = sgCount / sgLayout0;
1426 int64_t sgData0 = wgShape[0] / sgLayout0;
1427 int64_t sgData1 = wgShape[1] / sgLayout1;
1428 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1429 (sgData0 % instData[0] || sgData1 % instData[1]))
1431 candidates.emplace_back(sgLayout0, sgLayout1);
1438 int diffLhs = std::abs(
lhs.first -
lhs.second);
1439 int diffRhs = std::abs(
rhs.first -
rhs.second);
1440 if (diffLhs != diffRhs)
1441 return diffLhs < diffRhs;
1442 return lhs.first <
rhs.first;
1453 bool isDpasMx =
false) {
1458 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1466 const unsigned dataALen = aTy.getShape().front();
1467 auto supportedALen = uArchInstruction->
getSupportedM(aTy.getElementType());
1471 const unsigned dataBLen = bTy.getShape().back();
1472 auto supportedBLen = uArchInstruction->
getSupportedN(bTy.getElementType());
1476 auto supportedCLen = uArchInstruction->
getSupportedN(cdTy.getElementType());
1479 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1480 return std::nullopt;
1484 int kDimSize = subgroupSize;
1486 auto supportedKLen = uArchInstruction->
getSupportedK(aTy.getElementType());
1487 if (supportedKLen.empty())
1488 return std::nullopt;
1489 kDimSize = supportedKLen[0];
1493 instDataA[aTy.getRank() - 2] = maxALen;
1494 instDataA[aTy.getRank() - 1] = kDimSize;
1496 instDataB[bTy.getRank() - 2] = kDimSize;
1497 instDataB[bTy.getRank() - 1] = maxBLen;
1499 instDataCD[cdTy.getRank() - 2] = maxALen;
1500 instDataCD[cdTy.getRank() - 1] = maxCLen;
1501 return std::make_tuple(instDataA, instDataB, instDataCD);
1506static std::optional<
1507 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1508 xegpu::DistributeLayoutAttr>>
1510 VectorType bTy, VectorType cdTy,
1511 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1515 return std::nullopt;
1516 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1517 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1518 instDataCD.size() == 2 &&
1519 "Sg layout creation expects valid 2D inst data");
1521 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1522 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1524 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1531 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1532 return std::nullopt;
1538 std::optional<LayoutRepresentation> bestPick;
1540 return aTy.getShape().back() / sgLayout.second ==
1541 bTy.getShape().front() / sgLayout.first;
1543 for (
auto &sgLayout : layoutsB) {
1544 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1545 if (!checkAlignedSgDataAB(sgLayout))
1548 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1549 bestPick = sgLayout;
1557 bestPick = sgLayout;
1561 return std::nullopt;
1564 static_cast<int>(bestPick->second)};
1565 SmallVector<int> sgDataA = {
static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1566 static_cast<int>(aTy.getShape()[1])};
1568 static_cast<int>(bTy.getShape()[0]),
1569 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1571 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1572 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1577 nullptr,
nullptr,
nullptr);
1581 nullptr,
nullptr,
nullptr);
1585 nullptr,
nullptr,
nullptr);
1587 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1594 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1595 xegpu::DistributeLayoutAttr>>
1597 VectorType bTy, VectorType cdTy,
1598 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1600 auto context = aTy.getContext();
1601 const auto *uArchInstruction =
1607 "Number of subgroups must be provided for sg layout creation.");
1613 return std::nullopt;
1614 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1615 return std::make_tuple(
1616 xegpu::LayoutAttr::get(
1618 xegpu::LayoutAttr::get(
1620 xegpu::LayoutAttr::get(
1624 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1626 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1629 return std::make_tuple(aLayout, bLayout, cdLayout);
1631 return std::nullopt;
1638static xegpu::DistributeLayoutAttr
1640 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1642 if (!scaleTy || !matrixLayout)
1650 if (scaleShape.empty())
1653 auto uArchInstruction =
1654 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1658 int64_t rank = matrixLayout.getRank();
1659 assert(rank == 2 &&
"dpas layouts must be two dimensions");
1666 auto order = matrixLayout.getOrder();
1670 if (!sgLayout.empty() && !sgData.empty()) {
1671 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1672 scaleSgData.assign(sgData.begin(), sgData.end());
1673 scaleSgData[rank - 2] = std::max<int64_t>(
1674 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1675 scaleSgData[rank - 1] = std::max<int64_t>(
1676 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1683 if (!instData.empty()) {
1684 scaleInstData.assign(instData.begin(), instData.end());
1686 scaleInstData[rank - 2] = std::max<int64_t>(
1687 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1690 scaleInstData[rank - 1] = std::max<int64_t>(
1691 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1697 if (!laneLayout.empty() && !laneData.empty()) {
1698 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1699 scaleLaneData.assign(laneData.begin(), laneData.end());
1700 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1701 if (isBScale ^ isRowMajor) {
1702 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1703 scaleLaneLayout[rank - 2] =
1704 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1706 scaleLaneData[rank - 2] =
1707 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1708 scaleLaneData[rank - 1] =
1709 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1711 return xegpu::LayoutAttr::get(
1713 scaleSgLayout.empty() ?
nullptr
1715 scaleSgData.empty() ?
nullptr
1717 scaleInstData.empty() ?
nullptr
1719 scaleLaneLayout.empty()
1722 scaleLaneData.empty() ?
nullptr
1731 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1732 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1733 xegpu::DistributeLayoutAttr>>
1735 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1736 VectorType bScaleTy,
1737 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1739 auto context = aTy.getContext();
1743 "Number of subgroups must be provided for sg layout creation.");
1745 consumerLayout, numSg,
uArch);
1747 return std::nullopt;
1749 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1758 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1764 return std::nullopt;
1765 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1767 auto dpasALayout = xegpu::LayoutAttr::get(
1769 auto dpasBLayout = xegpu::LayoutAttr::get(
1771 auto dpasCDLayout = xegpu::LayoutAttr::get(
1780 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1783 const auto *uArchInstruction =
1787 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1789 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1798 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1801 return std::nullopt;
1805 OpOperand &operand, xegpu::DistributeLayoutAttr resLayout) {
1812 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1813 auto srcTy = dyn_cast<VectorType>(
broadcast.getSourceType());
1817 resLayout,
broadcast.getResultVectorType().getShape(),
1824 if (
auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1833 if (
auto reduction = dyn_cast<vector::ReductionOp>(op))
1838 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1839 int resElemBitWidth =
1840 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1841 int srcElemBitWidth =
1842 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1849 if (
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1851 resLayout, shapeCast.getResultVectorType().getShape(),
1852 shapeCast.getSourceVectorType().getShape());
1857 if (
auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1860 resLayout, insertSlice.getDestVectorType().getShape(),
1861 insertSlice.getSourceVectorType().getShape());
1869 if (
auto insert = dyn_cast<vector::InsertOp>(op)) {
1870 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1871 VectorType valueToStoreTy =
1872 dyn_cast<VectorType>(insert.getValueToStore().getType());
1874 if ((idx == 0) && valueToStoreTy) {
1876 valueToStoreTy.getShape());
1884 if (
auto extract = dyn_cast<vector::ExtractOp>(op)) {
1885 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1886 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1887 if (!srcVecTy || !resVecTy)
1890 srcVecTy.getShape());
1895 if (
auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1897 transpose.getPermutation());
1902 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1903 int resElemBitWidth =
1904 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1905 int srcElemBitWidth =
1906 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1912 if (
auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1917 if (
auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1922 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1938 if (isa<xegpu::AnchorLayoutInterface>(op))
1942 xegpu::DistributeLayoutAttr resLayout;
1943 if (op->
getNumResults() == 1 || isa<vector::DeinterleaveOp>(op))
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region ®ion, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
bool hasAttrOfType(NameT &&name)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
@ SubgroupMatrixMultiplyAcc
@ SubgroupScaledMatrixMultiplyAcc
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
const Instruction * getInstruction(InstructionKind instKind) const