32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
282template <
typename Ty>
283static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
285 unsigned packingSize) {
287 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
288 "Expected 1D or 2D vector.");
290 assert(ty.getElementType().isIntOrFloat() &&
291 "Expected int or float element type.");
293 if (ty.getRank() == 1)
294 return getDefaultSIMTLayoutInfo(ty.getContext(), 1,
uArch);
296 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
297 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
298 return LayoutInfo(xegpu::LayoutAttr::get(
299 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
311class LayoutInfoPropagation
315 unsigned indexBitWidth;
319 void visitStoreNdOp(xegpu::StoreNdOp store,
323 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
327 void visitLoadNdOp(xegpu::LoadNdOp
load,
331 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
335 void visitTransposeOp(vector::TransposeOp transpose,
339 void visitVectorBitcastOp(vector::BitCastOp bitcast,
343 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
347 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
351 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
355 void visitVectorReductionOp(vector::ReductionOp reduction,
359 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
362 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
366 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
370 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
374 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
378 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
382 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
386 void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
390 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
397 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
404 void visitBranchOperand(
OpOperand &operand)
override {};
406 void visitCallOperand(
OpOperand &operand)
override {};
412 void visitExternalCall(CallOpInterface call,
417 void setToExitState(LayoutInfoLattice *lattice)
override {
418 (
void)lattice->meet(LayoutInfo());
423LogicalResult LayoutInfoPropagation::visitOperation(
424 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
425 ArrayRef<const LayoutInfoLattice *> results) {
428 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
429 .Case([&](xegpu::StoreNdOp storeNdOp) {
430 visitStoreNdOp(storeNdOp, operands, results);
432 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
433 visitStoreScatterOp(storeScatterOp, operands, results);
435 .Case([&](xegpu::LoadNdOp loadNdOp) {
436 visitLoadNdOp(loadNdOp, operands, results);
438 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
439 visitLoadGatherOp(loadGatherOp, operands, results);
441 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
442 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
444 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
445 visitPrefetchNdOp(prefetchNdOp, operands, results);
447 .Case([&](vector::TransposeOp transposeOp) {
448 visitTransposeOp(transposeOp, operands, results);
450 .Case([&](vector::BitCastOp bitcastOp) {
451 visitVectorBitcastOp(bitcastOp, operands, results);
453 .Case([&](vector::MultiDimReductionOp reductionOp) {
454 visitVectorMultiReductionOp(reductionOp, operands, results);
456 .Case([&](vector::ReductionOp reductionOp) {
457 visitVectorReductionOp(reductionOp, operands, results);
459 .Case([&](vector::BroadcastOp broadcastOp) {
460 visitVectorBroadCastOp(broadcastOp, operands, results);
462 .Case([&](vector::ShapeCastOp shapeCastOp) {
463 visitShapeCastOp(shapeCastOp, operands, results);
465 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
466 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
468 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
469 visitLoadMatrixOp(loadMatrixOp, operands, results);
471 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
472 visitStoreMatrixOp(storeMatrixOp, operands, results);
474 .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
475 visitConvertLayoutOp(convertLayoutOp, operands, results);
478 .Default([&](Operation *op) {
479 for (
const LayoutInfoLattice *resultInfo : results) {
480 if (!resultInfo->getValue().isAssigned())
482 for (
auto [operandInfo, operand] :
486 if (!isa<xegpu::TensorDescType, VectorType>(
487 operand.get().getType()))
490 meet(operandInfo, *resultInfo);
498bool LayoutInfoPropagation::hasParamsOfLayoutKind(
499 xegpu::DistributeLayoutAttr anchorLayout) {
500 if (anchorLayout ==
nullptr) {
503 if (layoutKind == xegpu::LayoutKind::InstData) {
504 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
506 if (layoutKind == xegpu::LayoutKind::Lane) {
507 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
508 anchorLayout.getEffectiveLaneDataAsInt().empty());
510 if (layoutKind == xegpu::LayoutKind::Subgroup) {
511 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
512 anchorLayout.getEffectiveSgDataAsInt().empty());
528 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
529 if (sgCount % sgLayout0)
531 int sgLayout1 = sgCount / sgLayout0;
532 int sgData0 = wgShape[0] / sgLayout0;
533 int sgData1 = wgShape[1] / sgLayout1;
534 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
535 (sgData0 % instData[0] || sgData1 % instData[1]))
537 candidates.emplace_back(sgLayout0, sgLayout1);
542 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
543 const std::pair<int, int> &
rhs) {
544 int diffLhs = std::abs(
lhs.first -
lhs.second);
545 int diffRhs = std::abs(
rhs.first -
rhs.second);
546 if (diffLhs != diffRhs)
547 return diffLhs < diffRhs;
548 return lhs.first <
rhs.first;
558 auto knownBlockSize = gpuFunc.getKnownBlockSize();
559 if (!knownBlockSize.has_value())
561 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
562 return flatBlockSize / sgSize;
565void LayoutInfoPropagation::visitPrefetchNdOp(
566 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
567 ArrayRef<const LayoutInfoLattice *> results) {
569 LayoutInfo prefetchLayout;
570 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
571 if (hasParamsOfLayoutKind(anchorLayout)) {
572 prefetchLayout = LayoutInfo(anchorLayout);
576 auto tdescTy = prefetch.getTensorDescType();
581 const auto *uArchInstruction =
582 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
584 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
587 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
589 prefetch.emitWarning(
"No known block params found for the element type.");
590 auto [bWidth, bHeight, bCount] = blockWHC.value();
591 SmallVector<int> instData;
593 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
595 prefetch.emitWarning(
596 "No suitable instruction multiple found for the given shape.");
597 if (tdescTy.getRank() == 1)
598 instData = {instWidth};
601 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
602 if (instHeight == -1)
603 prefetch.emitWarning(
604 "No suitable instruction multiple found for the given shape.");
605 instData = {instHeight, instWidth};
608 if (layoutKind == xegpu::LayoutKind::InstData)
610 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
612 prefetchLayout = getSIMTLayoutInfoBlockIO(
613 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
615 prefetch.setLayoutAttr(
616 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
619 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
622void LayoutInfoPropagation::visitVectorMultiReductionOp(
623 vector::MultiDimReductionOp reduction,
624 ArrayRef<LayoutInfoLattice *> operands,
625 ArrayRef<const LayoutInfoLattice *> results) {
626 Type resultTy = reduction.getDestType();
628 LayoutInfo resLayoutInfo = results[0]->getValue();
630 xegpu::DistributeLayoutAttr consumerLayoutAttr;
632 if (!resLayoutInfo.isAssigned())
635 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
638 VectorType sourceTy = reduction.getSourceVectorType();
639 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
645 if (layoutKind == xegpu::LayoutKind::Subgroup) {
647 if (succeeded(numSgOrErr))
648 numSg = numSgOrErr.value();
657 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch);
663 requiredResLayoutAttr, reductionDims);
665 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
667 propagateIfChanged(operands[1],
668 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
671void LayoutInfoPropagation::visitVectorReductionOp(
672 vector::ReductionOp reduction, ArrayRef<LayoutInfoLattice *> operands,
673 ArrayRef<const LayoutInfoLattice *> results) {
675 VectorType sourceTy = reduction.getSourceVectorType();
680 auto requiredResLayoutAttr =
685 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
686 if (reduction.getAcc())
687 propagateIfChanged(operands[1],
688 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
691void LayoutInfoPropagation::visitVectorBroadCastOp(
692 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
693 ArrayRef<const LayoutInfoLattice *> results) {
695 LayoutInfo resLayoutInfo = results[0]->getValue();
696 if (!resLayoutInfo.isAssigned())
700 VectorType resultTy =
broadcast.getResultVectorType();
701 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
706 auto srcShape = sourceTy.getShape();
707 auto resShape = resultTy.getShape();
709 auto resultLayoutAttr =
710 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
712 xegpu::DistributeLayoutAttr srcLayoutAttr =
715 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
718void LayoutInfoPropagation::visitShapeCastOp(
719 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
720 ArrayRef<const LayoutInfoLattice *> results) {
722 LayoutInfo resLayoutInfo = results[0]->getValue();
723 if (!resLayoutInfo.isAssigned())
725 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
726 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
727 auto resultLayoutAttr =
728 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
730 xegpu::DistributeLayoutAttr srcLayoutAttr =
733 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
738void LayoutInfoPropagation::visitUpdateNdOffsetOp(
739 xegpu::UpdateNdOffsetOp updateNdOffset,
740 ArrayRef<LayoutInfoLattice *> operands,
741 ArrayRef<const LayoutInfoLattice *> results) {
743 LayoutInfo resultLayout = results[0]->getValue();
744 if (!resultLayout.isAssigned())
747 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
751void LayoutInfoPropagation::visitDpasOp(
752 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
753 ArrayRef<const LayoutInfoLattice *> results) {
754 LayoutInfo dpasALayout;
755 LayoutInfo dpasBLayout;
756 LayoutInfo dpasCDLayout;
758 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
759 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
760 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
761 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
762 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
763 "Expected anchor layout for DPAS A operand.");
764 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
765 "Expected anchor layout for DPAS B operand.");
766 dpasALayout = LayoutInfo(anchorLayoutA);
767 dpasBLayout = LayoutInfo(anchorLayoutB);
768 dpasCDLayout = LayoutInfo(anchorLayoutCD);
773 VectorType aTy = dpas.getLhsType();
774 VectorType bTy = dpas.getRhsType();
775 VectorType cdTy = dpas.getResultType();
777 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
778 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
782 if (layoutKind == xegpu::LayoutKind::Subgroup) {
783 LayoutInfo consumerLayout = results[0]->getValue();
784 if (!consumerLayout.isAssigned())
787 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
791 "Unable to determine the number of subgroups for the operation.");
794 numSg = numSgOrErr.value();
797 consumerLayoutAttr, numSg, uArch);
798 if (!layouts.has_value()) {
800 "Failed to determine required layouts for DPAS operands.");
804 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
806 dpas.setLayoutAAttr(requiredALayout);
807 dpas.setLayoutBAttr(requiredBLayout);
808 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
809 dpasALayout = LayoutInfo(requiredALayout);
810 dpasBLayout = LayoutInfo(requiredBLayout);
811 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
813 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
814 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
815 if (operands.size() > 2)
816 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
820void LayoutInfoPropagation::visitStoreNdOp(
821 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
822 ArrayRef<const LayoutInfoLattice *> results) {
823 LayoutInfo storeLayout;
824 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
825 if (hasParamsOfLayoutKind(anchorLayout)) {
826 storeLayout = LayoutInfo(anchorLayout);
831 const auto *uArchInstruction =
832 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
834 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
835 VectorType dataTy = store.getValueType();
836 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
837 store.getValueType().getElementType());
839 store.emitWarning(
"No known block params found for the element type.");
840 auto [bWidth, bHeight, bCount] = blockWHC.value();
841 SmallVector<int> instData;
843 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
846 "No suitable instruction multiple found for the given shape.");
847 if (dataTy.getRank() == 1)
848 instData = {instWidth};
851 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
852 if (instHeight == -1)
854 "No suitable instruction multiple found for the given shape.");
855 instData = {instHeight, instWidth};
858 if (layoutKind == xegpu::LayoutKind::InstData)
860 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
861 else if (layoutKind == xegpu::LayoutKind::Lane)
863 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
864 uArchInstruction->getPackedFormatBitSize());
867 auto numSgOrErr =
getNumSg(store, sgSize);
870 "Unable to determine the number of subgroups for the operation.");
874 instData, numSgOrErr.value());
875 if (sgLayouts.empty()) {
877 "Unable to determine suitable subgroup layout for store value.");
880 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
881 SmallVector<int> sgData = {
882 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
883 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
884 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
892 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
896 for (LayoutInfoLattice *operand : operands)
897 propagateIfChanged(operand, operand->meet(storeLayout));
902void LayoutInfoPropagation::visitLoadNdOp(
903 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
904 ArrayRef<const LayoutInfoLattice *> results) {
905 LayoutInfo loadLayout;
906 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
907 if (hasParamsOfLayoutKind(anchorLayout)) {
908 loadLayout = LayoutInfo(anchorLayout);
911 LayoutInfo valueLayout = results[0]->getValue();
913 if (!valueLayout.isAssigned())
915 loadLayout = valueLayout;
919 if (
auto transpose =
load.getTranspose()) {
920 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
921 "LayoutInfoPropagation stage.");
922 loadLayout = valueLayout.transpose(transpose.value());
924 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
927 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
932void LayoutInfoPropagation::visitConvertLayoutOp(
933 xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
934 ArrayRef<const LayoutInfoLattice *> results) {
935 xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
936 LayoutInfo convertLayout(anchorLayout);
938 propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
943void LayoutInfoPropagation::visitTransposeOp(
944 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
945 ArrayRef<const LayoutInfoLattice *> results) {
947 LayoutInfo resultLayout = results[0]->getValue();
948 if (!resultLayout.isAssigned())
950 auto consumerLayoutAttr =
951 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
953 consumerLayoutAttr, transpose.getPermutation());
955 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
960void LayoutInfoPropagation::visitVectorBitcastOp(
961 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
962 ArrayRef<const LayoutInfoLattice *> results) {
964 LayoutInfo resLayoutInfo = results[0]->getValue();
965 if (!resLayoutInfo.isAssigned())
968 auto srcVecType = bitcast.getSourceVectorType();
969 auto resVecType = bitcast.getResultVectorType();
971 auto consumerLayoutAttr =
972 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
977 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
981 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
982 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
986 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
988 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
991void LayoutInfoPropagation::visitInsertStridedSliceOp(
992 vector::InsertStridedSliceOp insertStridedSlice,
993 ArrayRef<LayoutInfoLattice *> operands,
994 ArrayRef<const LayoutInfoLattice *> results) {
996 LayoutInfo resLayoutInfo = results[0]->getValue();
997 if (!resLayoutInfo.isAssigned())
1000 auto srcVecType = insertStridedSlice.getSourceVectorType();
1001 auto resVecType = insertStridedSlice.getDestVectorType();
1003 auto consumerLayoutAttr =
1004 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1005 const uArch *uArch =
1011 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1013 requiredResLayoutAttr);
1016 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
1017 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1018 propagateIfChanged(operands[1],
1019 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
1024void LayoutInfoPropagation::visitLoadGatherOp(
1025 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
1026 ArrayRef<const LayoutInfoLattice *> results) {
1027 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1028 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
1032 VectorType resVecTy =
load.getValueType();
1033 int chunkSize =
load.getChunkSize().value_or(1);
1035 LayoutInfo resLayoutInfo = results[0]->getValue();
1036 if (!resLayoutInfo.isAssigned())
1038 auto consumerLayoutAttr =
1039 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1041 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1042 requiredAnchorLayoutAttr = anchorLayoutAttr;
1045 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1049 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1050 load.setLayoutAttr(requiredAnchorLayoutAttr);
1053 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1055 requiredAnchorLayoutAttr, chunkSize);
1056 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1057 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1060 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1061 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1063 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1064 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1069void LayoutInfoPropagation::visitStoreScatterOp(
1070 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1071 ArrayRef<const LayoutInfoLattice *> results) {
1073 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1074 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1078 VectorType srcVecTy = storeScatter.getValueType();
1079 int chunkSize = storeScatter.getChunkSize().value_or(1);
1081 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1082 requiredAnchorLayoutAttr = anchorLayoutAttr;
1085 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1089 layoutKind, srcVecTy, chunkSize, uArch);
1090 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1093 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1094 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1096 requiredAnchorLayoutAttr, chunkSize);
1097 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1100 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1102 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1103 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1105 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1106 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1109void LayoutInfoPropagation::visitLoadMatrixOp(
1110 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1111 ArrayRef<const LayoutInfoLattice *> results) {
1113 LayoutInfo resLayoutInfo = results[0]->getValue();
1114 if (!resLayoutInfo.isAssigned())
1117 auto consumerLayoutAttr =
1118 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1120 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1124 if (!hasParamsOfLayoutKind(anchorLayout)) {
1125 VectorType resVecTy =
1126 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1131 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1132 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1137void LayoutInfoPropagation::visitStoreMatrixOp(
1138 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1139 ArrayRef<const LayoutInfoLattice *> results) {
1140 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1142 if (hasParamsOfLayoutKind(anchorLayout)) {
1143 layout = LayoutInfo(anchorLayout);
1145 VectorType srcVecTy =
1146 llvm::cast<VectorType>(storeMatrix.getData().getType());
1150 auto requiredAnchorLayoutAttr =
1152 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1153 layout = LayoutInfo(requiredAnchorLayoutAttr);
1156 propagateIfChanged(operands[0], operands[0]->meet(layout));
1165class RunLayoutInfoPropagation {
1170 unsigned indexBitWidth)
1172 SymbolTableCollection symbolTable;
1174 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1178 LayoutInfo getLayoutInfo(Value val);
1180 void printAnalysisResult(llvm::raw_ostream &os);
1183 DataFlowSolver solver;
1188LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1189 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1192 return state->getValue();
1196void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1197 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1198 os <<
"function: " << funcOp.getName() <<
":\n";
1200 for (BlockArgument arg : funcOp.getArguments()) {
1201 LayoutInfo layout = getLayoutInfo(arg);
1202 os <<
"argument: " << arg <<
"\n";
1208 funcOp.walk([&](Operation *op) {
1214 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1220 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1221 LayoutInfo layout = getLayoutInfo(r);
1222 os <<
"layout for result #" << i <<
": ";
1229 SmallVector<FunctionOpInterface> funcOps;
1230 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1231 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1232 funcOps.push_back(funcOp);
1235 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1236 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1237 funcOps.push_back(gpuFuncOp);
1241 for (FunctionOpInterface funcOp : funcOps)
1242 printFunctionResult(funcOp);
1254static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1256 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1261 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1262 auto *parentOp = arg.getOwner()->getParentOp();
1263 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1264 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1266 return getDefiningCreateNdDescOp(tiedInit->
get());
1273struct ResolveLayoutConflicts {
1274 ResolveLayoutConflicts(Operation *parentOp)
1275 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1276 LogicalResult run();
1279 Operation *parentOp;
1281 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1282 LogicalResult resolveVectorConsumer(OpOperand &operand);
1283 LogicalResult assignResultLayout(OpResult &
result);
1288LogicalResult ResolveLayoutConflicts::run() {
1291 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1295 if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
1297 if (
result.getType().isIntOrFloat()) {
1298 auto res = assignResultLayout(
result);
1300 DBGS() <<
"Failed to resolve vector consumer for multi-reduction "
1309 Type operandType = operand.get().getType();
1310 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1311 isa<xegpu::TensorDescType>(operandType)) {
1312 auto res = resolveTensorDescConsumer(operand);
1314 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1320 if (isa<VectorType>(operandType)) {
1321 auto res = resolveVectorConsumer(operand);
1323 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1331 return r.wasInterrupted() ? failure() :
success();
1334LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &
result) {
1335 Operation *producerOp =
result.getDefiningOp();
1339 auto convertOp = xegpu::ConvertLayoutOp::create(
1342 result.replaceAllUsesExcept(convertOp.getResult(), convertOp);
1347ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1348 Value vectorValue = operand.
get();
1349 Operation *consumerOp = operand.
getOwner();
1352 if (!producerLayout) {
1353 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1354 vectorTy && vectorTy.getRank() > 1)
1355 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1360 if (!consumerLayout)
1362 "No consumer layout found for vector operand.");
1365 if (consumerLayout.isEqualTo(producerLayout))
1370 auto convertOp = xegpu::ConvertLayoutOp::create(
1371 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1372 producerLayout, consumerLayout);
1375 operand.
set(convertOp.getResult());
1380ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1381 Operation *consumerOp = operand.
getOwner();
1382 Value tdescValue = operand.
get();
1383 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1384 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1385 assert(anchorOp && currTDescType &&
1386 "Expected anchor layout op and tensor descriptor consumer.");
1387 Attribute currLayout = currTDescType.getLayout();
1388 Attribute expectedLayout = anchorOp.getAnchorLayout();
1391 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1393 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1394 if (!conflictingCreateNdOp) {
1395 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1396 << tdescValue <<
"\n";
1401 auto newTensorDescType = xegpu::TensorDescType::get(
1402 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1403 currTDescType.getElementType(), currTDescType.getEncoding(),
1405 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1406 builder, consumerOp->
getLoc(), newTensorDescType,
1407 conflictingCreateNdOp->getOperands(),
1408 conflictingCreateNdOp->getAttrs());
1426 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1433 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1436 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1437 if (!layout &&
result.getNumUses() > 0) {
1438 op->
emitWarning(
"op has users but no layout assigned for its result");
1443 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1444 auto typeWithLayout = xegpu::TensorDescType::get(
1445 tensorDescTy.getContext(), tensorDescTy.getShape(),
1446 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1447 result.setType(typeWithLayout);
1481 mlir::RegionBranchTerminatorOpInterface terminator,
1484 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1489 branchOp.getSuccessorOperandInputMapping(mapping,
1491 for (
const auto &[successorOperand, successorInputs] : mapping) {
1492 for (
Value successorInput : successorInputs) {
1493 Type inputType = successorInput.getType();
1495 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1497 xegpu::DistributeLayoutAttr successorInputLayout =
1498 getLayoutOfValue(successorInput);
1499 xegpu::DistributeLayoutAttr successorOperandLayout =
1500 getLayoutOfValue(successorOperand->get());
1503 if (!successorOperandLayout) {
1504 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1505 "branch terminator: "
1506 << successorOperand->get() <<
"\n");
1510 if (successorInputLayout &&
1511 successorInputLayout != successorOperandLayout) {
1512 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1513 "operand forwarded as the argument: "
1514 << successorInputLayout <<
" vs "
1515 << successorOperandLayout <<
"\n");
1519 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1520 auto newTdescTy = xegpu::TensorDescType::get(
1521 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1522 tdescTy.getEncoding(), successorOperandLayout);
1523 successorInput.setType(newTdescTy);
1528 if (
auto result = dyn_cast<OpResult>(successorInput))
1537 mlir::FunctionOpInterface funcOp,
1543 if (!isa<FunctionType>(funcOp.getFunctionType()))
1548 Type argType = arg.getType();
1549 newArgTypes.push_back(argType);
1550 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1552 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1554 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1555 <<
" but got none.\n");
1558 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1559 auto newTdescTy = xegpu::TensorDescType::get(
1560 tensorDescTy.getContext(), tensorDescTy.getShape(),
1561 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1562 arg.setType(newTdescTy);
1563 newArgTypes.back() = newTdescTy;
1568 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1569 funcOp.getResultTypes()));
1574struct XeGPUPropagateLayoutPass final
1575 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1576 XeGPUPropagateLayoutPass() =
default;
1577 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1578 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1579 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1580 void runOnOperation()
override;
1587 unsigned indexBitWidth,
bool printOnly) {
1588 RunLayoutInfoPropagation analysis(
target, layoutKind, indexBitWidth);
1591 auto &os = llvm::outs();
1592 analysis.printAnalysisResult(os);
1596 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1597 LayoutInfo layout = analysis.getLayoutInfo(val);
1598 if (
auto opResult = dyn_cast<OpResult>(val)) {
1599 Operation *defOp = opResult.getDefiningOp();
1600 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1601 auto anchorLayout = anchorOp.getAnchorLayout();
1602 if (anchorLayout !=
nullptr)
1603 return anchorLayout;
1605 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1607 if (requiredResLayoutAttr !=
nullptr)
1608 return requiredResLayoutAttr;
1610 if (!layout.isAssigned())
1612 xegpu::DistributeLayoutAttr layoutAttr =
1613 cast<xegpu::DistributeLayoutAttr>(layout.get());
1614 if (layout.isSliceLayout())
1615 return cast<xegpu::SliceAttr>(layoutAttr);
1617 return cast<xegpu::LayoutAttr>(layoutAttr);
1625 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1627 getXeGPULayoutForValue);
1629 .Case([&](mlir::FunctionOpInterface funcOp) {
1631 getXeGPULayoutForValue);
1634 r =
updateOp(builder, op, getXeGPULayoutForValue);
1637 op.
emitError(
"Failed to update operation with the layout.");
1643 if (walkResult.wasInterrupted())
1650 ResolveLayoutConflicts resolver(
target);
1651 return resolver.run();
1654void XeGPUPropagateLayoutPass::runOnOperation() {
1656 getOperation()->walk([](
Operation *op) {
1659 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1660 attrsToRemove.push_back(namedAttr.getName());
1662 for (
auto attrName : attrsToRemove)
1666 if (this->layoutKind ==
"lane") {
1667 layoutKind = xegpu::LayoutKind::Lane;
1668 }
else if (this->layoutKind ==
"inst") {
1669 layoutKind = xegpu::LayoutKind::InstData;
1670 }
else if (this->layoutKind ==
"subgroup") {
1671 layoutKind = xegpu::LayoutKind::Subgroup;
1673 getOperation()->emitError(
"Unsupported layout kind option: " +
1675 signalPassFailure();
1680 this->indexBitWidth, this->printOnly))) {
1681 signalPassFailure();
1686 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const