30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/Support/FormatVariadic.h"
40 out.reserve(attrs.size());
42 for (
auto attr : attrs) {
43 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
44 auto newLayout = dist.dropSgLayoutAndData();
46 out.emplace_back(attr.getName(), newLayout);
58 out.reserve(attrs.size());
60 for (
auto attr : attrs) {
61 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
62 auto newLayout = dist.dropInstData();
64 out.emplace_back(attr.getName(), newLayout);
76 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.
getType());
77 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
79 auto typeWithLayout = xegpu::TensorDescType::get(
80 tensorDescTy.getContext(), tensorDescTy.getShape(),
81 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
96 llvm::ReversePostOrderTraversal<Region *> rpot(®ion);
98 for (
Block *block : llvm::reverse(blocks)) {
100 for (
Operation &op : llvm::reverse(*block)) {
108 for (
Region &nested : op.getRegions())
117 xegpu::DistributeLayoutAttr layout =
nullptr;
152 if (op->
getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
163 if (isa<xegpu::TensorDescType>(resultType))
168 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
171 if (isa<vector::DeinterleaveOp>(op))
175 xegpu::DistributeLayoutAttr operandLayout =
177 if (isa<VectorType>(opr.get().getType()) && operandLayout)
189 mlir::RegionBranchTerminatorOpInterface yieldOp) {
190 auto regionBranchOp =
191 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
197 yieldOp.getSuccessorRegions(operandAttrs, successors);
200 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
204 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
205 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
207 for (
unsigned i = 0; i < count; ++i) {
208 xegpu::DistributeLayoutAttr layout;
209 if (successor.isOperation()) {
212 auto regionResult = regionBranchOp->getResult(i);
217 if (isa<xegpu::TensorDescType>(regionResult.getType()))
228 auto operandType = succOps[i].
getType();
229 if (isa<VectorType>(operandType) ||
230 dyn_cast<xegpu::TensorDescType>(operandType))
247 for (
Region ®ion : regionOp->getRegions()) {
252 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
253 for (
auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
254 auto layout = getLayoutOfValue(regionArg);
259 if (isa<xegpu::TensorDescType>(regionArg.getType()))
265 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
266 for (
Value predVal : predValues) {
268 for (
OpOperand &operand : regionOp->getOpOperands()) {
269 if (operand.get() == predVal)
308 auto processFunc = [&](
Region &body, StringRef funcName) {
310 if (
auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
313 }
else if (
auto yieldOp =
314 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
316 }
else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
322 rootOp->
walk([&](func::FuncOp
func) {
323 processFunc(
func.getBody(),
func.getSymName());
325 rootOp->
walk([&](gpu::GPUFuncOp
func) {
326 processFunc(
func.getBody(),
func.getName());
332template <
typename T,
typename>
334 Operation *owner = operandOrResult.getOwner();
352 for (
auto namedAttr : nestOp->
getAttrs()) {
353 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
354 attrsToRemove.push_back(namedAttr.getName());
356 for (
auto attrName : attrsToRemove)
365 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
366 attrsToRemove.push_back(namedAttr.getName());
368 for (
auto attrName : attrsToRemove)
375xegpu::DistributeLayoutAttr
381 size_t dimDiff = resShape.size() - srcShape.size();
382 auto bcastSourceLayout = resLayout;
383 for (
size_t i = dimDiff; i < resShape.size(); i++) {
384 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
385 bcastDims.push_back(i);
390 if (!bcastDims.empty())
391 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
395 for (
size_t i = 0; i < dimDiff; i++)
396 sliceDims.push_back(i);
397 bcastSourceLayout = xegpu::SliceAttr::get(
398 resLayout.getContext(), bcastSourceLayout,
401 return bcastSourceLayout;
406xegpu::DistributeLayoutAttr
410 assert(isa<xegpu::SliceAttr>(resLayout) &&
411 "reduction result layout must be slice layout");
413 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
415 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
416 "reduction dims must match with slice dims");
418 return sliceLayout.getParent();
421xegpu::DistributeLayoutAttr
432xegpu::DistributeLayoutAttr
443xegpu::DistributeLayoutAttr
445 int resElemTyBitWidth,
int srcElemTyBitWidth) {
450 size_t sgDataSize = sgData.size();
451 size_t instDataSize = instData.size();
452 size_t laneDataSize = laneData.size();
456 int64_t dim = resLayout.getRank() - 1;
458 if (srcElemTyBitWidth <= resElemTyBitWidth) {
459 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
461 sgDataValue = sgData.back() * bitWidthRatio;
463 instDataValue = instData.back() * bitWidthRatio;
465 laneDataValue = laneData.back() * bitWidthRatio;
467 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
469 assert((sgData.back() % bitWidthRatio) == 0 &&
470 "sgData not divisible by bitWidthRatio");
471 sgDataValue = sgData.back() / bitWidthRatio;
474 assert((instData.back() % bitWidthRatio) == 0 &&
475 "instData not divisible by bitWidthRatio");
476 instDataValue = instData.back() / bitWidthRatio;
479 assert((laneData.back() % bitWidthRatio) == 0 &&
480 "laneData not divisible by bitWidthRatio");
481 laneDataValue = laneData.back() / bitWidthRatio;
485 xegpu::DistributeLayoutAttr finalSrcLayout;
487 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
489 return finalSrcLayout;
496xegpu::DistributeLayoutAttr
502 size_t sgDataSize = sgData.size();
503 size_t instDataSize = instData.size();
504 size_t laneDataSize = laneData.size();
508 int64_t dim = resLayout.getRank() - 1;
512 constexpr int ratio = 2;
514 assert((sgData.back() % ratio) == 0 &&
515 "sgData not divisible by interleave ratio");
516 sgDataValue = sgData.back() / ratio;
519 assert((instData.back() % ratio) == 0 &&
520 "instData not divisible by interleave ratio");
521 instDataValue = instData.back() / ratio;
524 assert((laneData.back() % ratio) == 0 &&
525 "laneData not divisible by interleave ratio");
526 laneDataValue = laneData.back() / ratio;
529 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
536xegpu::DistributeLayoutAttr
542 size_t sgDataSize = sgData.size();
543 size_t instDataSize = instData.size();
544 size_t laneDataSize = laneData.size();
548 int64_t dim = resLayout.getRank() - 1;
552 constexpr int ratio = 2;
554 sgDataValue = sgData.back() * ratio;
556 instDataValue = instData.back() * ratio;
558 laneDataValue = laneData.back() * ratio;
560 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
570 int srcShapeSize = srcShape.size();
571 int resShapeSize = resShape.size();
572 int dimDiff = resShapeSize - srcShapeSize;
577 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
578 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
579 for (
int i = 0; i < dimDiff; i++) {
580 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
581 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
582 "Leading dimensions being sliced off must not be distributed");
584 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
593xegpu::DistributeLayoutAttr
598 int srcShapeSize = srcShape.size();
599 int resShapeSize = resShape.size();
600 int dimDiff = resShapeSize - srcShapeSize;
605 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
606 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
607 for (
int i = 0; i < dimDiff; i++) {
608 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
609 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
610 "Leading dimensions being sliced off must not be distributed");
612 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
622xegpu::DistributeLayoutAttr
627 int srcShapeSize = srcShape.size();
628 int resShapeSize = resShape.size();
629 int dimDiff = srcShapeSize - resShapeSize;
630 auto context = resLayout.getContext();
634 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
635 auto sgData = resLayout.getEffectiveSgDataAsInt();
636 auto instData = resLayout.getEffectiveInstDataAsInt();
637 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
638 auto laneData = resLayout.getEffectiveLaneDataAsInt();
639 auto order = resLayout.getEffectiveOrderAsInt();
649 for (
auto &o : order)
655 for (
int i = 0; i < dimDiff; i++) {
656 if (!sgLayout.empty())
657 sgLayout.insert(sgLayout.begin(), 1);
659 sgData.insert(sgData.begin(), 1);
660 if (!instData.empty())
661 instData.insert(instData.begin(), 1);
662 if (!laneLayout.empty())
663 laneLayout.insert(laneLayout.begin(), 1);
664 if (!laneData.empty())
665 laneData.insert(laneData.begin(), 1);
666 order.push_back(dimDiff - 1 - i);
676 auto srcLayout = xegpu::LayoutAttr::get(
677 context, sgLayout.empty() ?
nullptr : toAttr(sgLayout),
678 sgData.empty() ?
nullptr : toAttr(sgData),
679 instData.empty() ?
nullptr : toAttr(instData),
680 laneLayout.empty() ?
nullptr : toAttr(laneLayout),
681 laneData.empty() ?
nullptr : toAttr(laneData),
682 (!orderAttr || orderAttr.empty()) ?
nullptr : toAttr(order));
690xegpu::DistributeLayoutAttr
714 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
721 auto srcLayout = resLayout;
722 for (
const auto &dimGroup : splitDimGroups)
723 srcLayout = srcLayout.collapseDims(dimGroup);
744 auto srcLayout = resLayout;
745 for (
int64_t dstIdx =
static_cast<int64_t>(collapseDims.size()) - 1;
746 dstIdx >= 0; --dstIdx) {
748 if (srcDims.empty()) {
750 srcLayout = srcLayout.dropDims({dstIdx});
753 if (srcDims.size() == 1)
757 targetShape.reserve(srcDims.size());
759 targetShape.push_back(srcShape[d]);
760 srcLayout = srcLayout.expandDim(dstIdx, targetShape);
764 llvm_unreachable(
"running into unsupported shape cast scenarios");
771 xegpu::DistributeLayoutAttr payloadLayout,
int chunkSize) {
772 auto rank = payloadLayout.getRank();
774 return payloadLayout.dropDims(
775 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
776 return payloadLayout;
845 auto srcShape = srcVecTy.getShape();
846 int srcRank = srcShape.size();
847 auto context = srcVecTy.getContext();
855 const int subgroupSize =
uArch->getSubgroupSize();
856 int64_t maxReduceVectorSize = 1;
857 xegpu::DistributeLayoutAttr srcLayout;
859 xegpu::SliceAttr consumerSliceLayout =
860 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
861 if (consumerSliceLayout &&
862 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
863 srcLayout = consumerSliceLayout.getParent();
865 srcLayout.getEffectiveSgLayoutAsInt();
868 for (
int dim = 0; dim < srcRank; dim++) {
869 if (llvm::is_contained(reductionDims, dim))
871 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
875 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
878 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
881 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
884 consumerLayout ? consumerLayout.getOrder() :
nullptr;
886 int remainingSgCount =
887 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
891 for (
int i = 0; i < srcRank; i++) {
892 if (!llvm::is_contained(reductionDims, i) &&
893 consumerIdx <
static_cast<int>(consumerSgLayout.size())) {
894 sgLayout[i] = consumerSgLayout[consumerIdx];
895 sgData[i] = consumerSgData[consumerIdx];
896 remainingSgCount /= sgLayout[i];
897 order[i] = consumerOrder[consumerIdx];
904 int64_t remainOrder = consumerSgLayout.size();
905 for (
int i = 0; i < srcRank; i++) {
906 if (llvm::is_contained(reductionDims, i)) {
908 std::min(srcShape[i],
static_cast<int64_t>(remainingSgCount));
909 assert((srcShape[i] % sgLayout[i] == 0) &&
910 "source shape not divisible by sg_layout");
911 sgData[i] = srcShape[i] / sgLayout[i];
912 remainingSgCount /= sgLayout[i];
913 order[i] = remainOrder++;
917 assert(remainingSgCount == 1 &&
"not all subgroups distributed");
918 srcLayout = xegpu::LayoutAttr::get(
919 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
922 (!orderAttr || orderAttr.empty()) ?
nullptr : toInt32Attr(order));
928 instData[srcRank - 2] =
929 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
930 instData[srcRank - 1] =
931 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
932 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
936 laneLayout[srcRank - 1] =
937 std::min(
static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
939 laneData[srcRank - 2] =
940 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
941 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
942 toInt32Attr(laneData));
945 return xegpu::SliceAttr::get(context, srcLayout,
956 auto srcShape = srcVecTy.getShape();
957 auto context = srcVecTy.getContext();
958 auto subgroupSize =
uArch->getSubgroupSize();
959 xegpu::LayoutAttr srcLayout;
962 assert(
true &&
"subgroup layout assignment not supported for reduction (op "
963 "is not expected at this level).");
965 assert(
true &&
"instData layout assignment not supported for reduction (op "
966 "is not expected at this level).");
969 laneLayout[0] = std::min(subgroupSize,
static_cast<int32_t
>(srcShape[0]));
971 srcLayout = xegpu::LayoutAttr::get(
976 auto result = xegpu::SliceAttr::get(context, srcLayout,
1007 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1008 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1016 consumerLayout.getEffectiveLaneLayoutAsInt();
1018 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1019 "laneData must be available for all dimensions");
1020 size_t innerMostDim = srcShape.size() - 1;
1024 if (srcElemTyBitWidth > resElemTyBitWidth) {
1028 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1030 sgDataValue = sgData[innerMostDim];
1031 while ((sgDataValue <= resShape[innerMostDim]) &&
1032 (sgDataValue % bitWidthRatio) != 0)
1035 instDataValue = instData[innerMostDim];
1036 const int innermostDimLaneLayout = laneLayout.empty()
1037 ?
uArch->getSubgroupSize()
1038 : laneLayout[innerMostDim];
1041 while ((instDataValue <= resShape[innerMostDim]) &&
1042 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1044 assert((resShape[innerMostDim] % instDataValue) == 0 &&
1045 "resShape, instData, and lanelayout for innermost must be 2^n !");
1047 laneDataValue = laneData[innerMostDim];
1048 while ((laneDataValue <= resShape[innerMostDim]) &&
1049 (laneDataValue % bitWidthRatio != 0))
1053 xegpu::DistributeLayoutAttr resLayout;
1054 resLayout = consumerLayout.setDimData(innerMostDim, sgDataValue,
1055 instDataValue, laneDataValue);
1058 return consumerLayout;
1085 consumerLayout.getEffectiveLaneLayoutAsInt();
1087 assert(consumerLayout.getRank() ==
static_cast<int64_t>(srcShape.size()) &&
1088 "consumer layout rank must match source shape rank");
1089 const size_t innerMostDim = srcShape.size() - 1;
1095 constexpr int ratio = 2;
1098 sgDataValue = sgData[innerMostDim];
1100 while ((sgDataValue <= srcShape[innerMostDim]) &&
1101 (sgDataValue % ratio != 0))
1102 sgDataValue *= ratio;
1104 instDataValue = instData[innerMostDim];
1105 const int innermostDimLaneLayout = laneLayout.empty()
1106 ?
uArch->getSubgroupSize()
1107 : laneLayout[innerMostDim];
1110 while ((instDataValue <= srcShape[innerMostDim]) &&
1111 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1112 instDataValue *= ratio;
1113 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1114 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1116 laneDataValue = laneData[innerMostDim];
1119 while ((laneDataValue <= srcShape[innerMostDim]) &&
1120 (laneDataValue % ratio != 0))
1121 laneDataValue *= ratio;
1124 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1133 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1136 xegpu::DistributeLayoutAttr requiredResLayout;
1138 consumerLayout.getEffectiveInstDataAsInt();
1140 consumerLayout.getEffectiveLaneDataAsInt();
1142 consumerLayout.getEffectiveLaneLayoutAsInt();
1147 requiredResLayout = consumerLayout;
1148 int srcRank = srcShape.size();
1152 "subgroup layout assignment not supported for insertStridedSlice.");
1154 for (
int dim = 0; dim < srcRank; dim++) {
1155 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1157 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1160 for (
int dim = 0; dim < srcRank; dim++) {
1161 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1162 "srcShape must be divisible by laneLayout for all dimensions");
1163 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1164 consumerLaneData[dim]);
1166 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1169 return requiredResLayout;
1186 xegpu::DistributeLayoutAttr consumerLayout,
bool isChunkedLoad,
1190 return consumerLayout;
1193 consumerLayout.getEffectiveInstDataAsInt();
1195 consumerLayout.getEffectiveLaneDataAsInt();
1201 if (!isChunkedLoad) {
1203 instData.back() = std::min(
static_cast<int>(consumerInstData.back()),
1204 maxChunkSize * subgroupSize);
1205 return xegpu::LayoutAttr::get(context, instData);
1208 std::min(
static_cast<int>(consumerLaneData.back()), maxChunkSize);
1209 laneLayout.back() = std::min(
static_cast<int64_t>(subgroupSize),
1210 resShape.back() / laneData.back());
1211 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1214 assert(resShape.size() == 2 &&
"Chunked Store must access 2D tensor tile.");
1216 instData[0] = subgroupSize;
1218 std::min(
static_cast<int>(consumerInstData[1]), maxChunkSize);
1219 return xegpu::LayoutAttr::get(context, instData);
1221 laneLayout[0] = subgroupSize;
1223 std::min(
static_cast<int>(consumerLaneData[1]), maxChunkSize);
1224 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1237 auto context = resVecTy.getContext();
1238 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1240 const auto *uArchInstruction =
1241 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1243 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1246 (chunkSize > 1), maxChunkSize, resShape,
1252xegpu::DistributeLayoutAttr
1254 VectorType resVecTy,
1255 xegpu::DistributeLayoutAttr consumerLayout,
1260 auto context = resVecTy.getContext();
1261 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1263 const auto *uArchInstruction =
1264 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1266 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1268 false, maxChunkSize, resShape,
1283static xegpu::DistributeLayoutAttr
1289 int srcShapeSize = srcShape.size();
1296 "subgroup layout assignment not supported for storeScatter.");
1300 if (!isChunkedStore) {
1302 instData[srcShapeSize - 1] =
1303 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1304 return xegpu::LayoutAttr::get(context, instData);
1306 laneLayout[srcShapeSize - 1] =
1307 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
1308 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1311 assert(srcShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
1313 instData[0] = subgroupSize;
1314 instData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1315 return xegpu::LayoutAttr::get(context, instData);
1317 laneLayout[0] = subgroupSize;
1318 laneData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
1319 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1326xegpu::DistributeLayoutAttr
1328 VectorType srcVecTy,
int chunkSize,
1331 const int subgroupSize =
uArch->getSubgroupSize();
1333 auto context = srcVecTy.getContext();
1334 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1336 const auto *uArchInstruction =
1337 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1339 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1341 maxChunkSize, srcShape, subgroupSize);
1345xegpu::DistributeLayoutAttr
1347 VectorType srcVecTy,
1350 const int subgroupSize =
uArch->getSubgroupSize();
1352 auto context = srcVecTy.getContext();
1353 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1355 const auto *uArchInstruction =
1356 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1358 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1361 srcShape, subgroupSize);
1369template <
typename RankedTy>
1372 std::optional<unsigned> packingSize = std::nullopt,
bool vnni =
false) {
1374 assert(((ty.getRank() >= 1 && !vnni) || ty.getRank() >= 2) &&
1375 "Expected at least 1D non-vnni or 2D vector.");
1377 assert(ty.getElementType().isIntOrFloat() &&
1378 "Expected int or float element type.");
1380 auto context = ty.getContext();
1381 auto rank = ty.getRank();
1384 if (packingSize.has_value()) {
1385 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1386 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1387 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1390 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1405 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1406 if (sgCount % sgLayout0)
1408 int64_t sgLayout1 = sgCount / sgLayout0;
1409 int64_t sgData0 = wgShape[0] / sgLayout0;
1410 int64_t sgData1 = wgShape[1] / sgLayout1;
1411 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1412 (sgData0 % instData[0] || sgData1 % instData[1]))
1414 candidates.emplace_back(sgLayout0, sgLayout1);
1421 int diffLhs = std::abs(
lhs.first -
lhs.second);
1422 int diffRhs = std::abs(
rhs.first -
rhs.second);
1423 if (diffLhs != diffRhs)
1424 return diffLhs < diffRhs;
1425 return lhs.first <
rhs.first;
1436 bool isDpasMx =
false) {
1441 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1450 const unsigned dataALen = aTy.getShape()[aTy.getRank() - 2];
1451 auto supportedALen = uArchInstruction->
getSupportedM(aTy.getElementType());
1456 const unsigned dataBLen = bTy.getShape().back();
1457 auto supportedBLen = uArchInstruction->
getSupportedN(bTy.getElementType());
1461 auto supportedCLen = uArchInstruction->
getSupportedN(cdTy.getElementType());
1464 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1465 return std::nullopt;
1469 int kDimSize = subgroupSize;
1471 auto supportedKLen = uArchInstruction->
getSupportedK(aTy.getElementType());
1472 if (supportedKLen.empty())
1473 return std::nullopt;
1474 kDimSize = supportedKLen[0];
1478 instDataA[aTy.getRank() - 2] = maxALen;
1479 instDataA[aTy.getRank() - 1] = kDimSize;
1481 instDataB[bTy.getRank() - 2] = kDimSize;
1482 instDataB[bTy.getRank() - 1] = maxBLen;
1484 instDataCD[cdTy.getRank() - 2] = maxALen;
1485 instDataCD[cdTy.getRank() - 1] = maxCLen;
1486 return std::make_tuple(instDataA, instDataB, instDataCD);
1491static std::optional<
1492 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1493 xegpu::DistributeLayoutAttr>>
1495 VectorType bTy, VectorType cdTy,
1496 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1500 return std::nullopt;
1501 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1502 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1503 instDataCD.size() == 2 &&
1504 "Sg layout creation expects valid 2D inst data");
1506 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1507 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1509 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1516 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1517 return std::nullopt;
1523 std::optional<LayoutRepresentation> bestPick;
1525 return aTy.getShape().back() / sgLayout.second ==
1526 bTy.getShape().front() / sgLayout.first;
1528 for (
auto &sgLayout : layoutsB) {
1529 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1530 if (!checkAlignedSgDataAB(sgLayout))
1533 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1534 bestPick = sgLayout;
1542 bestPick = sgLayout;
1546 return std::nullopt;
1549 static_cast<int>(bestPick->second)};
1550 SmallVector<int> sgDataA = {
static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1551 static_cast<int>(aTy.getShape()[1])};
1553 static_cast<int>(bTy.getShape()[0]),
1554 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1556 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1557 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1562 nullptr,
nullptr,
nullptr);
1566 nullptr,
nullptr,
nullptr);
1570 nullptr,
nullptr,
nullptr);
1572 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1579 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1580 xegpu::DistributeLayoutAttr>>
1582 VectorType bTy, VectorType cdTy,
1583 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1585 auto context = aTy.getContext();
1586 const auto *uArchInstruction =
1592 "Number of subgroups must be provided for sg layout creation.");
1598 return std::nullopt;
1599 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1600 return std::make_tuple(
1601 xegpu::LayoutAttr::get(
1603 xegpu::LayoutAttr::get(
1605 xegpu::LayoutAttr::get(
1609 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1611 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1614 return std::make_tuple(aLayout, bLayout, cdLayout);
1616 return std::nullopt;
1623static xegpu::DistributeLayoutAttr
1625 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1627 if (!scaleTy || !matrixLayout)
1635 if (scaleShape.empty())
1638 auto uArchInstruction =
1639 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1643 int64_t rank = matrixLayout.getRank();
1644 assert(rank >= 2 &&
"dpas layouts must be at least two dimensions");
1651 auto order = matrixLayout.getOrder();
1655 if (!sgLayout.empty() && !sgData.empty()) {
1656 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1657 scaleSgData.assign(sgData.begin(), sgData.end());
1658 scaleSgData[rank - 2] = std::max<int64_t>(
1659 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1660 scaleSgData[rank - 1] = std::max<int64_t>(
1661 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1668 if (!instData.empty()) {
1669 scaleInstData.assign(instData.begin(), instData.end());
1671 scaleInstData[rank - 2] = std::max<int64_t>(
1672 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1675 scaleInstData[rank - 1] = std::max<int64_t>(
1676 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1682 if (!laneLayout.empty() && !laneData.empty()) {
1683 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1684 scaleLaneData.assign(laneData.begin(), laneData.end());
1685 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1686 if (isBScale ^ isRowMajor) {
1687 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1688 scaleLaneLayout[rank - 2] =
1689 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1691 scaleLaneData[rank - 2] =
1692 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1693 scaleLaneData[rank - 1] =
1694 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1696 return xegpu::LayoutAttr::get(
1698 scaleSgLayout.empty() ?
nullptr
1700 scaleSgData.empty() ?
nullptr
1702 scaleInstData.empty() ?
nullptr
1704 scaleLaneLayout.empty()
1707 scaleLaneData.empty() ?
nullptr
1716 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1717 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1718 xegpu::DistributeLayoutAttr>>
1720 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1721 VectorType bScaleTy,
1722 xegpu::DistributeLayoutAttr consumerLayout,
int numSg,
1724 auto context = aTy.getContext();
1728 "Number of subgroups must be provided for sg layout creation.");
1730 consumerLayout, numSg,
uArch);
1732 return std::nullopt;
1734 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1743 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1749 return std::nullopt;
1750 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1752 auto dpasALayout = xegpu::LayoutAttr::get(
1754 auto dpasBLayout = xegpu::LayoutAttr::get(
1756 auto dpasCDLayout = xegpu::LayoutAttr::get(
1765 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1768 const auto *uArchInstruction =
1772 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1774 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1783 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1786 return std::nullopt;
1790 OpOperand &operand, xegpu::DistributeLayoutAttr resLayout) {
1797 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1798 auto srcTy = dyn_cast<VectorType>(
broadcast.getSourceType());
1802 resLayout,
broadcast.getResultVectorType().getShape(),
1809 if (
auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1818 if (
auto reduction = dyn_cast<vector::ReductionOp>(op))
1823 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1824 int resElemBitWidth =
1825 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1826 int srcElemBitWidth =
1827 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1834 if (
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1836 resLayout, shapeCast.getResultVectorType().getShape(),
1837 shapeCast.getSourceVectorType().getShape());
1842 if (
auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1845 resLayout, insertSlice.getDestVectorType().getShape(),
1846 insertSlice.getSourceVectorType().getShape());
1854 if (
auto insert = dyn_cast<vector::InsertOp>(op)) {
1855 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1856 VectorType valueToStoreTy =
1857 dyn_cast<VectorType>(insert.getValueToStore().getType());
1859 if ((idx == 0) && valueToStoreTy) {
1861 valueToStoreTy.getShape());
1869 if (
auto extract = dyn_cast<vector::ExtractOp>(op)) {
1870 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1871 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1872 if (!srcVecTy || !resVecTy)
1875 srcVecTy.getShape());
1880 if (
auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1882 transpose.getPermutation());
1887 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1888 int resElemBitWidth =
1889 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1890 int srcElemBitWidth =
1891 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1897 if (
auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1902 if (
auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1907 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1923 if (isa<xegpu::AnchorLayoutInterface>(op))
1927 xegpu::DistributeLayoutAttr resLayout;
1928 if (op->
getNumResults() == 1 || isa<vector::DeinterleaveOp>(op))
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region ®ion, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
bool hasAttrOfType(NameT &&name)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
@ SubgroupMatrixMultiplyAcc
@ SubgroupScaledMatrixMultiplyAcc
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
bool matchDimCollapse(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &collapseDims)
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
llvm::function_ref< DistributeLayoutAttr(Value)> GetLayoutFnTy
Callable returning the propagated layout for a given Value, used by the layout-propagation helpers be...
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp, GetLayoutFnTy getLayoutOfValue)
Propagate layouts from a region branch op's region entry block arguments back to its init operands.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
const Instruction * getInstruction(InstructionKind instKind) const