32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
83 xegpu::DistributeLayoutAttr storage =
nullptr;
86 LayoutInfo() =
default;
87 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
91 bool operator==(
const LayoutInfo &other)
const {
92 return this->isAssigned() == other.isAssigned();
95 static LayoutInfo meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
97 static LayoutInfo
join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs);
101 bool isAssigned()
const {
return storage !=
nullptr; }
117 bool isSliceLayout()
const {
120 return isa<xegpu::SliceAttr>(storage);
126 return storage.getRank();
130 void set(
const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](
int64_t val) { return static_cast<int>(val); });
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](
int64_t val) { return static_cast<int>(val); });
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](
int64_t val) { return static_cast<int>(val); });
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](
int64_t val) { return static_cast<int>(val); });
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](
int64_t val) { return static_cast<int>(val); });
169 if (!isAssigned() || !storage.getOrder())
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](
int64_t val) { return static_cast<int>(val); });
179 os <<
"Not assigned.";
183LayoutInfo LayoutInfo::meet(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
184 if (!
lhs.isAssigned())
190LayoutInfo LayoutInfo::join(
const LayoutInfo &
lhs,
const LayoutInfo &
rhs) {
191 llvm_unreachable(
"Join should not be triggered by layout propagation.");
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](
int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
206 if (!withinRange || hasDuplicates) {
207 assert(
false &&
"Invalid permutation for transpose.");
218 for (
int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
221 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(
static_cast<int32_t
>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(
static_cast<int32_t
>(getSgLayout()[idx]));
227 sgData.push_back(
static_cast<int32_t
>(getSgData()[idx]));
229 if (getOrder().size()) {
230 order.push_back(
static_cast<int32_t
>(getOrder()[idx]));
233 auto orderAttr = order.size()
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
249 return LayoutInfo(layoutAttr);
257struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
272 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
275 xegpu::LayoutAttr::get(ctx, {
uArch->getSubgroupSize()}, {1}));
278 xegpu::LayoutAttr::get(ctx, {1,
uArch->getSubgroupSize()}, {1, 1}));
282 unsigned rank,
int subgroupSize) {
283 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
291template <
typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
293 const xegpu::uArch::uArch *uArch,
294 unsigned packingSize) {
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
320class LayoutInfoPropagation
321 :
public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
324 unsigned indexBitWidth;
325 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
326 ArrayRef<const LayoutInfoLattice *> results);
328 void visitStoreNdOp(xegpu::StoreNdOp store,
329 ArrayRef<LayoutInfoLattice *> operands,
330 ArrayRef<const LayoutInfoLattice *> results);
332 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
333 ArrayRef<LayoutInfoLattice *> operands,
334 ArrayRef<const LayoutInfoLattice *> results);
336 void visitLoadNdOp(xegpu::LoadNdOp
load,
337 ArrayRef<LayoutInfoLattice *> operands,
338 ArrayRef<const LayoutInfoLattice *> results);
340 void visitLoadGatherOp(xegpu::LoadGatherOp
load,
341 ArrayRef<LayoutInfoLattice *> operands,
342 ArrayRef<const LayoutInfoLattice *> results);
344 void visitTransposeOp(vector::TransposeOp transpose,
345 ArrayRef<LayoutInfoLattice *> operands,
346 ArrayRef<const LayoutInfoLattice *> results);
348 void visitVectorBitcastOp(vector::BitCastOp bitcast,
349 ArrayRef<LayoutInfoLattice *> operands,
350 ArrayRef<const LayoutInfoLattice *> results);
352 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
353 ArrayRef<LayoutInfoLattice *> operands,
354 ArrayRef<const LayoutInfoLattice *> results);
356 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
357 ArrayRef<LayoutInfoLattice *> operands,
358 ArrayRef<const LayoutInfoLattice *> results);
360 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
361 ArrayRef<LayoutInfoLattice *> operands,
362 ArrayRef<const LayoutInfoLattice *> results);
364 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
365 ArrayRef<LayoutInfoLattice *> operands,
366 ArrayRef<const LayoutInfoLattice *> results);
368 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
369 ArrayRef<LayoutInfoLattice *> operands,
370 ArrayRef<const LayoutInfoLattice *> results);
371 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
372 ArrayRef<LayoutInfoLattice *> operands,
373 ArrayRef<const LayoutInfoLattice *> results);
375 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
376 ArrayRef<LayoutInfoLattice *> operands,
377 ArrayRef<const LayoutInfoLattice *> results);
379 void visitLoadMatrixOp(xegpu::LoadMatrixOp
load,
380 ArrayRef<LayoutInfoLattice *> operands,
381 ArrayRef<const LayoutInfoLattice *> results);
383 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
384 ArrayRef<LayoutInfoLattice *> operands,
385 ArrayRef<const LayoutInfoLattice *> results);
387 void visitLoadGatherOp(xegpu::LoadMatrixOp
load,
388 ArrayRef<LayoutInfoLattice *> operands,
389 ArrayRef<const LayoutInfoLattice *> results);
391 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
392 ArrayRef<LayoutInfoLattice *> operands,
393 ArrayRef<const LayoutInfoLattice *> results);
395 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
398 LayoutInfoPropagation(DataFlowSolver &solver,
399 SymbolTableCollection &symbolTable,
401 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
402 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
406 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
407 ArrayRef<const LayoutInfoLattice *> results)
override;
409 void visitBranchOperand(OpOperand &operand)
override {};
411 void visitCallOperand(OpOperand &operand)
override {};
414 visitNonControlFlowArguments(RegionSuccessor &successor,
415 ArrayRef<BlockArgument> arguments)
override {};
417 void visitExternalCall(CallOpInterface call,
418 ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results)
override {
422 void setToExitState(LayoutInfoLattice *lattice)
override {
423 (void)lattice->meet(LayoutInfo());
428LogicalResult LayoutInfoPropagation::visitOperation(
429 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
430 ArrayRef<const LayoutInfoLattice *> results) {
433 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
434 .Case([&](xegpu::StoreNdOp storeNdOp) {
435 visitStoreNdOp(storeNdOp, operands, results);
437 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
438 visitStoreScatterOp(storeScatterOp, operands, results);
440 .Case([&](xegpu::LoadNdOp loadNdOp) {
441 visitLoadNdOp(loadNdOp, operands, results);
443 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
444 visitLoadGatherOp(loadGatherOp, operands, results);
446 .Case([&](xegpu::CreateDescOp createDescOp) {
447 visitCreateDescOp(createDescOp, operands, results);
449 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
450 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
452 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
453 visitPrefetchNdOp(prefetchNdOp, operands, results);
455 .Case([&](vector::TransposeOp transposeOp) {
456 visitTransposeOp(transposeOp, operands, results);
458 .Case([&](vector::BitCastOp bitcastOp) {
459 visitVectorBitcastOp(bitcastOp, operands, results);
461 .Case([&](vector::MultiDimReductionOp reductionOp) {
462 visitVectorMultiReductionOp(reductionOp, operands, results);
464 .Case([&](vector::BroadcastOp broadcastOp) {
465 visitVectorBroadCastOp(broadcastOp, operands, results);
467 .Case([&](vector::ShapeCastOp shapeCastOp) {
468 visitShapeCastOp(shapeCastOp, operands, results);
470 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
471 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
473 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
474 visitLoadMatrixOp(loadMatrixOp, operands, results);
476 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
477 visitStoreMatrixOp(storeMatrixOp, operands, results);
480 .Default([&](Operation *op) {
481 for (
const LayoutInfoLattice *resultInfo : results) {
482 if (!resultInfo->getValue().isAssigned())
484 for (
auto [operandInfo, operand] :
488 if (!isa<xegpu::TensorDescType, VectorType>(
489 operand.get().getType()))
492 meet(operandInfo, *resultInfo);
500bool LayoutInfoPropagation::hasParamsOfLayoutKind(
501 xegpu::DistributeLayoutAttr anchorLayout) {
502 if (anchorLayout ==
nullptr) {
505 if (layoutKind == xegpu::LayoutKind::InstData) {
506 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
508 if (layoutKind == xegpu::LayoutKind::Lane) {
509 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
510 anchorLayout.getEffectiveLaneDataAsInt().empty());
512 if (layoutKind == xegpu::LayoutKind::Subgroup) {
513 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
514 anchorLayout.getEffectiveSgDataAsInt().empty());
530 for (
int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
531 if (sgCount % sgLayout0)
533 int sgLayout1 = sgCount / sgLayout0;
534 int sgData0 = wgShape[0] / sgLayout0;
535 int sgData1 = wgShape[1] / sgLayout1;
536 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
537 (sgData0 % instData[0] || sgData1 % instData[1]))
539 candidates.emplace_back(sgLayout0, sgLayout1);
544 llvm::sort(candidates, [](
const std::pair<int, int> &
lhs,
545 const std::pair<int, int> &
rhs) {
546 int diffLhs = std::abs(
lhs.first -
lhs.second);
547 int diffRhs = std::abs(
rhs.first -
rhs.second);
548 if (diffLhs != diffRhs)
549 return diffLhs < diffRhs;
550 return lhs.first <
rhs.first;
560 auto knownBlockSize = gpuFunc.getKnownBlockSize();
561 if (!knownBlockSize.has_value())
563 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
564 return flatBlockSize / sgSize;
567void LayoutInfoPropagation::visitPrefetchNdOp(
568 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
569 ArrayRef<const LayoutInfoLattice *> results) {
571 LayoutInfo prefetchLayout;
572 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
573 if (hasParamsOfLayoutKind(anchorLayout)) {
574 prefetchLayout = LayoutInfo(anchorLayout);
578 auto tdescTy = prefetch.getTensorDescType();
583 const auto *uArchInstruction =
584 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
586 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
589 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
591 prefetch.emitWarning(
"No known block params found for the element type.");
592 auto [bWidth, bHeight, bCount] = blockWHC.value();
593 SmallVector<int> instData;
595 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
597 prefetch.emitWarning(
598 "No suitable instruction multiple found for the given shape.");
599 if (tdescTy.getRank() == 1)
600 instData = {instWidth};
603 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
604 if (instHeight == -1)
605 prefetch.emitWarning(
606 "No suitable instruction multiple found for the given shape.");
607 instData = {instHeight, instWidth};
610 if (layoutKind == xegpu::LayoutKind::InstData)
612 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
614 prefetchLayout = getSIMTLayoutInfoBlockIO(
615 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
617 prefetch.setLayoutAttr(
618 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
621 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
624void LayoutInfoPropagation::visitVectorMultiReductionOp(
625 vector::MultiDimReductionOp reduction,
626 ArrayRef<LayoutInfoLattice *> operands,
627 ArrayRef<const LayoutInfoLattice *> results) {
629 LayoutInfo resLayoutInfo = results[0]->getValue();
630 if (!resLayoutInfo.isAssigned())
633 VectorType sourceTy = reduction.getSourceVectorType();
634 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
639 auto consumerLayoutAttr =
640 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
648 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
654 requiredResLayoutAttr, reductionDims);
656 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
658 propagateIfChanged(operands[1],
659 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
662void LayoutInfoPropagation::visitVectorBroadCastOp(
663 vector::BroadcastOp
broadcast, ArrayRef<LayoutInfoLattice *> operands,
664 ArrayRef<const LayoutInfoLattice *> results) {
666 LayoutInfo resLayoutInfo = results[0]->getValue();
667 if (!resLayoutInfo.isAssigned())
671 VectorType resultTy =
broadcast.getResultVectorType();
672 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
677 auto srcShape = sourceTy.getShape();
678 auto resShape = resultTy.getShape();
680 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
685 [[maybe_unused]]
bool hasUnitDim =
686 llvm::any_of(srcShape, [](int64_t dim) {
return dim == 1; });
688 hasUnitDim && isa<vector::ShapeCastOp>(srcOp) &&
689 "When broadcasting from unit-dim, the producer op must be shape_cast!");
692 auto resultLayoutAttr =
693 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
695 xegpu::DistributeLayoutAttr srcLayoutAttr =
698 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
701void LayoutInfoPropagation::visitShapeCastOp(
702 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
703 ArrayRef<const LayoutInfoLattice *> results) {
705 LayoutInfo resLayoutInfo = results[0]->getValue();
706 if (!resLayoutInfo.isAssigned())
708 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
709 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
710 auto resultLayoutAttr =
711 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
713 xegpu::DistributeLayoutAttr srcLayoutAttr =
716 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
721void LayoutInfoPropagation::visitUpdateNdOffsetOp(
722 xegpu::UpdateNdOffsetOp updateNdOffset,
723 ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
726 LayoutInfo resultLayout = results[0]->getValue();
727 if (!resultLayout.isAssigned())
730 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
734void LayoutInfoPropagation::visitDpasOp(
735 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
736 ArrayRef<const LayoutInfoLattice *> results) {
737 LayoutInfo dpasALayout;
738 LayoutInfo dpasBLayout;
739 LayoutInfo dpasCDLayout;
741 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
742 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
743 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
744 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
745 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
746 "Expected anchor layout for DPAS A operand.");
747 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
748 "Expected anchor layout for DPAS B operand.");
749 dpasALayout = LayoutInfo(anchorLayoutA);
750 dpasBLayout = LayoutInfo(anchorLayoutB);
751 dpasCDLayout = LayoutInfo(anchorLayoutCD);
756 VectorType aTy = dpas.getLhsType();
757 VectorType bTy = dpas.getRhsType();
758 VectorType cdTy = dpas.getResultType();
760 xegpu::DistributeLayoutAttr consumerLayoutAttr =
nullptr;
761 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
765 if (layoutKind == xegpu::LayoutKind::Subgroup) {
766 LayoutInfo consumerLayout = results[0]->getValue();
767 if (!consumerLayout.isAssigned())
770 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
774 "Unable to determine the number of subgroups for the operation.");
777 numSg = numSgOrErr.value();
780 consumerLayoutAttr, uArch, numSg);
781 if (!layouts.has_value()) {
783 "Failed to determine required layouts for DPAS operands.");
787 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
789 dpas.setLayoutAAttr(requiredALayout);
790 dpas.setLayoutBAttr(requiredBLayout);
791 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
792 dpasALayout = LayoutInfo(requiredALayout);
793 dpasBLayout = LayoutInfo(requiredBLayout);
794 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
796 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
797 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
798 if (operands.size() > 2)
799 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
803void LayoutInfoPropagation::visitStoreNdOp(
804 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
805 ArrayRef<const LayoutInfoLattice *> results) {
806 LayoutInfo storeLayout;
807 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
808 if (hasParamsOfLayoutKind(anchorLayout)) {
809 storeLayout = LayoutInfo(anchorLayout);
814 const auto *uArchInstruction =
815 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
817 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
818 VectorType dataTy = store.getValueType();
819 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
820 store.getValueType().getElementType());
822 store.emitWarning(
"No known block params found for the element type.");
823 auto [bWidth, bHeight, bCount] = blockWHC.value();
824 SmallVector<int> instData;
826 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
829 "No suitable instruction multiple found for the given shape.");
830 if (dataTy.getRank() == 1)
831 instData = {instWidth};
834 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
835 if (instHeight == -1)
837 "No suitable instruction multiple found for the given shape.");
838 instData = {instHeight, instWidth};
841 if (layoutKind == xegpu::LayoutKind::InstData)
843 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
844 else if (layoutKind == xegpu::LayoutKind::Lane)
846 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
847 uArchInstruction->getPackedFormatBitSize());
850 auto numSgOrErr =
getNumSg(store, sgSize);
853 "Unable to determine the number of subgroups for the operation.");
857 instData, numSgOrErr.value());
858 if (sgLayouts.empty()) {
860 "Unable to determine suitable subgroup layout for store value.");
863 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
864 SmallVector<int> sgData = {
865 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
866 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
867 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
875 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
879 for (LayoutInfoLattice *operand : operands)
880 propagateIfChanged(operand, operand->meet(storeLayout));
885void LayoutInfoPropagation::visitLoadNdOp(
886 xegpu::LoadNdOp
load, ArrayRef<LayoutInfoLattice *> operands,
887 ArrayRef<const LayoutInfoLattice *> results) {
888 LayoutInfo loadLayout;
889 xegpu::DistributeLayoutAttr anchorLayout =
load.getLayoutAttr();
890 if (hasParamsOfLayoutKind(anchorLayout)) {
891 loadLayout = LayoutInfo(anchorLayout);
894 LayoutInfo valueLayout = results[0]->getValue();
896 if (!valueLayout.isAssigned())
898 loadLayout = valueLayout;
902 if (
auto transpose =
load.getTranspose()) {
903 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
904 "LayoutInfoPropagation stage.");
905 loadLayout = valueLayout.transpose(transpose.value());
907 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
910 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
915void LayoutInfoPropagation::visitTransposeOp(
916 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
917 ArrayRef<const LayoutInfoLattice *> results) {
919 LayoutInfo resultLayout = results[0]->getValue();
920 if (!resultLayout.isAssigned())
922 auto consumerLayoutAttr =
923 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
925 consumerLayoutAttr, transpose.getPermutation());
927 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
932void LayoutInfoPropagation::visitVectorBitcastOp(
933 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
934 ArrayRef<const LayoutInfoLattice *> results) {
936 LayoutInfo resLayoutInfo = results[0]->getValue();
937 if (!resLayoutInfo.isAssigned())
940 auto srcVecType = bitcast.getSourceVectorType();
941 auto resVecType = bitcast.getResultVectorType();
943 auto consumerLayoutAttr =
944 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
949 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
953 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
954 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
958 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
960 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
963void LayoutInfoPropagation::visitInsertStridedSliceOp(
964 vector::InsertStridedSliceOp insertStridedSlice,
965 ArrayRef<LayoutInfoLattice *> operands,
966 ArrayRef<const LayoutInfoLattice *> results) {
968 LayoutInfo resLayoutInfo = results[0]->getValue();
969 if (!resLayoutInfo.isAssigned())
972 auto srcVecType = insertStridedSlice.getSourceVectorType();
973 auto resVecType = insertStridedSlice.getDestVectorType();
975 auto consumerLayoutAttr =
976 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
983 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
985 requiredResLayoutAttr);
988 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
989 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
990 propagateIfChanged(operands[1],
991 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
996void LayoutInfoPropagation::visitLoadGatherOp(
997 xegpu::LoadGatherOp
load, ArrayRef<LayoutInfoLattice *> operands,
998 ArrayRef<const LayoutInfoLattice *> results) {
999 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1000 xegpu::DistributeLayoutAttr anchorLayoutAttr =
load.getLayoutAttr();
1005 VectorType resVecTy =
load.getValueType();
1006 int chunkSize =
load.getChunkSize().value_or(1);
1008 LayoutInfo resLayoutInfo = results[0]->getValue();
1009 if (!resLayoutInfo.isAssigned())
1011 auto consumerLayoutAttr =
1012 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1014 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1015 requiredAnchorLayoutAttr = anchorLayoutAttr;
1018 load.emitWarning(
"Not propagating, non-vector payload supplied.");
1022 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1023 load.setLayoutAttr(requiredAnchorLayoutAttr);
1026 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1029 if (chunkSize > 1) {
1030 if (layoutKind == xegpu::LayoutKind::InstData)
1032 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize});
1033 else if (layoutKind == xegpu::LayoutKind::Lane)
1035 xegpu::LayoutAttr::get(
load->getContext(), {subgroupSize}, {1});
1038 "chunked StoreScatterOp should not be used at workgroup level");
1041 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1042 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1045 if (isa<xegpu::TensorDescType>(
load.getSourceType()))
1046 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1048 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1049 if (
load.getOffsets())
1050 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1055void LayoutInfoPropagation::visitCreateDescOp(
1056 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1057 ArrayRef<const LayoutInfoLattice *> results) {
1058 LayoutInfo descLayout = results[0]->getValue();
1060 if (!descLayout.isAssigned())
1066 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1068 propagateIfChanged(operands[1], operands[1]->meet(layout));
1073void LayoutInfoPropagation::visitStoreScatterOp(
1074 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1075 ArrayRef<const LayoutInfoLattice *> results) {
1077 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1078 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1083 VectorType srcVecTy = storeScatter.getValueType();
1084 int chunkSize = storeScatter.getChunkSize().value_or(1);
1086 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1087 requiredAnchorLayoutAttr = anchorLayoutAttr;
1090 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
1094 layoutKind, srcVecTy, chunkSize, uArch);
1095 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1098 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1099 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1102 if (chunkSize > 1) {
1103 if (layoutKind == xegpu::LayoutKind::InstData)
1105 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1106 else if (layoutKind == xegpu::LayoutKind::Lane)
1107 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1108 {subgroupSize}, {1});
1111 "chunked StoreScatterOp should not be used at workgroup level");
1114 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1117 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1119 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1120 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1122 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1123 if (storeScatter.getOffsets())
1124 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1127void LayoutInfoPropagation::visitLoadMatrixOp(
1128 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1129 ArrayRef<const LayoutInfoLattice *> results) {
1131 LayoutInfo resLayoutInfo = results[0]->getValue();
1132 if (!resLayoutInfo.isAssigned())
1135 auto consumerLayoutAttr =
1136 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1138 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1142 if (!hasParamsOfLayoutKind(anchorLayout)) {
1143 VectorType resVecTy =
1144 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1145 assert(resVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1150 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1151 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1156void LayoutInfoPropagation::visitStoreMatrixOp(
1157 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1158 ArrayRef<const LayoutInfoLattice *> results) {
1159 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1161 if (hasParamsOfLayoutKind(anchorLayout)) {
1162 layout = LayoutInfo(anchorLayout);
1164 VectorType srcVecTy =
1165 llvm::cast<VectorType>(storeMatrix.getData().getType());
1166 assert(srcVecTy.getRank() == 2 &&
"Expecting 2D vector for store matrix.");
1170 auto requiredAnchorLayoutAttr =
1172 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1173 layout = LayoutInfo(requiredAnchorLayoutAttr);
1176 propagateIfChanged(operands[0], operands[0]->meet(layout));
1185class RunLayoutInfoPropagation {
1190 unsigned indexBitWidth)
1192 SymbolTableCollection symbolTable;
1194 solver.
load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1198 LayoutInfo getLayoutInfo(Value val);
1200 void printAnalysisResult(llvm::raw_ostream &os);
1203 DataFlowSolver solver;
1208LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1209 auto *state = solver.
lookupState<LayoutInfoLattice>(val);
1212 return state->getValue();
1216void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1217 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1218 os <<
"function: " << funcOp.getName() <<
":\n";
1220 for (BlockArgument arg : funcOp.getArguments()) {
1221 LayoutInfo layout = getLayoutInfo(arg);
1222 os <<
"argument: " << arg <<
"\n";
1228 funcOp.walk([&](Operation *op) {
1234 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1240 for (
auto [i, r] : llvm::enumerate(op->
getResults())) {
1241 LayoutInfo layout = getLayoutInfo(r);
1242 os <<
"layout for result #" << i <<
": ";
1249 SmallVector<FunctionOpInterface> funcOps;
1250 if (
auto modOp = dyn_cast<ModuleOp>(
target)) {
1251 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
1252 funcOps.push_back(funcOp);
1255 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1256 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1257 funcOps.push_back(gpuFuncOp);
1261 for (FunctionOpInterface funcOp : funcOps)
1262 printFunctionResult(funcOp);
1274static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1276 auto definingOp = tdescValue.
getDefiningOp<xegpu::CreateNdDescOp>();
1281 if (
auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1282 auto *parentOp = arg.getOwner()->getParentOp();
1283 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1284 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1286 return getDefiningCreateNdDescOp(tiedInit->
get());
1293struct ResolveLayoutConflicts {
1294 ResolveLayoutConflicts(Operation *parentOp)
1295 : parentOp(parentOp), builder(parentOp->
getContext()) {}
1296 LogicalResult run();
1299 Operation *parentOp;
1301 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1302 LogicalResult resolveVectorConsumer(OpOperand &operand);
1307LogicalResult ResolveLayoutConflicts::run() {
1310 auto r = parentOp->
walk([&](Operation *op) -> WalkResult {
1313 Type operandType = operand.get().getType();
1314 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1315 isa<xegpu::TensorDescType>(operandType)) {
1316 auto res = resolveTensorDescConsumer(operand);
1318 DBGS() <<
"Failed to resolve tensor descriptor consumer: " << *op
1324 if (isa<VectorType>(operandType)) {
1325 auto res = resolveVectorConsumer(operand);
1327 DBGS() <<
"Failed to resolve vector consumer: " << *op <<
"\n";
1335 return r.wasInterrupted() ? failure() :
success();
1339ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1340 Value vectorValue = operand.
get();
1341 Operation *consumerOp = operand.
getOwner();
1344 if (!producerLayout) {
1345 if (
auto vectorTy = dyn_cast<VectorType>(vectorValue.
getType());
1346 vectorTy && vectorTy.getRank() > 1)
1347 consumerOp->
emitWarning(
"Expected layout for non-1D vectors.");
1352 if (!consumerLayout)
1354 "No consumer layout found for vector operand.");
1357 if (consumerLayout.isEqualTo(producerLayout))
1362 auto convertOp = xegpu::ConvertLayoutOp::create(
1363 builder, consumerOp->
getLoc(), vectorValue.
getType(), vectorValue,
1364 producerLayout, consumerLayout);
1367 operand.
set(convertOp.getResult());
1372ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1373 Operation *consumerOp = operand.
getOwner();
1374 Value tdescValue = operand.
get();
1375 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1376 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.
getType());
1377 assert(anchorOp && currTDescType &&
1378 "Expected anchor layout op and tensor descriptor consumer.");
1380 if (currTDescType.isScattered()) {
1381 DBGS() <<
"Scattered tensor descriptor not supported: " << tdescValue
1385 Attribute currLayout = currTDescType.getLayout();
1386 Attribute expectedLayout = anchorOp.getAnchorLayout();
1389 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1391 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1392 if (!conflictingCreateNdOp) {
1393 DBGS() <<
"Unable to find defining CreateNdDescOp for tensor descriptor: "
1394 << tdescValue <<
"\n";
1399 auto newTensorDescType = xegpu::TensorDescType::get(
1400 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1401 currTDescType.getElementType(), currTDescType.getEncoding(),
1403 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1404 builder, consumerOp->
getLoc(), newTensorDescType,
1405 conflictingCreateNdOp->getOperands(),
1406 conflictingCreateNdOp->getAttrs());
1424 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1431 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1434 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(
result);
1435 if (!layout &&
result.getNumUses() > 0) {
1436 op->
emitWarning(
"op has users but no layout assigned for its result");
1441 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1442 auto typeWithLayout = xegpu::TensorDescType::get(
1443 tensorDescTy.getContext(), tensorDescTy.getShape(),
1444 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1445 result.setType(typeWithLayout);
1479 mlir::RegionBranchTerminatorOpInterface terminator,
1482 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1487 branchOp.getSuccessorOperandInputMapping(mapping,
1489 for (
const auto &[successorOperand, successorInputs] : mapping) {
1490 for (
Value successorInput : successorInputs) {
1491 Type inputType = successorInput.getType();
1493 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1495 xegpu::DistributeLayoutAttr successorInputLayout =
1496 getLayoutOfValue(successorInput);
1497 xegpu::DistributeLayoutAttr successorOperandLayout =
1498 getLayoutOfValue(successorOperand->get());
1501 if (!successorOperandLayout) {
1502 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
1503 "branch terminator: "
1504 << successorOperand->get() <<
"\n");
1508 if (successorInputLayout &&
1509 successorInputLayout != successorOperandLayout) {
1510 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
1511 "operand forwarded as the argument: "
1512 << successorInputLayout <<
" vs "
1513 << successorOperandLayout <<
"\n");
1517 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1518 auto newTdescTy = xegpu::TensorDescType::get(
1519 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1520 tdescTy.getEncoding(), successorOperandLayout);
1521 successorInput.setType(newTdescTy);
1526 if (
auto result = dyn_cast<OpResult>(successorInput))
1535 mlir::FunctionOpInterface funcOp,
1541 if (!isa<FunctionType>(funcOp.getFunctionType()))
1546 Type argType = arg.getType();
1547 newArgTypes.push_back(argType);
1548 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1550 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1552 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
1553 <<
" but got none.\n");
1556 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1557 auto newTdescTy = xegpu::TensorDescType::get(
1558 tensorDescTy.getContext(), tensorDescTy.getShape(),
1559 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1560 arg.setType(newTdescTy);
1561 newArgTypes.back() = newTdescTy;
1566 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1567 funcOp.getResultTypes()));
1572struct XeGPUPropagateLayoutPass final
1573 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1574 XeGPUPropagateLayoutPass() =
default;
1575 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1576 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1577 : XeGPUPropagateLayoutBase(std::move(
options)) {}
1578 void runOnOperation()
override;
1585 unsigned indexBitWidth,
bool printOnly) {
1586 RunLayoutInfoPropagation analysis(
target, layoutKind, indexBitWidth);
1589 auto &os = llvm::outs();
1590 analysis.printAnalysisResult(os);
1594 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1595 LayoutInfo layout = analysis.getLayoutInfo(val);
1596 if (!layout.isAssigned())
1598 if (
auto opResult = dyn_cast<OpResult>(val)) {
1600 Operation *defOp = opResult.getDefiningOp();
1601 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1602 auto anchorLayout = anchorOp.getAnchorLayout();
1603 if (anchorLayout !=
nullptr)
1604 return anchorLayout;
1606 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1608 if (requiredResLayoutAttr !=
nullptr)
1609 return requiredResLayoutAttr;
1611 xegpu::DistributeLayoutAttr layoutAttr =
1612 cast<xegpu::DistributeLayoutAttr>(layout.get());
1613 if (layout.isSliceLayout())
1614 return cast<xegpu::SliceAttr>(layoutAttr);
1616 return cast<xegpu::LayoutAttr>(layoutAttr);
1624 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1626 getXeGPULayoutForValue);
1628 .Case([&](mlir::FunctionOpInterface funcOp) {
1630 getXeGPULayoutForValue);
1633 r =
updateOp(builder, op, getXeGPULayoutForValue);
1636 op.
emitError(
"Failed to update operation with the layout.");
1642 if (walkResult.wasInterrupted())
1649 ResolveLayoutConflicts resolver(
target);
1650 return resolver.run();
1653void XeGPUPropagateLayoutPass::runOnOperation() {
1655 if (this->layoutKind ==
"lane") {
1657 }
else if (this->layoutKind ==
"inst") {
1659 }
else if (this->layoutKind ==
"subgroup") {
1660 layoutKind = xegpu::LayoutKind::Subgroup;
1662 getOperation()->emitError(
"Unsupported layout kind option: " +
1664 signalPassFailure();
1669 this->indexBitWidth, this->printOnly))) {
1670 signalPassFailure();
1675 signalPassFailure();
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const