27#include "llvm/Support/FormatVariadic.h"
50 out.reserve(attrs.size());
52 for (
auto attr : attrs) {
53 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
54 auto newLayout = dist.dropSgLayoutAndData();
56 out.emplace_back(attr.getName(), newLayout);
68 out.reserve(attrs.size());
70 for (
auto attr : attrs) {
71 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
72 auto newLayout = dist.dropInstData();
74 out.emplace_back(attr.getName(), newLayout);
90 if (!isa<VectorType>(operand.get().getType()))
94 if (isa<BlockArgument>(operand.get()))
98 op->
emitWarning(
"Could not find layout attribute for operand ")
99 << operand.getOperandNumber() <<
" of operation " << op->
getName();
106 return !
result.wasInterrupted();
109template <
typename T,
typename>
111 Operation *owner = operandOrResult.getOwner();
129 for (
auto namedAttr : nestOp->
getAttrs()) {
130 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
131 attrsToRemove.push_back(namedAttr.getName());
133 for (
auto attrName : attrsToRemove)
140xegpu::DistributeLayoutAttr
146 auto returnLayout = resLayout;
149 int dimDiff = resShape.size() - srcShape.size();
153 for (
int i = 0; i < dimDiff; i++)
154 bcastDims.push_back(i);
157 returnLayout = xegpu::SliceAttr::get(
158 resLayout.getContext(), resLayout,
166xegpu::DistributeLayoutAttr
170 assert(isa<xegpu::SliceAttr>(resLayout) &&
171 "reduction result layout must be slice layout");
173 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
175 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
176 "reduction dims must match with slice dims");
178 return sliceLayout.getParent();
184xegpu::DistributeLayoutAttr
186 int resElemTyBitWidth,
int srcElemTyBitWidth) {
191 size_t sgDataSize = sgData.size();
192 size_t instDataSize = instData.size();
193 size_t laneDataSize = laneData.size();
197 int64_t dim = resLayout.getRank() - 1;
199 if (srcElemTyBitWidth <= resElemTyBitWidth) {
200 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
202 sgDataValue = sgData.back() * bitWidthRatio;
204 instDataValue = instData.back() * bitWidthRatio;
206 laneDataValue = laneData.back() * bitWidthRatio;
208 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
210 assert((sgData.back() % bitWidthRatio) == 0 &&
211 "sgData not divisible by bitWidthRatio");
212 sgDataValue = sgData.back() / bitWidthRatio;
215 assert((instData.back() % bitWidthRatio) == 0 &&
216 "instData not divisible by bitWidthRatio");
217 instDataValue = instData.back() / bitWidthRatio;
220 assert((laneData.back() % bitWidthRatio) == 0 &&
221 "laneData not divisible by bitWidthRatio");
222 laneDataValue = laneData.back() / bitWidthRatio;
226 xegpu::DistributeLayoutAttr finalSrcLayout;
228 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
230 return finalSrcLayout;
240 int srcShapeSize = srcShape.size();
241 int resShapeSize = resShape.size();
242 int dimDiff = resShapeSize - srcShapeSize;
244 assert(isa<xegpu::LayoutAttr>(resLayout) &&
245 "insertStridedSlice result layout must be plain layout");
246 auto context = resLayout.getContext();
247 auto resInstData = resLayout.getEffectiveInstDataAsInt();
248 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
249 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
251 if (resInstData.size() != 0) {
253 for (
int i = 0; i < srcShapeSize; i++)
254 inferredInstData[i] = resInstData[i + dimDiff];
255 return xegpu::LayoutAttr::get(context, inferredInstData);
258 if (resLaneLayout.size() != 0) {
261 for (
int i = 0; i < srcShapeSize; i++) {
262 inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
263 inferredLaneData[i] = resLaneData[i + dimDiff];
265 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
273xegpu::DistributeLayoutAttr
297 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
304 auto srcLayout = resLayout;
305 for (
const auto &dimGroup : splitDimGroups)
306 srcLayout = srcLayout.collapseDims(dimGroup);
315 if ((dst.size() != 2) && (dst.size() != 1))
317 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
318 std::multiplies<int64_t>());
320 return (dst[0] == srcSize);
321 return (dst[0] == 1) && (dst[1] == srcSize);
324 if (matchCollapseToInnermostDim(srcShape, resShape)) {
325 int srcShapeSize = srcShape.size();
326 int resShapeSize = resShape.size();
327 auto context = resLayout.getContext();
328 auto resInstData = resLayout.getEffectiveInstDataAsInt();
329 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
330 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
346 if (resInstData.size() != 0) {
348 for (
int i = 0; i < resShapeSize - 1; i++) {
349 assert(resInstData[i] == 1 &&
350 "only innermost dim can have non-unit instData");
353 inferredInstData[srcShapeSize - 1] =
354 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
355 return xegpu::LayoutAttr::get(context, inferredInstData);
358 if (resLaneLayout.size() != 0) {
359 for (
int i = 0; i < resShapeSize - 1; i++) {
360 assert(resLaneData[i] == 1 &&
361 "only innermost dim can have non-unit instData");
363 assert(srcShape.back() % resLaneLayout.back() == 0 &&
364 "source innermost dim must be >= result lane layout");
367 inferredLaneLayout.back() = resLaneLayout.back();
368 inferredLaneData.back() = std::min(
369 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
370 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
374 llvm_unreachable(
"running into unsupported shape cast scenarios");
425 auto srcShape = srcVecTy.getShape();
426 int srcRank = srcShape.size();
427 auto context = consumerLayout.getContext();
440 xegpu::SliceAttr consumerSliceLayout =
441 dyn_cast<xegpu::SliceAttr>(consumerLayout);
442 DistributeLayoutAttr plainLayout =
443 consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
446 const int subgroupSize =
uArch->getSubgroupSize();
447 int64_t maxReduceVectorSize = 1;
449 xegpu::DistributeLayoutAttr srcLayout;
452 auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
453 const int workgroupSize = std::accumulate(
454 sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
457 consumerLayout.getEffectiveSgLayoutAsInt();
458 int remainingSgCount = workgroupSize;
459 int consumerIdx = consumerSgLayout.size() - 1;
462 for (
int i = srcRank - 1; i >= 0; i--) {
463 if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
464 sgLayout[i] = consumerSgLayout[consumerIdx];
465 assert((srcShape[i] % sgLayout[i] == 0) &&
466 "source shape not divisible by consumer sg_layout");
467 sgData[i] = srcShape[i] / sgLayout[i];
468 remainingSgCount /= sgLayout[i];
474 for (
int i = srcRank - 1; i >= 0; i--) {
475 if (llvm::is_contained(reductionDims, i)) {
477 std::min(srcShape[i],
static_cast<int64_t>(remainingSgCount));
478 assert((srcShape[i] % sgLayout[i] == 0) &&
479 "source shape not divisible by sg_layout");
480 sgData[i] = srcShape[i] / sgLayout[i];
481 remainingSgCount /= sgLayout[i];
485 assert(remainingSgCount == 1 &&
"not all subgroups distributed");
486 srcLayout = xegpu::LayoutAttr::get(
487 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
494 instData[srcRank - 2] =
495 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
496 instData[srcRank - 1] = subgroupSize;
497 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
502 laneLayout[srcRank - 1] = subgroupSize;
503 laneData[srcRank - 2] =
504 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
505 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
506 toInt32Attr(laneData),
507 consumerLayout.getOrder());
510 return xegpu::SliceAttr::get(context, srcLayout,
540 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
541 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
547 size_t dim = srcShape.size() - 1;
552 const int subgroupSize =
uArch->getSubgroupSize();
554 if (srcElemTyBitWidth > resElemTyBitWidth) {
558 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
559 int innermostDimLaneLayout = subgroupSize;
561 assert(sgData.size() == srcShape.size() &&
562 "sgData must be available for all dimensions");
563 sgDataValue = sgData[dim];
565 assert(instData.size() == srcShape.size() &&
566 "instData must be available for all dimensions");
567 instDataValue = instData[dim];
570 while ((instDataValue <= srcShape[dim]) &&
571 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
573 assert((srcShape[dim] % instDataValue) == 0 &&
574 "srcShape, instData, and lanelayout for innermost must be 2^n !");
576 assert(laneData.size() == srcShape.size() &&
577 "laneData must be available for all dimensions");
578 laneDataValue = laneData[dim];
579 while ((laneDataValue <= srcShape[dim]) &&
580 (laneDataValue % bitWidthRatio != 0))
584 xegpu::DistributeLayoutAttr resLayout;
585 resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
589 return consumerLayout;
624 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
627 xegpu::DistributeLayoutAttr requiredResLayout;
629 auto context = resVectorTy.getContext();
630 auto resShape = resVectorTy.getShape();
631 int resShapeSize = resShape.size();
632 auto srcShape = srcVectorTy.getShape();
634 consumerLayout.getEffectiveInstDataAsInt();
636 consumerLayout.getEffectiveLaneDataAsInt();
643 unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
644 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
645 int packedDataSize = subgroupSize * packingFactor;
649 "subgroup layout assignment not supported for insertStridedSlice.");
651 assert(srcShape.back() >= subgroupSize &&
652 "source innermost dim must be >= subgroupSize");
653 instData.back() = subgroupSize;
654 if (consumerInstData.back() == packedDataSize &&
655 srcShape.back() >= packedDataSize)
656 instData.back() = packedDataSize;
657 requiredResLayout = xegpu::LayoutAttr::get(context, instData);
659 laneLayout.back() = subgroupSize;
661 if (consumerLaneData.back() == packingFactor &&
662 srcShape.back() >= packedDataSize)
663 laneData.back() = packingFactor;
664 requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
666 return requiredResLayout;
683 xegpu::DistributeLayoutAttr consumerLayout,
bool isChunkedLoad,
687 return consumerLayout;
690 consumerLayout.getEffectiveInstDataAsInt();
692 consumerLayout.getEffectiveLaneDataAsInt();
698 if (!isChunkedLoad) {
700 instData.back() = std::min(
static_cast<int>(consumerInstData.back()),
701 maxChunkSize * subgroupSize);
702 return xegpu::LayoutAttr::get(context, instData);
705 std::min(
static_cast<int>(consumerLaneData.back()), maxChunkSize);
706 laneLayout.back() = std::min(
static_cast<int64_t>(subgroupSize),
707 resShape.back() / laneData.back());
708 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
711 assert(resShape.size() == 2 &&
"Chunked Store must access 2D tensor tile.");
713 instData[0] = subgroupSize;
715 std::min(
static_cast<int>(consumerInstData[1]), maxChunkSize);
716 return xegpu::LayoutAttr::get(context, instData);
718 laneLayout[0] = subgroupSize;
720 std::min(
static_cast<int>(consumerLaneData[1]), maxChunkSize);
721 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
734 auto context = resVecTy.getContext();
735 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
737 const auto *uArchInstruction =
738 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
740 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
743 (chunkSize > 1), maxChunkSize, resShape,
749xegpu::DistributeLayoutAttr
752 xegpu::DistributeLayoutAttr consumerLayout,
757 auto context = resVecTy.getContext();
758 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
760 const auto *uArchInstruction =
761 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
763 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
765 false, maxChunkSize, resShape,
780static xegpu::DistributeLayoutAttr
786 int srcShapeSize = srcShape.size();
793 "subgroup layout assignment not supported for storeScatter.");
797 if (!isChunkedStore) {
799 instData[srcShapeSize - 1] =
800 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
801 return xegpu::LayoutAttr::get(context, instData);
803 laneLayout[srcShapeSize - 1] =
804 std::min(subgroupSize,
static_cast<int>(srcShape.back()));
805 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
808 assert(srcShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
810 instData[0] = subgroupSize;
811 instData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
812 return xegpu::LayoutAttr::get(context, instData);
814 laneLayout[0] = subgroupSize;
815 laneData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
816 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
823xegpu::DistributeLayoutAttr
825 VectorType srcVecTy,
int chunkSize,
828 const int subgroupSize =
uArch->getSubgroupSize();
830 auto context = srcVecTy.getContext();
831 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
833 const auto *uArchInstruction =
834 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
836 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
838 maxChunkSize, srcShape, subgroupSize);
842xegpu::DistributeLayoutAttr
847 const int subgroupSize =
uArch->getSubgroupSize();
849 auto context = srcVecTy.getContext();
850 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
852 const auto *uArchInstruction =
853 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
855 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
858 srcShape, subgroupSize);
866template <
typename RankedTy>
869 std::optional<unsigned> packingSize = std::nullopt,
bool vnni =
false) {
871 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
872 "Expected 1D non-vnni or 2D vector.");
874 assert(ty.getElementType().isIntOrFloat() &&
875 "Expected int or float element type.");
877 auto context = ty.getContext();
878 auto rank = ty.getRank();
881 if (packingSize.has_value()) {
882 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
883 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
884 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
887 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
902 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
903 if (sgCount % sgLayout0)
905 int64_t sgLayout1 = sgCount / sgLayout0;
906 int64_t sgData0 = wgShape[0] / sgLayout0;
907 int64_t sgData1 = wgShape[1] / sgLayout1;
908 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
909 (sgData0 % instData[0] || sgData1 % instData[1]))
911 candidates.emplace_back(sgLayout0, sgLayout1);
918 int diffLhs = std::abs(
lhs.first -
lhs.second);
919 int diffRhs = std::abs(
rhs.first -
rhs.second);
920 if (diffLhs != diffRhs)
921 return diffLhs < diffRhs;
922 return lhs.first <
rhs.first;
930 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
931 xegpu::DistributeLayoutAttr>>
933 VectorType bTy, VectorType cdTy,
934 xegpu::DistributeLayoutAttr consumerLayout,
936 auto context = aTy.getContext();
937 const auto *uArchInstruction =
941 auto getInstDataVectors = [&]()
945 const unsigned dataALen = aTy.getShape().front();
946 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
950 const unsigned dataBLen = bTy.getShape().back();
951 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
955 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
958 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
962 instDataA[aTy.getRank() - 2] = maxALen;
963 instDataA[aTy.getRank() - 1] = subgroupSize;
965 instDataB[bTy.getRank() - 2] = subgroupSize;
966 instDataB[bTy.getRank() - 1] = maxBLen;
968 instDataCD[cdTy.getRank() - 2] = maxALen;
969 instDataCD[cdTy.getRank() - 1] = maxCLen;
970 return std::make_tuple(instDataA, instDataB, instDataCD);
975 "Number of subgroups must be provided for sg layout creation.");
976 auto instDataVecs = getInstDataVectors();
979 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
980 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
981 instDataCD.size() == 2 &&
982 "Sg layout creation expects valid 2D inst data");
984 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
985 if (consumerLayout && consumerLayout.isForWorkgroup()) {
987 consumerLayout.getEffectiveSgLayoutAsInt();
988 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
996 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1005 std::optional<LayoutRepresentation> bestPick;
1006 for (
auto &sgLayout : layoutsB) {
1007 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1009 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1010 bestPick = sgLayout;
1018 bestPick = sgLayout;
1024 return std::nullopt;
1026 static_cast<int>(bestPick->second)};
1028 static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1029 static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
1031 static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
1032 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1034 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1035 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1037 auto dpasALayout = xegpu::LayoutAttr::get(
1043 auto dpasBLayout = xegpu::LayoutAttr::get(
1049 auto dpasCDLayout = xegpu::LayoutAttr::get(
1054 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1056 auto instDataVecs = getInstDataVectors();
1058 return std::nullopt;
1059 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1060 return std::make_tuple(
1061 xegpu::LayoutAttr::get(
1063 xegpu::LayoutAttr::get(
1065 xegpu::LayoutAttr::get(
1069 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1071 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1073 cdTy,
uArch, uArchInstruction->getPackedFormatBitSizeB());
1074 return std::make_tuple(aLayout, bLayout, cdLayout);
1076 return std::nullopt;
1082 xegpu::DistributeLayoutAttr resLayout;
1087 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1089 return xegpu::DistributeLayoutAttr();
1090 auto srcTy = dyn_cast<VectorType>(
broadcast.getSourceType());
1092 return xegpu::DistributeLayoutAttr();
1094 resLayout,
broadcast.getResultVectorType().getShape(),
1101 if (
auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1103 return xegpu::DistributeLayoutAttr();
1114 if (
auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1116 return xegpu::DistributeLayoutAttr();
1117 int resElemBitWidth =
1118 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1119 int srcElemBitWidth =
1120 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1127 if (
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1129 return xegpu::DistributeLayoutAttr();
1131 resLayout, shapeCast.getResultVectorType().getShape(),
1132 shapeCast.getSourceVectorType().getShape());
1137 if (
auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1139 return xegpu::DistributeLayoutAttr();
1142 resLayout, insertSlice.getDestVectorType().getShape(),
1143 insertSlice.getSourceVectorType().getShape());
1151 return xegpu::DistributeLayoutAttr();
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasAttrOfType(NameT &&name)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
Type getType() const
Return the type of this value.
static WalkResult advance()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
@ SubgroupMatrixMultiplyAcc
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
void recoverTemporaryLayoutsDeprecated(Operation *op)
[to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and OpResult of of the given opera...
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
const Instruction * getInstruction(InstructionKind instKind) const