32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291template <
typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
320class LayoutInfoPropagation
327 void visitStoreNdOp(xegpu::StoreNdOp store,
331 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
335 void visitLoadNdOp(xegpu::LoadNdOp
load,
339 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
343 void visitTransposeOp(vector::TransposeOp transpose,
347 void visitVectorBitcastOp(vector::BitCastOp bitcast,
351 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
355 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
359 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
363 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
367 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
370 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
374 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
378 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
382 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
386 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
390 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
394 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
401 layoutKind(layoutKind) {}
408 void visitBranchOperand(
OpOperand &operand)
override {};
410 void visitCallOperand(
OpOperand &operand)
override {};
416 void visitExternalCall(CallOpInterface call,
421 void setToExitState(LayoutInfoLattice *lattice)
override {
422 (
void)lattice->meet(LayoutInfo());
427LogicalResult LayoutInfoPropagation::visitOperation(
428 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
429 ArrayRef<const LayoutInfoLattice *> results) {
432 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
433 .Case([&](xegpu::StoreNdOp storeNdOp) {
434 visitStoreNdOp(storeNdOp, operands, results);
436 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
437 visitStoreScatterOp(storeScatterOp, operands, results);
439 .Case([&](xegpu::LoadNdOp loadNdOp) {
440 visitLoadNdOp(loadNdOp, operands, results);
442 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
443 visitLoadGatherOp(loadGatherOp, operands, results);
445 .Case([&](xegpu::CreateDescOp createDescOp) {
446 visitCreateDescOp(createDescOp, operands, results);
448 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
449 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
451 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
452 visitPrefetchNdOp(prefetchNdOp, operands, results);
454 .Case([&](vector::TransposeOp transposeOp) {
455 visitTransposeOp(transposeOp, operands, results);
457 .Case([&](vector::BitCastOp bitcastOp) {
458 visitVectorBitcastOp(bitcastOp, operands, results);
460 .Case([&](vector::MultiDimReductionOp reductionOp) {
461 visitVectorMultiReductionOp(reductionOp, operands, results);
463 .Case([&](vector::BroadcastOp broadcastOp) {
464 visitVectorBroadCastOp(broadcastOp, operands, results);
466 .Case([&](vector::ShapeCastOp shapeCastOp) {
467 visitShapeCastOp(shapeCastOp, operands, results);
469 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
470 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
472 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
473 visitLoadMatrixOp(loadMatrixOp, operands, results);
475 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
476 visitStoreMatrixOp(storeMatrixOp, operands, results);
479 .Default([&](Operation *op) {
480 for (
const LayoutInfoLattice *resultInfo : results) {
481 if (!resultInfo->getValue().isAssigned())
483 for (
auto [operandInfo, operand] :
487 if (!isa<xegpu::TensorDescType, VectorType>(
488 operand.get().getType()))
491 meet(operandInfo, *resultInfo);
499bool LayoutInfoPropagation::hasParamsOfLayoutKind(
500 xegpu::DistributeLayoutAttr anchorLayout) {
501 if (anchorLayout ==
nullptr) {
504 if (layoutKind == xegpu::LayoutKind::InstData) {
505 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
507 if (layoutKind == xegpu::LayoutKind::Lane) {
508 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
509 anchorLayout.getEffectiveLaneDataAsInt().empty());
511 if (layoutKind == xegpu::LayoutKind::Subgroup) {
512 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
513 anchorLayout.getEffectiveSgDataAsInt().empty());
529 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
530 if (sgCount % sgLayout0)
532 int sgLayout1 = sgCount / sgLayout0;
533 int sgData0 = wgShape[0] / sgLayout0;
534 int sgData1 = wgShape[1] / sgLayout1;
535 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
536 (sgData0 % instData[0] || sgData1 % instData[1]))
538 candidates.emplace_back(sgLayout0, sgLayout1);
543 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
544 const std::pair<int, int> &
rhs) {
545 int diffLhs = std::abs(
lhs.first -
lhs.second);
546 int diffRhs = std::abs(
rhs.first -
rhs.second);
547 if (diffLhs != diffRhs)
548 return diffLhs < diffRhs;
549 return lhs.first <
rhs.first;
559 auto knownBlockSize = gpuFunc.getKnownBlockSize();
560 if (!knownBlockSize.has_value())
562 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
563 return flatBlockSize / sgSize;
566void LayoutInfoPropagation::visitPrefetchNdOp(
567 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
568 ArrayRef<const LayoutInfoLattice *> results) {
570 LayoutInfo prefetchLayout;
571 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
572 if (hasParamsOfLayoutKind(anchorLayout)) {
573 prefetchLayout = LayoutInfo(anchorLayout);
577 auto tdescTy = prefetch.getTensorDescType();
582 const auto *uArchInstruction =
583 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
585 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
588 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
590 prefetch.emitWarning(
"No known block params found for the element type.");
591 auto [bWidth, bHeight, bCount] = blockWHC.value();
592 SmallVector<int> instData;
594 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
596 prefetch.emitWarning(
597 "No suitable instruction multiple found for the given shape.");
598 if (tdescTy.getRank() == 1)
599 instData = {instWidth};
602 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
603 if (instHeight == -1)
604 prefetch.emitWarning(
605 "No suitable instruction multiple found for the given shape.");
606 instData = {instHeight, instWidth};
609 if (layoutKind == xegpu::LayoutKind::InstData)
611 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
613 prefetchLayout = getSIMTLayoutInfoBlockIO(
614 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
616 prefetch.setLayoutAttr(
617 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
620 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
623void LayoutInfoPropagation::visitVectorMultiReductionOp(
624 vector::MultiDimReductionOp reduction,
625 ArrayRef<LayoutInfoLattice *> operands,
626 ArrayRef<const LayoutInfoLattice *> results) {
628 LayoutInfo resLayoutInfo = results[0]->getValue();
629 if (!resLayoutInfo.isAssigned())
632 VectorType sourceTy = reduction.getSourceVectorType();
633 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
638 auto consumerLayoutAttr =
639 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
647 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
653 requiredResLayoutAttr, reductionDims);
655 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
657 propagateIfChanged(operands[1],
658 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
661void LayoutInfoPropagation::visitVectorBroadCastOp(
662 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
663 ArrayRef<const LayoutInfoLattice *> results) {
665 LayoutInfo resLayoutInfo = results[0]->getValue();
666 if (!resLayoutInfo.isAssigned())
670 VectorType resultTy =
broadcast.getResultVectorType();
671 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
676 auto srcShape = sourceTy.getShape();
677 auto resShape = resultTy.getShape();
679 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
680 for (
size_t i = 0; i < srcShape.size(); i++)
681 if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
682 broadcast.emitWarning(
"broadcast must either from low-rank or same-rank "
683 "with unit-dim, mixed scenario is not supported!");
685 auto resultLayoutAttr =
686 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
688 xegpu::DistributeLayoutAttr srcLayoutAttr =
691 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
694void LayoutInfoPropagation::visitShapeCastOp(
695 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
696 ArrayRef<const LayoutInfoLattice *> results) {
698 LayoutInfo resLayoutInfo = results[0]->getValue();
699 if (!resLayoutInfo.isAssigned())
701 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
702 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
703 auto resultLayoutAttr =
704 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
706 xegpu::DistributeLayoutAttr srcLayoutAttr =
709 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
714void LayoutInfoPropagation::visitUpdateNdOffsetOp(
715 xegpu::UpdateNdOffsetOp updateNdOffset,
716 ArrayRef<LayoutInfoLattice *> operands,
717 ArrayRef<const LayoutInfoLattice *> results) {
719 LayoutInfo resultLayout = results[0]->getValue();
720 if (!resultLayout.isAssigned())
723 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
727void LayoutInfoPropagation::visitDpasOp(
728 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
729 ArrayRef<const LayoutInfoLattice *> results) {
730 LayoutInfo dpasALayout;
731 LayoutInfo dpasBLayout;
732 LayoutInfo dpasCDLayout;
734 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
735 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
736 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
737 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
738 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
739 "Expected anchor layout for DPAS A operand.");
740 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
741 "Expected anchor layout for DPAS B operand.");
742 dpasALayout = LayoutInfo(anchorLayoutA);
743 dpasBLayout = LayoutInfo(anchorLayoutB);
744 dpasCDLayout = LayoutInfo(anchorLayoutCD);
749 VectorType aTy = dpas.getLhsType();
750 VectorType bTy = dpas.getRhsType();
751 VectorType cdTy = dpas.getResultType();
753 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
754 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
758 if (layoutKind == xegpu::LayoutKind::Subgroup) {
759 LayoutInfo consumerLayout = results[0]->getValue();
760 if (!consumerLayout.isAssigned())
763 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
767 "Unable to determine the number of subgroups for the operation.");
770 numSg = numSgOrErr.value();
773 consumerLayoutAttr, uArch, numSg);
774 if (!layouts.has_value()) {
776 "Failed to determine required layouts for DPAS operands.");
780 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
782 dpas.setLayoutAAttr(requiredALayout);
783 dpas.setLayoutBAttr(requiredBLayout);
784 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
785 dpasALayout = LayoutInfo(requiredALayout);
786 dpasBLayout = LayoutInfo(requiredBLayout);
787 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
789 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
790 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
791 if (operands.size() > 2)
792 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
796void LayoutInfoPropagation::visitStoreNdOp(
797 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
798 ArrayRef<const LayoutInfoLattice *> results) {
799 LayoutInfo storeLayout;
800 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
801 if (hasParamsOfLayoutKind(anchorLayout)) {
802 storeLayout = LayoutInfo(anchorLayout);
807 const auto *uArchInstruction =
808 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
810 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
811 VectorType dataTy = store.getValueType();
812 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
813 store.getValueType().getElementType());
815 store.emitWarning(
"No known block params found for the element type.");
816 auto [bWidth, bHeight, bCount] = blockWHC.value();
817 SmallVector<int> instData;
819 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
822 "No suitable instruction multiple found for the given shape.");
823 if (dataTy.getRank() == 1)
824 instData = {instWidth};
827 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
828 if (instHeight == -1)
830 "No suitable instruction multiple found for the given shape.");
831 instData = {instHeight, instWidth};
834 if (layoutKind == xegpu::LayoutKind::InstData)
836 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
837 else if (layoutKind == xegpu::LayoutKind::Lane)
839 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
840 uArchInstruction->getPackedFormatBitSize());
843 auto numSgOrErr =
getNumSg(store, sgSize);
846 "Unable to determine the number of subgroups for the operation.");
850 instData, numSgOrErr.value());
851 if (sgLayouts.empty()) {
853 "Unable to determine suitable subgroup layout for store value.");
856 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
857 SmallVector<int> sgData = {
858 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
859 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
860 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
868 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
872 for (LayoutInfoLattice *operand : operands)
873 propagateIfChanged(operand, operand->meet(storeLayout));
878void LayoutInfoPropagation::visitLoadNdOp(
879 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
880 ArrayRef<const LayoutInfoLattice *> results) {
881 LayoutInfo loadLayout;
882 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
883 if (hasParamsOfLayoutKind(anchorLayout)) {
884 loadLayout = LayoutInfo(anchorLayout);
887 LayoutInfo valueLayout = results[0]->getValue();
889 if (!valueLayout.isAssigned())
891 loadLayout = valueLayout;
895 if (
auto transpose =
load.getTranspose()) {
896 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
897 "LayoutInfoPropagation stage.");
898 loadLayout = valueLayout.transpose(transpose.value());
900 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
903 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
908void LayoutInfoPropagation::visitTransposeOp(
909 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
910 ArrayRef<const LayoutInfoLattice *> results) {
912 LayoutInfo resultLayout = results[0]->getValue();
913 if (!resultLayout.isAssigned())
915 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
917 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
922void LayoutInfoPropagation::visitVectorBitcastOp(
923 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
924 ArrayRef<const LayoutInfoLattice *> results) {
926 LayoutInfo resLayoutInfo = results[0]->getValue();
927 if (!resLayoutInfo.isAssigned())
930 auto srcVecType = bitcast.getSourceVectorType();
931 auto resVecType = bitcast.getResultVectorType();
933 auto consumerLayoutAttr =
934 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
939 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
943 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
944 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
948 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
950 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
953void LayoutInfoPropagation::visitInsertStridedSliceOp(
954 vector::InsertStridedSliceOp insertStridedSlice,
955 ArrayRef<LayoutInfoLattice *> operands,
956 ArrayRef<const LayoutInfoLattice *> results) {
958 LayoutInfo resLayoutInfo = results[0]->getValue();
959 if (!resLayoutInfo.isAssigned())
962 auto srcVecType = insertStridedSlice.getSourceVectorType();
963 auto resVecType = insertStridedSlice.getDestVectorType();
965 auto consumerLayoutAttr =
966 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
973 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
976 requiredResLayoutAttr);
979 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
981 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
982 propagateIfChanged(operands[1],
983 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
988void LayoutInfoPropagation::visitLoadGatherOp(
989 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
990 ArrayRef<const LayoutInfoLattice *> results) {
991 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
992 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
997 VectorType resVecTy =
load.getValueType();
998 int chunkSize =
load.getChunkSize().value_or(1);
1000 LayoutInfo resLayoutInfo = results[0]->getValue();
1001 if (!resLayoutInfo.isAssigned())
1003 auto consumerLayoutAttr =
1004 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1006 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1007 requiredAnchorLayoutAttr = anchorLayoutAttr;
1010 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1014 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1015 load.setLayoutAttr(requiredAnchorLayoutAttr);
1018 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1021 if (chunkSize > 1) {
1022 if (layoutKind == xegpu::LayoutKind::InstData)
1024 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize});
1025 else if (layoutKind == xegpu::LayoutKind::Lane)
1027 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize}, {1});
1030 "chunked StoreScatterOp should not be used at workgroup level");
1033 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1034 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1037 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1038 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1040 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1041 if (
load.getOffsets())
1042 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1047void LayoutInfoPropagation::visitCreateDescOp(
1048 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1049 ArrayRef<const LayoutInfoLattice *> results) {
1050 LayoutInfo descLayout = results[0]->getValue();
1052 if (!descLayout.isAssigned())
1058 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1060 propagateIfChanged(operands[1], operands[1]->meet(layout));
1065void LayoutInfoPropagation::visitStoreScatterOp(
1066 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1067 ArrayRef<const LayoutInfoLattice *> results) {
1069 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1070 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1075 VectorType srcVecTy = storeScatter.getValueType();
1076 int chunkSize = storeScatter.getChunkSize().value_or(1);
1078 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1079 requiredAnchorLayoutAttr = anchorLayoutAttr;
1082 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1086 layoutKind, srcVecTy, chunkSize, uArch);
1087 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1090 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1091 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1094 if (chunkSize > 1) {
1095 if (layoutKind == xegpu::LayoutKind::InstData)
1097 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1098 else if (layoutKind == xegpu::LayoutKind::Lane)
1099 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1100 {subgroupSize}, {1});
1103 "chunked StoreScatterOp should not be used at workgroup level");
1106 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1109 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1111 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1112 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1114 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1115 if (storeScatter.getOffsets())
1116 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1119void LayoutInfoPropagation::visitLoadMatrixOp(
1120 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1121 ArrayRef<const LayoutInfoLattice *> results) {
1123 LayoutInfo resLayoutInfo = results[0]->getValue();
1124 if (!resLayoutInfo.isAssigned())
1127 auto consumerLayoutAttr =
1128 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1130 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1134 if (!hasParamsOfLayoutKind(anchorLayout)) {
1135 VectorType resVecTy =
1136 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1137 assert(resVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1142 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1143 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1148void LayoutInfoPropagation::visitStoreMatrixOp(
1149 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1150 ArrayRef<const LayoutInfoLattice *> results) {
1151 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1153 if (hasParamsOfLayoutKind(anchorLayout)) {
1154 layout = LayoutInfo(anchorLayout);
1156 VectorType srcVecTy =
1157 llvm::cast<VectorType>(storeMatrix.getData().getType());
1158 assert(srcVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1162 auto requiredAnchorLayoutAttr =
1164 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1165 layout = LayoutInfo(requiredAnchorLayoutAttr);
1168 propagateIfChanged(operands[0], operands[0]->meet(layout));
1177class RunLayoutInfoPropagation {
1183 SymbolTableCollection symbolTable;
1185 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind);
1189 LayoutInfo getLayoutInfo(Value val);
1191 void printAnalysisResult(llvm::raw_ostream &os);
1194 DataFlowSolver solver;
1199LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1200 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1203 return state->getValue();
1207void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1208 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1209 os <<
"function: " << funcOp.getName() <<
":\n";
1211 for (BlockArgument arg : funcOp.getArguments()) {
1212 LayoutInfo layout = getLayoutInfo(arg);
1213 os <<
"argument: " << arg <<
"\n";
1219 funcOp.walk([&](Operation *op) {
1225 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1231 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1232 LayoutInfo layout = getLayoutInfo(r);
1233 os <<
"layout for result #" << i <<
": ";
1240 SmallVector<FunctionOpInterface> funcOps;
1241 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1242 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1243 funcOps.push_back(funcOp);
1246 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1247 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1248 funcOps.push_back(gpuFuncOp);
1252 for (FunctionOpInterface funcOp : funcOps)
1253 printFunctionResult(funcOp);
1265static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1267 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1272 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1273 auto *parentOp = arg.getOwner()->getParentOp();
1274 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1275 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1277 return getDefiningCreateNdDescOp(tiedInit->
get());
1284struct ResolveLayoutConflicts {
1285 ResolveLayoutConflicts(Operation *parentOp)
1286 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1287 LogicalResult run();
1290 Operation *parentOp;
1292 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1293 LogicalResult resolveVectorConsumer(OpOperand &operand);
1298LogicalResult ResolveLayoutConflicts::run() {
1301 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1304 Type operandType = operand.get().getType();
1305 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1306 isa<xegpu::TensorDescType>(operandType)) {
1307 auto res = resolveTensorDescConsumer(operand);
1309 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1315 if (isa<VectorType>(operandType)) {
1316 auto res = resolveVectorConsumer(operand);
1318 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1326 return r.wasInterrupted() ? failure() :
success();
1330ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1331 Value vectorValue = operand.
get();
1332 Operation *consumerOp = operand.
getOwner();
1335 if (!producerLayout) {
1336 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1337 vectorTy && vectorTy.getRank() > 1)
1338 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1343 if (!consumerLayout)
1345 "No consumer layout found for vector operand.");
1348 if (consumerLayout.isEqualTo(producerLayout))
1353 auto convertOp = xegpu::ConvertLayoutOp::create(
1354 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1355 producerLayout, consumerLayout);
1358 operand.
set(convertOp.getResult());
1363ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1364 Operation *consumerOp = operand.
getOwner();
1365 Value tdescValue = operand.
get();
1366 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1367 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1368 assert(anchorOp && currTDescType &&
1369 "Expected anchor layout op and tensor descriptor consumer.");
1371 if (currTDescType.isScattered()) {
1372 DBGS() <<
"Scattered tensor descriptor not supported: " << tdescValue
1376 Attribute currLayout = currTDescType.getLayout();
1377 Attribute expectedLayout = anchorOp.getAnchorLayout();
1380 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1382 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1383 if (!conflictingCreateNdOp) {
1384 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1385 << tdescValue <<
"\n";
1390 auto newTensorDescType = xegpu::TensorDescType::get(
1391 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1392 currTDescType.getElementType(), currTDescType.getEncoding(),
1394 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1395 builder, consumerOp->
getLoc(), newTensorDescType,
1396 conflictingCreateNdOp->getOperands(),
1397 conflictingCreateNdOp->getAttrs());
1415 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1422 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1425 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1426 if (!layout &&
result.getNumUses() > 0) {
1427 op->
emitWarning(
"op has users but no layout assigned for its result");
1432 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1433 auto typeWithLayout = xegpu::TensorDescType::get(
1434 tensorDescTy.getContext(), tensorDescTy.getShape(),
1435 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1436 result.setType(typeWithLayout);
1470 mlir::RegionBranchTerminatorOpInterface terminator,
1473 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1478 branchOp.getSuccessorOperandInputMapping(mapping,
1480 for (
const auto &[successorOperand, successorInputs] : mapping) {
1481 for (
Value successorInput : successorInputs) {
1482 Type inputType = successorInput.getType();
1484 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1486 xegpu::DistributeLayoutAttr successorInputLayout =
1487 getLayoutOfValue(successorInput);
1488 xegpu::DistributeLayoutAttr successorOperandLayout =
1489 getLayoutOfValue(successorOperand->get());
1492 if (!successorOperandLayout) {
1493 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1494 "branch terminator: "
1495 << successorOperand->get() <<
"\n");
1499 if (successorInputLayout &&
1500 successorInputLayout != successorOperandLayout) {
1501 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1502 "operand forwarded as the argument: "
1503 << successorInputLayout <<
" vs "
1504 << successorOperandLayout <<
"\n");
1508 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1509 auto newTdescTy = xegpu::TensorDescType::get(
1510 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1511 tdescTy.getEncoding(), successorOperandLayout);
1512 successorInput.setType(newTdescTy);
1517 if (
auto result = dyn_cast<OpResult>(successorInput))
1526 mlir::FunctionOpInterface funcOp,
1532 if (!isa<FunctionType>(funcOp.getFunctionType()))
1537 Type argType = arg.getType();
1538 newArgTypes.push_back(argType);
1539 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1541 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1543 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1544 <<
" but got none.\n");
1547 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1548 auto newTdescTy = xegpu::TensorDescType::get(
1549 tensorDescTy.getContext(), tensorDescTy.getShape(),
1550 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1551 arg.setType(newTdescTy);
1552 newArgTypes.back() = newTdescTy;
1557 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1558 funcOp.getResultTypes()));
1563struct XeGPUPropagateLayoutPass final
1564 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1565 XeGPUPropagateLayoutPass() =
default;
1566 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1567 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1568 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1569 void runOnOperation()
override;
1576 RunLayoutInfoPropagation analysis(
target, layoutKind);
1579 auto &os = llvm::outs();
1580 analysis.printAnalysisResult(os);
1584 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1585 LayoutInfo layout = analysis.getLayoutInfo(val);
1586 if (!layout.isAssigned())
1588 if (
auto opResult = dyn_cast<OpResult>(val)) {
1590 Operation *defOp = opResult.getDefiningOp();
1591 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1592 auto anchorLayout = anchorOp.getAnchorLayout();
1593 if (anchorLayout !=
nullptr)
1594 return anchorLayout;
1596 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1598 if (requiredResLayoutAttr !=
nullptr)
1599 return requiredResLayoutAttr;
1601 xegpu::DistributeLayoutAttr layoutAttr =
1602 cast<xegpu::DistributeLayoutAttr>(layout.get());
1603 if (layout.isSliceLayout())
1604 return cast<xegpu::SliceAttr>(layoutAttr);
1606 return cast<xegpu::LayoutAttr>(layoutAttr);
1614 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1616 getXeGPULayoutForValue);
1618 .Case([&](mlir::FunctionOpInterface funcOp) {
1620 getXeGPULayoutForValue);
1623 r =
updateOp(builder, op, getXeGPULayoutForValue);
1626 op.
emitError(
"Failed to update operation with the layout.");
1632 if (walkResult.wasInterrupted())
1639 ResolveLayoutConflicts resolver(
target);
1640 return resolver.run();
1643void XeGPUPropagateLayoutPass::runOnOperation() {
1645 if (this->layoutKind ==
"lane") {
1647 }
else if (this->layoutKind ==
"inst") {
1649 }
else if (this->layoutKind ==
"subgroup") {
1650 layoutKind = xegpu::LayoutKind::Subgroup;
1652 getOperation()->emitError(
"Unsupported layout kind option: " +
1654 signalPassFailure();
1659 this->printOnly))) {
1660 signalPassFailure();
1665 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, bool printOnly=false)
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const