32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
275 xegpu::LayoutAttr::get(ctx, {
uArch->getSubgroupSize()}, {1}));
278 xegpu::LayoutAttr::get(ctx, {1,
uArch->getSubgroupSize()}, {1, 1}));
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291template <
typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
320class LayoutInfoPropagation
327 void visitStoreNdOp(xegpu::StoreNdOp store,
331 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
335 void visitLoadNdOp(xegpu::LoadNdOp
load,
339 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
343 void visitTransposeOp(vector::TransposeOp transpose,
347 void visitVectorBitcastOp(vector::BitCastOp bitcast,
351 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
355 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
359 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
363 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
367 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
370 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
374 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
378 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
382 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
386 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
390 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
394 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
401 layoutKind(layoutKind) {}
408 void visitBranchOperand(
OpOperand &operand)
override {};
410 void visitCallOperand(
OpOperand &operand)
override {};
413 visitNonControlFlowArguments(RegionSuccessor &successor,
414 ArrayRef<BlockArgument> arguments)
override {};
416 void visitExternalCall(CallOpInterface call,
417 ArrayRef<LayoutInfoLattice *> operands,
418 ArrayRef<const LayoutInfoLattice *> results)
override {
421 void setToExitState(LayoutInfoLattice *lattice)
override {
422 (void)lattice->meet(LayoutInfo());
427LogicalResult LayoutInfoPropagation::visitOperation(
428 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
429 ArrayRef<const LayoutInfoLattice *> results) {
432 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
433 .Case([&](xegpu::StoreNdOp storeNdOp) {
434 visitStoreNdOp(storeNdOp, operands, results);
436 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
437 visitStoreScatterOp(storeScatterOp, operands, results);
439 .Case([&](xegpu::LoadNdOp loadNdOp) {
440 visitLoadNdOp(loadNdOp, operands, results);
442 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
443 visitLoadGatherOp(loadGatherOp, operands, results);
445 .Case([&](xegpu::CreateDescOp createDescOp) {
446 visitCreateDescOp(createDescOp, operands, results);
448 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
449 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
451 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
452 visitPrefetchNdOp(prefetchNdOp, operands, results);
454 .Case([&](vector::TransposeOp transposeOp) {
455 visitTransposeOp(transposeOp, operands, results);
457 .Case([&](vector::BitCastOp bitcastOp) {
458 visitVectorBitcastOp(bitcastOp, operands, results);
460 .Case([&](vector::MultiDimReductionOp reductionOp) {
461 visitVectorMultiReductionOp(reductionOp, operands, results);
463 .Case([&](vector::BroadcastOp broadcastOp) {
464 visitVectorBroadCastOp(broadcastOp, operands, results);
466 .Case([&](vector::ShapeCastOp shapeCastOp) {
467 visitShapeCastOp(shapeCastOp, operands, results);
469 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
470 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
472 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
473 visitLoadMatrixOp(loadMatrixOp, operands, results);
475 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
476 visitStoreMatrixOp(storeMatrixOp, operands, results);
479 .Default([&](Operation *op) {
480 for (
const LayoutInfoLattice *resultInfo : results) {
481 if (!resultInfo->getValue().isAssigned())
483 for (
auto [operandInfo, operand] :
487 if (!isa<xegpu::TensorDescType, VectorType>(
488 operand.get().getType()))
491 meet(operandInfo, *resultInfo);
499bool LayoutInfoPropagation::hasParamsOfLayoutKind(
500 xegpu::DistributeLayoutAttr anchorLayout) {
501 if (anchorLayout ==
nullptr) {
504 if (layoutKind == xegpu::LayoutKind::InstData) {
505 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
506 }
else if (layoutKind == xegpu::LayoutKind::Lane) {
507 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
508 anchorLayout.getEffectiveLaneDataAsInt().empty());
509 }
else if (layoutKind == xegpu::LayoutKind::Subgroup) {
510 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
511 anchorLayout.getEffectiveSgDataAsInt().empty());
527 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
528 if (sgCount % sgLayout0)
530 int sgLayout1 = sgCount / sgLayout0;
531 int sgData0 = wgShape[0] / sgLayout0;
532 int sgData1 = wgShape[1] / sgLayout1;
533 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
534 (sgData0 % instData[0] || sgData1 % instData[1]))
536 candidates.emplace_back(sgLayout0, sgLayout1);
541 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
542 const std::pair<int, int> &
rhs) {
543 int diffLhs = std::abs(
lhs.first -
lhs.second);
544 int diffRhs = std::abs(
rhs.first -
rhs.second);
545 if (diffLhs != diffRhs)
546 return diffLhs < diffRhs;
547 return lhs.first <
rhs.first;
557 auto knownBlockSize = gpuFunc.getKnownBlockSize();
558 if (!knownBlockSize.has_value())
560 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
561 return flatBlockSize / sgSize;
564void LayoutInfoPropagation::visitPrefetchNdOp(
565 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
566 ArrayRef<const LayoutInfoLattice *> results) {
568 LayoutInfo prefetchLayout;
569 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
570 if (hasParamsOfLayoutKind(anchorLayout)) {
571 prefetchLayout = LayoutInfo(anchorLayout);
575 auto tdescTy = prefetch.getTensorDescType();
578 const auto *uArchInstruction =
579 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
581 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
584 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
586 prefetch.emitWarning(
"No known block params found for the element type.");
587 auto [bWidth, bHeight, bCount] = blockWHC.value();
588 SmallVector<int> instData;
590 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
592 prefetch.emitWarning(
593 "No suitable instruction multiple found for the given shape.");
594 if (tdescTy.getRank() == 1)
595 instData = {instWidth};
598 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
599 if (instHeight == -1)
600 prefetch.emitWarning(
601 "No suitable instruction multiple found for the given shape.");
602 instData = {instHeight, instWidth};
605 if (layoutKind == xegpu::LayoutKind::InstData)
607 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
609 prefetchLayout = getSIMTLayoutInfoBlockIO(
610 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
612 prefetch.setLayoutAttr(
613 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
616 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
619void LayoutInfoPropagation::visitVectorMultiReductionOp(
620 vector::MultiDimReductionOp reduction,
621 ArrayRef<LayoutInfoLattice *> operands,
622 ArrayRef<const LayoutInfoLattice *> results) {
624 LayoutInfo resLayoutInfo = results[0]->getValue();
625 if (!resLayoutInfo.isAssigned())
628 VectorType sourceTy = reduction.getSourceVectorType();
629 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
632 auto consumerLayoutAttr =
633 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
641 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
647 requiredResLayoutAttr, reductionDims);
649 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
651 propagateIfChanged(operands[1],
652 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
655void LayoutInfoPropagation::visitVectorBroadCastOp(
656 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
657 ArrayRef<const LayoutInfoLattice *> results) {
659 LayoutInfo resLayoutInfo = results[0]->getValue();
660 if (!resLayoutInfo.isAssigned())
664 VectorType resultTy =
broadcast.getResultVectorType();
665 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
670 auto srcShape = sourceTy.getShape();
671 auto resShape = resultTy.getShape();
673 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
674 for (
size_t i = 0; i < srcShape.size(); i++)
675 if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
676 broadcast.emitWarning(
"broadcast must either from low-rank or same-rank "
677 "with unit-dim, mixed scenario is not supported!");
679 auto resultLayoutAttr =
680 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
682 xegpu::DistributeLayoutAttr srcLayoutAttr =
685 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
689void LayoutInfoPropagation::visitShapeCastOp(
690 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
691 ArrayRef<const LayoutInfoLattice *> results) {
693 LayoutInfo resLayoutInfo = results[0]->getValue();
694 if (!resLayoutInfo.isAssigned())
696 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
697 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
698 auto resultLayoutAttr =
699 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
701 xegpu::DistributeLayoutAttr srcLayoutAttr =
704 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
709void LayoutInfoPropagation::visitUpdateNdOffsetOp(
710 xegpu::UpdateNdOffsetOp updateNdOffset,
711 ArrayRef<LayoutInfoLattice *> operands,
712 ArrayRef<const LayoutInfoLattice *> results) {
714 LayoutInfo resultLayout = results[0]->getValue();
715 if (!resultLayout.isAssigned())
718 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
722void LayoutInfoPropagation::visitDpasOp(
723 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
725 LayoutInfo dpasALayout;
726 LayoutInfo dpasBLayout;
727 LayoutInfo dpasCDLayout;
729 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
730 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
731 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
732 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
733 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
734 "Expected anchor layout for DPAS A operand.");
735 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
736 "Expected anchor layout for DPAS B operand.");
737 dpasALayout = LayoutInfo(anchorLayoutA);
738 dpasBLayout = LayoutInfo(anchorLayoutB);
739 dpasCDLayout = LayoutInfo(anchorLayoutCD);
742 VectorType aTy = dpas.getLhsType();
743 VectorType bTy = dpas.getRhsType();
744 VectorType cdTy = dpas.getResultType();
746 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
747 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
751 if (layoutKind == xegpu::LayoutKind::Subgroup) {
752 LayoutInfo consumerLayout = results[0]->getValue();
753 if (!consumerLayout.isAssigned())
756 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
760 "Unable to determine the number of subgroups for the operation.");
763 numSg = numSgOrErr.value();
766 consumerLayoutAttr, uArch, numSg);
767 if (!layouts.has_value()) {
769 "Failed to determine required layouts for DPAS operands.");
773 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
775 dpas.setLayoutAAttr(requiredALayout);
776 dpas.setLayoutBAttr(requiredBLayout);
777 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
778 dpasALayout = LayoutInfo(requiredALayout);
779 dpasBLayout = LayoutInfo(requiredBLayout);
780 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
782 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
783 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
784 if (operands.size() > 2)
785 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
789void LayoutInfoPropagation::visitStoreNdOp(
790 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
791 ArrayRef<const LayoutInfoLattice *> results) {
792 LayoutInfo storeLayout;
793 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
794 if (hasParamsOfLayoutKind(anchorLayout)) {
795 storeLayout = LayoutInfo(anchorLayout);
798 const auto *uArchInstruction =
799 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
801 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
802 VectorType dataTy = store.getValueType();
803 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
804 store.getValueType().getElementType());
806 store.emitWarning(
"No known block params found for the element type.");
807 auto [bWidth, bHeight, bCount] = blockWHC.value();
808 SmallVector<int> instData;
810 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
813 "No suitable instruction multiple found for the given shape.");
814 if (dataTy.getRank() == 1)
815 instData = {instWidth};
818 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
819 if (instHeight == -1)
821 "No suitable instruction multiple found for the given shape.");
822 instData = {instHeight, instWidth};
825 if (layoutKind == xegpu::LayoutKind::InstData)
827 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
828 else if (layoutKind == xegpu::LayoutKind::Lane)
830 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
831 uArchInstruction->getPackedFormatBitSize());
834 auto numSgOrErr =
getNumSg(store, sgSize);
837 "Unable to determine the number of subgroups for the operation.");
841 instData, numSgOrErr.value());
842 if (sgLayouts.empty()) {
844 "Unable to determine suitable subgroup layout for store value.");
847 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
848 SmallVector<int> sgData = {
849 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
850 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
851 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
859 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
863 for (LayoutInfoLattice *operand : operands)
864 propagateIfChanged(operand, operand->meet(storeLayout));
869void LayoutInfoPropagation::visitLoadNdOp(
870 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
871 ArrayRef<const LayoutInfoLattice *> results) {
872 LayoutInfo loadLayout;
873 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
874 if (hasParamsOfLayoutKind(anchorLayout)) {
875 loadLayout = LayoutInfo(anchorLayout);
878 LayoutInfo valueLayout = results[0]->getValue();
880 if (!valueLayout.isAssigned())
882 loadLayout = valueLayout;
886 if (
auto transpose =
load.getTranspose()) {
887 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
888 "LayoutInfoPropagation stage.");
889 loadLayout = valueLayout.transpose(transpose.value());
891 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
894 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
899void LayoutInfoPropagation::visitTransposeOp(
900 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
901 ArrayRef<const LayoutInfoLattice *> results) {
903 LayoutInfo resultLayout = results[0]->getValue();
904 if (!resultLayout.isAssigned())
906 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
908 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
913void LayoutInfoPropagation::visitVectorBitcastOp(
914 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
915 ArrayRef<const LayoutInfoLattice *> results) {
917 LayoutInfo resLayoutInfo = results[0]->getValue();
918 if (!resLayoutInfo.isAssigned())
921 auto srcVecType = bitcast.getSourceVectorType();
922 auto resVecType = bitcast.getResultVectorType();
924 auto consumerLayoutAttr =
925 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
928 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
932 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
933 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
937 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
939 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
942void LayoutInfoPropagation::visitInsertStridedSliceOp(
943 vector::InsertStridedSliceOp insertStridedSlice,
944 ArrayRef<LayoutInfoLattice *> operands,
945 ArrayRef<const LayoutInfoLattice *> results) {
947 LayoutInfo resLayoutInfo = results[0]->getValue();
948 if (!resLayoutInfo.isAssigned())
951 auto srcVecType = insertStridedSlice.getSourceVectorType();
952 auto resVecType = insertStridedSlice.getDestVectorType();
954 auto consumerLayoutAttr =
955 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
959 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
962 requiredResLayoutAttr);
965 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
967 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
968 propagateIfChanged(operands[1],
969 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
975void LayoutInfoPropagation::visitLoadGatherOp(
976 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
977 ArrayRef<const LayoutInfoLattice *> results) {
978 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
979 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
982 VectorType resVecTy =
load.getValueType();
983 int chunkSize =
load.getChunkSize().value_or(1);
985 LayoutInfo resLayoutInfo = results[0]->getValue();
986 if (!resLayoutInfo.isAssigned())
988 auto consumerLayoutAttr =
989 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
991 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
992 requiredAnchorLayoutAttr = anchorLayoutAttr;
995 load.emitWarning(
"Not propagating, non-vector payload supplied.");
999 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1000 load.setLayoutAttr(requiredAnchorLayoutAttr);
1003 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1006 if (chunkSize > 1) {
1007 if (layoutKind == xegpu::LayoutKind::InstData)
1009 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize});
1010 else if (layoutKind == xegpu::LayoutKind::Lane)
1012 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize}, {1});
1015 "chunked StoreScatterOp should not be used at workgroup level");
1018 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1019 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1022 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1023 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1025 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1026 if (
load.getOffsets())
1027 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1032void LayoutInfoPropagation::visitCreateDescOp(
1033 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1034 ArrayRef<const LayoutInfoLattice *> results) {
1035 LayoutInfo descLayout = results[0]->getValue();
1037 if (!descLayout.isAssigned())
1041 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1043 propagateIfChanged(operands[1], operands[1]->meet(layout));
1048void LayoutInfoPropagation::visitStoreScatterOp(
1049 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1050 ArrayRef<const LayoutInfoLattice *> results) {
1052 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1053 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1056 VectorType srcVecTy = storeScatter.getValueType();
1057 int chunkSize = storeScatter.getChunkSize().value_or(1);
1059 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1060 requiredAnchorLayoutAttr = anchorLayoutAttr;
1063 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1067 layoutKind, srcVecTy, chunkSize, uArch);
1068 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1071 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1072 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1075 if (chunkSize > 1) {
1076 if (layoutKind == xegpu::LayoutKind::InstData)
1078 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1079 else if (layoutKind == xegpu::LayoutKind::Lane)
1080 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1081 {subgroupSize}, {1});
1084 "chunked StoreScatterOp should not be used at workgroup level");
1087 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1090 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1092 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1093 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1095 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1096 if (storeScatter.getOffsets())
1097 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1100void LayoutInfoPropagation::visitLoadMatrixOp(
1101 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1102 ArrayRef<const LayoutInfoLattice *> results) {
1104 LayoutInfo resLayoutInfo = results[0]->getValue();
1105 auto consumerLayoutAttr =
1106 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1108 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1112 if (!hasParamsOfLayoutKind(anchorLayout)) {
1113 VectorType resVecTy =
1114 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1115 assert(resVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1118 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1119 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1124void LayoutInfoPropagation::visitStoreMatrixOp(
1125 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1126 ArrayRef<const LayoutInfoLattice *> results) {
1127 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1129 if (hasParamsOfLayoutKind(anchorLayout)) {
1130 layout = LayoutInfo(anchorLayout);
1132 VectorType srcVecTy =
1133 llvm::cast<VectorType>(storeMatrix.getData().getType());
1134 assert(srcVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1136 auto requiredAnchorLayoutAttr =
1138 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1139 layout = LayoutInfo(requiredAnchorLayoutAttr);
1142 propagateIfChanged(operands[0], operands[0]->meet(layout));
1151class RunLayoutInfoPropagation {
1157 SymbolTableCollection symbolTable;
1159 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
1163 LayoutInfo getLayoutInfo(Value val);
1165 void printAnalysisResult(llvm::raw_ostream &os);
1168 DataFlowSolver solver;
1173LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1174 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1177 return state->getValue();
1181void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1182 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1183 os <<
"function: " << funcOp.getName() <<
":\n";
1185 for (BlockArgument arg : funcOp.getArguments()) {
1186 LayoutInfo layout = getLayoutInfo(arg);
1187 os <<
"argument: " << arg <<
"\n";
1193 funcOp.walk([&](Operation *op) {
1199 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1205 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1206 LayoutInfo layout = getLayoutInfo(r);
1207 os <<
"layout for result #" << i <<
": ";
1214 SmallVector<FunctionOpInterface> funcOps;
1215 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1216 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1217 funcOps.push_back(funcOp);
1220 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1221 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1222 funcOps.push_back(gpuFuncOp);
1226 for (FunctionOpInterface funcOp : funcOps)
1227 printFunctionResult(funcOp);
1235struct ResolveLayoutConflicts {
1236 ResolveLayoutConflicts(Operation *parentOp)
1237 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1238 LogicalResult run();
1241 Operation *parentOp;
1243 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1244 LogicalResult resolveVectorConsumer(OpOperand &operand);
1249LogicalResult ResolveLayoutConflicts::run() {
1252 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1255 Type operandType = operand.get().getType();
1256 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1257 isa<xegpu::TensorDescType>(operandType)) {
1258 auto res = resolveTensorDescConsumer(operand);
1262 if (isa<VectorType>(operandType)) {
1263 auto res = resolveVectorConsumer(operand);
1270 return r.wasInterrupted() ? failure() :
success();
1278 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1283 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1284 auto *parentOp = arg.getOwner()->getParentOp();
1285 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1286 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1296ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1303ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1304 Operation *consumerOp = operand.
getOwner();
1305 Value tdescValue = operand.
get();
1306 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1307 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1308 assert(anchorOp && currTDescType &&
1309 "Expected anchor layout op and tensor descriptor consumer.");
1311 if (currTDescType.isScattered()) {
1312 DBGS() <<
"Scattered tensor descriptor not supported: " << tdescValue
1316 Attribute currLayout = currTDescType.getLayout();
1317 Attribute expectedLayout = anchorOp.getAnchorLayout();
1320 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1323 if (!conflictingCreateNdOp) {
1324 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1325 << tdescValue <<
"\n";
1330 auto newTensorDescType = xegpu::TensorDescType::get(
1331 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1332 currTDescType.getElementType(), currTDescType.getEncoding(),
1334 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1335 builder, consumerOp->
getLoc(), newTensorDescType,
1336 conflictingCreateNdOp->getOperands(),
1337 conflictingCreateNdOp->getAttrs());
1355 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1362 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1365 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1366 if (!layout &&
result.getNumUses() > 0) {
1367 op->
emitWarning(
"op has users but no layout assigned for its result");
1372 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1373 auto typeWithLayout = xegpu::TensorDescType::get(
1374 tensorDescTy.getContext(), tensorDescTy.getShape(),
1375 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1376 result.setType(typeWithLayout);
1410 mlir::RegionBranchTerminatorOpInterface terminator,
1413 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1418 branchOp.getSuccessorOperandInputMapping(mapping,
1420 for (
const auto &[successorOperand, successorInputs] : mapping) {
1421 for (
Value successorInput : successorInputs) {
1422 Type inputType = successorInput.getType();
1424 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1426 xegpu::DistributeLayoutAttr successorInputLayout =
1427 getLayoutOfValue(successorInput);
1428 xegpu::DistributeLayoutAttr successorOperandLayout =
1429 getLayoutOfValue(successorOperand->get());
1432 if (!successorOperandLayout) {
1433 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1434 "branch terminator: "
1435 << successorOperand->get() <<
"\n");
1439 if (successorInputLayout &&
1440 successorInputLayout != successorOperandLayout) {
1441 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1442 "operand forwarded as the argument: "
1443 << successorInputLayout <<
" vs "
1444 << successorOperandLayout <<
"\n");
1448 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1449 auto newTdescTy = xegpu::TensorDescType::get(
1450 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1451 tdescTy.getEncoding(), successorOperandLayout);
1452 successorInput.setType(newTdescTy);
1457 if (
auto result = dyn_cast<OpResult>(successorInput))
1466 mlir::FunctionOpInterface funcOp,
1471 Type argType = arg.getType();
1472 newArgTypes.push_back(argType);
1473 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1475 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1477 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1478 <<
" but got none.\n");
1481 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1482 auto newTdescTy = xegpu::TensorDescType::get(
1483 tensorDescTy.getContext(), tensorDescTy.getShape(),
1484 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1485 arg.setType(newTdescTy);
1486 newArgTypes.back() = newTdescTy;
1491 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1492 funcOp.getResultTypes()));
1497struct XeGPUPropagateLayoutPass final
1498 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1499 XeGPUPropagateLayoutPass() =
default;
1500 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1501 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1502 : XeGPUPropagateLayoutBase(
options) {}
1503 void runOnOperation()
override;
1510 RunLayoutInfoPropagation analysis(
target, layoutKind);
1513 auto &os = llvm::outs();
1514 analysis.printAnalysisResult(os);
1518 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1519 LayoutInfo layout = analysis.getLayoutInfo(val);
1520 if (!layout.isAssigned())
1522 if (
auto opResult = dyn_cast<OpResult>(val)) {
1524 Operation *defOp = opResult.getDefiningOp();
1525 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1526 auto anchorLayout = anchorOp.getAnchorLayout();
1527 if (anchorLayout !=
nullptr)
1528 return anchorLayout;
1530 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1532 if (requiredResLayoutAttr !=
nullptr)
1533 return requiredResLayoutAttr;
1535 xegpu::DistributeLayoutAttr layoutAttr =
1536 cast<xegpu::DistributeLayoutAttr>(layout.get());
1537 if (layout.isSliceLayout())
1538 return cast<xegpu::SliceAttr>(layoutAttr);
1540 return cast<xegpu::LayoutAttr>(layoutAttr);
1548 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1550 getXeGPULayoutForValue);
1552 .Case([&](mlir::FunctionOpInterface funcOp) {
1554 getXeGPULayoutForValue);
1557 r =
updateOp(builder, op, getXeGPULayoutForValue);
1560 op.
emitError(
"Failed to update operation with the layout.");
1566 if (walkResult.wasInterrupted())
1573 ResolveLayoutConflicts resolver(
target);
1574 return resolver.run();
1577void XeGPUPropagateLayoutPass::runOnOperation() {
1579 if (this->layoutKind ==
"lane") {
1581 }
else if (this->layoutKind ==
"inst") {
1583 }
else if (this->layoutKind ==
"subgroup") {
1584 layoutKind = xegpu::LayoutKind::Subgroup;
1586 getOperation()->emitError(
"Unsupported layout kind option: " +
1588 signalPassFailure();
1593 this->printOnly))) {
1594 signalPassFailure();
1599 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue)
Helper to get the defining CreateNdDescOp of a tensor descriptor value.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
::mlir::Pass::Option< std::string > layoutKind
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, bool printOnly=false)
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const