27#include "llvm/Support/FormatVariadic.h"
50 out.reserve(attrs.size());
52 for (
auto attr : attrs) {
53 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
54 auto newLayout = dist.dropSgLayoutAndData();
56 out.emplace_back(attr.getName(), newLayout);
68 out.reserve(attrs.size());
70 for (
auto attr : attrs) {
71 if (
auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
72 auto newLayout = dist.dropInstData();
74 out.emplace_back(attr.getName(), newLayout);
90 if (!isa<VectorType>(operand.get().getType()))
94 if (isa<BlockArgument>(operand.get()))
98 op->
emitWarning(
"Could not find layout attribute for operand ")
99 << operand.getOperandNumber() <<
" of operation " << op->
getName();
106 return !
result.wasInterrupted();
109template <
typename T,
typename>
111 Operation *owner = operandOrResult.getOwner();
129 for (
auto namedAttr : nestOp->
getAttrs()) {
130 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
131 attrsToRemove.push_back(namedAttr.getName());
133 for (
auto attrName : attrsToRemove)
140xegpu::DistributeLayoutAttr
146 auto returnLayout = resLayout;
149 int dimDiff = resShape.size() - srcShape.size();
153 for (
int i = 0; i < dimDiff; i++)
154 bcastDims.push_back(i);
157 returnLayout = xegpu::SliceAttr::get(
158 resLayout.getContext(), resLayout,
166xegpu::DistributeLayoutAttr
170 assert(isa<xegpu::SliceAttr>(resLayout) &&
171 "reduction result layout must be slice layout");
173 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
175 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
176 "reduction dims must match with slice dims");
178 return sliceLayout.getParent();
184xegpu::DistributeLayoutAttr
186 int resElemTyBitWidth,
int srcElemTyBitWidth) {
191 size_t sgDataSize = sgData.size();
192 size_t instDataSize = instData.size();
193 size_t laneDataSize = laneData.size();
197 int64_t dim = resLayout.getRank() - 1;
199 if (srcElemTyBitWidth <= resElemTyBitWidth) {
200 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
202 sgDataValue = sgData.back() * bitWidthRatio;
204 instDataValue = instData.back() * bitWidthRatio;
206 laneDataValue = laneData.back() * bitWidthRatio;
208 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
210 assert((sgData.back() % bitWidthRatio) == 0 &&
211 "sgData not divisible by bitWidthRatio");
212 sgDataValue = sgData.back() / bitWidthRatio;
215 assert((instData.back() % bitWidthRatio) == 0 &&
216 "instData not divisible by bitWidthRatio");
217 instDataValue = instData.back() / bitWidthRatio;
220 assert((laneData.back() % bitWidthRatio) == 0 &&
221 "laneData not divisible by bitWidthRatio");
222 laneDataValue = laneData.back() / bitWidthRatio;
226 xegpu::DistributeLayoutAttr finalSrcLayout;
228 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
230 return finalSrcLayout;
240 int srcShapeSize = srcShape.size();
241 int resShapeSize = resShape.size();
242 int dimDiff = resShapeSize - srcShapeSize;
244 assert(isa<xegpu::LayoutAttr>(resLayout) &&
245 "insertStridedSlice result layout must be plain layout");
246 auto context = resLayout.getContext();
247 auto resInstData = resLayout.getEffectiveInstDataAsInt();
248 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
249 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
251 if (resInstData.size() != 0) {
253 for (
int i = 0; i < srcShapeSize; i++)
254 inferredInstData[i] = resInstData[i + dimDiff];
255 return xegpu::LayoutAttr::get(context, inferredInstData);
258 if (resLaneLayout.size() != 0) {
261 for (
int i = 0; i < srcShapeSize; i++) {
262 inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
263 inferredLaneData[i] = resLaneData[i + dimDiff];
265 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
273xegpu::DistributeLayoutAttr
297 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
304 auto srcLayout = resLayout;
305 for (
const auto &dimGroup : splitDimGroups)
306 srcLayout = srcLayout.collapseDims(dimGroup);
315 if ((dst.size() != 2) && (dst.size() != 1))
317 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
318 std::multiplies<int64_t>());
320 return (dst[0] == srcSize);
321 return (dst[0] == 1) && (dst[1] == srcSize);
324 if (matchCollapseToInnermostDim(srcShape, resShape)) {
325 int srcShapeSize = srcShape.size();
326 int resShapeSize = resShape.size();
327 auto context = resLayout.getContext();
328 auto resInstData = resLayout.getEffectiveInstDataAsInt();
329 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
330 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
346 if (resInstData.size() != 0) {
348 for (
int i = 0; i < resShapeSize - 1; i++) {
349 assert(resInstData[i] == 1 &&
350 "only innermost dim can have non-unit instData");
353 inferredInstData[srcShapeSize - 1] =
354 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
355 return xegpu::LayoutAttr::get(context, inferredInstData);
358 if (resLaneLayout.size() != 0) {
359 for (
int i = 0; i < resShapeSize - 1; i++) {
360 assert(resLaneData[i] == 1 &&
361 "only innermost dim can have non-unit instData");
363 assert(srcShape.back() % resLaneLayout.back() == 0 &&
364 "source innermost dim must be >= result lane layout");
367 inferredLaneLayout.back() = resLaneLayout.back();
368 inferredLaneData.back() = std::min(
369 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
370 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
374 llvm_unreachable(
"running into unsupported shape cast scenarios");
425 auto srcShape = srcVecTy.getShape();
426 int srcRank = srcShape.size();
427 auto context = consumerLayout.getContext();
440 xegpu::SliceAttr consumerSliceLayout =
441 dyn_cast<xegpu::SliceAttr>(consumerLayout);
442 DistributeLayoutAttr plainLayout =
443 consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
446 const int subgroupSize =
uArch->getSubgroupSize();
447 int64_t maxReduceVectorSize = 1;
449 xegpu::DistributeLayoutAttr srcLayout;
452 auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
453 const int workgroupSize = std::accumulate(
454 sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
457 consumerLayout.getEffectiveSgLayoutAsInt();
458 int remainingSgCount = workgroupSize;
459 int consumerIdx = consumerSgLayout.size() - 1;
462 for (
int i = srcRank - 1; i >= 0; i--) {
463 if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
464 sgLayout[i] = consumerSgLayout[consumerIdx];
465 assert((srcShape[i] % sgLayout[i] == 0) &&
466 "source shape not divisible by consumer sg_layout");
467 sgData[i] = srcShape[i] / sgLayout[i];
468 remainingSgCount /= sgLayout[i];
474 for (
int i = srcRank - 1; i >= 0; i--) {
475 if (llvm::is_contained(reductionDims, i)) {
477 std::min(srcShape[i],
static_cast<int64_t>(remainingSgCount));
478 assert((srcShape[i] % sgLayout[i] == 0) &&
479 "source shape not divisible by sg_layout");
480 sgData[i] = srcShape[i] / sgLayout[i];
481 remainingSgCount /= sgLayout[i];
485 assert(remainingSgCount == 1 &&
"not all subgroups distributed");
486 srcLayout = xegpu::LayoutAttr::get(
487 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
494 instData[srcRank - 2] =
495 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
496 instData[srcRank - 1] = subgroupSize;
497 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
502 laneLayout[srcRank - 1] = subgroupSize;
503 laneData[srcRank - 2] =
504 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
505 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
506 toInt32Attr(laneData),
507 consumerLayout.getOrder());
510 return xegpu::SliceAttr::get(context, srcLayout,
540 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
541 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
547 size_t dim = srcShape.size() - 1;
552 const int subgroupSize =
uArch->getSubgroupSize();
554 if (srcElemTyBitWidth > resElemTyBitWidth) {
558 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
559 int innermostDimLaneLayout = subgroupSize;
561 assert(sgData.size() == srcShape.size() &&
562 "sgData must be available for all dimensions");
563 sgDataValue = sgData[dim];
565 assert(instData.size() == srcShape.size() &&
566 "instData must be available for all dimensions");
567 instDataValue = instData[dim];
570 while ((instDataValue <= srcShape[dim]) &&
571 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
573 assert((srcShape[dim] % instDataValue) == 0 &&
574 "srcShape, instData, and lanelayout for innermost must be 2^n !");
576 assert(laneData.size() == srcShape.size() &&
577 "laneData must be available for all dimensions");
578 laneDataValue = laneData[dim];
579 while ((laneDataValue <= srcShape[dim]) &&
580 (laneDataValue % bitWidthRatio != 0))
584 xegpu::DistributeLayoutAttr resLayout;
585 resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
589 return consumerLayout;
624 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
627 xegpu::DistributeLayoutAttr requiredResLayout;
629 auto context = resVectorTy.getContext();
630 auto resShape = resVectorTy.getShape();
631 int resShapeSize = resShape.size();
632 auto srcShape = srcVectorTy.getShape();
634 consumerLayout.getEffectiveInstDataAsInt();
636 consumerLayout.getEffectiveLaneDataAsInt();
643 unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
644 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
645 int packedDataSize = subgroupSize * packingFactor;
649 "subgroup layout assignment not supported for insertStridedSlice.");
651 assert(srcShape.back() >= subgroupSize &&
652 "source innermost dim must be >= subgroupSize");
653 instData.back() = subgroupSize;
654 if (consumerInstData.back() == packedDataSize &&
655 srcShape.back() >= packedDataSize)
656 instData.back() = packedDataSize;
657 requiredResLayout = xegpu::LayoutAttr::get(context, instData);
659 laneLayout.back() = subgroupSize;
661 if (consumerLaneData.back() == packingFactor &&
662 srcShape.back() >= packedDataSize)
663 laneData.back() = packingFactor;
664 requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
666 return requiredResLayout;
683 xegpu::DistributeLayoutAttr consumerLayout,
bool isChunkedLoad,
684 int maxChunkSize,
int valShapeSize,
int subgroupSize) {
687 return consumerLayout;
690 consumerLayout.getEffectiveInstDataAsInt();
692 consumerLayout.getEffectiveLaneDataAsInt();
698 if (!isChunkedLoad) {
700 instData[valShapeSize - 1] =
701 std::min(
static_cast<int>(consumerInstData[valShapeSize - 1]),
702 maxChunkSize * subgroupSize);
703 return xegpu::LayoutAttr::get(context, instData);
705 laneLayout.back() = subgroupSize;
707 std::min(
static_cast<int>(consumerLaneData.back()), maxChunkSize);
708 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
711 assert(valShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
713 instData[0] = subgroupSize;
715 std::min(
static_cast<int>(consumerInstData[1]), maxChunkSize);
716 return xegpu::LayoutAttr::get(context, instData);
718 laneLayout[0] = subgroupSize;
720 std::min(
static_cast<int>(consumerLaneData[1]), maxChunkSize);
721 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
733 int resShapeSize = resVecTy.getShape().size();
734 auto context = resVecTy.getContext();
735 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
737 const auto *uArchInstruction =
738 dyn_cast<xegpu::uArch::SpirvLoadGatherInstruction>(
740 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
743 (chunkSize > 1), maxChunkSize,
744 resShapeSize, subgroupSize);
749xegpu::DistributeLayoutAttr
752 xegpu::DistributeLayoutAttr consumerLayout,
756 int resShapeSize = resVecTy.getShape().size();
757 auto context = resVecTy.getContext();
758 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
760 const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
762 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
764 false, maxChunkSize, resShapeSize,
779static xegpu::DistributeLayoutAttr
785 int srcShapeSize = srcShape.size();
792 "subgroup layout assignment not supported for storeScatter.");
796 if (!isChunkedStore) {
798 instData[srcShapeSize - 1] = subgroupSize;
799 return xegpu::LayoutAttr::get(context, instData);
801 laneLayout[srcShapeSize - 1] = subgroupSize;
802 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
805 assert(srcShapeSize == 2 &&
"Chunked Store must access 2D tensor tile.");
807 instData[0] = subgroupSize;
808 instData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
809 return xegpu::LayoutAttr::get(context, instData);
811 laneLayout[0] = subgroupSize;
812 laneData[1] = std::min(
static_cast<int>(srcShape[1]), maxChunkSize);
813 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
820xegpu::DistributeLayoutAttr
822 VectorType srcVecTy,
int chunkSize,
825 const int subgroupSize =
uArch->getSubgroupSize();
827 auto context = srcVecTy.getContext();
828 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
830 const auto *uArchInstruction =
831 dyn_cast<xegpu::uArch::SpirvStoreScatterInstruction>(
833 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
835 maxChunkSize, srcShape, subgroupSize);
839xegpu::DistributeLayoutAttr
844 const int subgroupSize =
uArch->getSubgroupSize();
846 auto context = srcVecTy.getContext();
847 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
849 const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
851 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
854 srcShape, subgroupSize);
862template <
typename RankedTy>
865 std::optional<unsigned> packingSize = std::nullopt,
bool vnni =
false) {
867 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
868 "Expected 1D non-vnni or 2D vector.");
870 assert(ty.getElementType().isIntOrFloat() &&
871 "Expected int or float element type.");
873 auto context = ty.getContext();
874 auto rank = ty.getRank();
877 if (packingSize.has_value()) {
878 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
879 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
880 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
883 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
898 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
899 if (sgCount % sgLayout0)
901 int64_t sgLayout1 = sgCount / sgLayout0;
902 int64_t sgData0 = wgShape[0] / sgLayout0;
903 int64_t sgData1 = wgShape[1] / sgLayout1;
904 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
905 (sgData0 % instData[0] || sgData1 % instData[1]))
907 candidates.emplace_back(sgLayout0, sgLayout1);
914 int diffLhs = std::abs(
lhs.first -
lhs.second);
915 int diffRhs = std::abs(
rhs.first -
rhs.second);
916 if (diffLhs != diffRhs)
917 return diffLhs < diffRhs;
918 return lhs.first <
rhs.first;
926 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
927 xegpu::DistributeLayoutAttr>>
929 VectorType bTy, VectorType cdTy,
930 xegpu::DistributeLayoutAttr consumerLayout,
932 auto context = aTy.getContext();
933 const auto *uArchInstruction =
937 auto getInstDataVectors = [&]()
941 const unsigned dataALen = aTy.getShape().front();
942 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
946 const unsigned dataBLen = bTy.getShape().back();
947 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
951 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
954 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
958 instDataA[aTy.getRank() - 2] = maxALen;
959 instDataA[aTy.getRank() - 1] = subgroupSize;
961 instDataB[bTy.getRank() - 2] = subgroupSize;
962 instDataB[bTy.getRank() - 1] = maxBLen;
964 instDataCD[cdTy.getRank() - 2] = maxALen;
965 instDataCD[cdTy.getRank() - 1] = maxCLen;
966 return std::make_tuple(instDataA, instDataB, instDataCD);
971 "Number of subgroups must be provided for sg layout creation.");
972 auto instDataVecs = getInstDataVectors();
975 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
976 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
977 instDataCD.size() == 2 &&
978 "Sg layout creation expects valid 2D inst data");
980 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
981 if (consumerLayout && consumerLayout.isForWorkgroup()) {
983 consumerLayout.getEffectiveSgLayoutAsInt();
984 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
992 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1001 std::optional<LayoutRepresentation> bestPick;
1002 for (
auto &sgLayout : layoutsB) {
1003 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1005 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1006 bestPick = sgLayout;
1014 bestPick = sgLayout;
1020 return std::nullopt;
1022 static_cast<int>(bestPick->second)};
1024 static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1025 static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
1027 static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
1028 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1030 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1031 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1033 auto dpasALayout = xegpu::LayoutAttr::get(
1039 auto dpasBLayout = xegpu::LayoutAttr::get(
1045 auto dpasCDLayout = xegpu::LayoutAttr::get(
1050 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1052 auto instDataVecs = getInstDataVectors();
1054 return std::nullopt;
1055 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1056 return std::make_tuple(
1057 xegpu::LayoutAttr::get(
1059 xegpu::LayoutAttr::get(
1061 xegpu::LayoutAttr::get(
1065 aTy,
uArch, uArchInstruction->getPackedFormatBitSizeA());
1067 bTy,
uArch, uArchInstruction->getPackedFormatBitSizeB(),
true);
1069 cdTy,
uArch, uArchInstruction->getPackedFormatBitSizeB());
1070 return std::make_tuple(aLayout, bLayout, cdLayout);
1072 return std::nullopt;
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, int valShapeSize, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasAttrOfType(NameT &&name)
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
static WalkResult advance()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
@ SubgroupMatrixMultiplyAcc
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
void recoverTemporaryLayoutsDeprecated(Operation *op)
[to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and OpResult of of the given opera...
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
const Instruction * getInstruction(InstructionKind instKind) const