35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/raw_ostream.h"
45 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
46 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
50 #define DEBUG_TYPE "xegpu-subgroup-distribute"
51 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
79 Layout(std::initializer_list<int64_t> list) : layout(list) {}
80 void print(llvm::raw_ostream &os)
const;
81 size_t size()
const {
return layout.size(); }
82 int64_t operator[](
size_t idx)
const;
86 os << llvm::interleaved_array(layout);
89 int64_t Layout::operator[](
size_t idx)
const {
90 assert(idx < layout.size() &&
"Index out of bounds.");
97 using LaneLayout = Layout;
98 using LaneData = Layout;
122 LaneLayout laneLayout;
126 LayoutInfo() =
default;
127 LayoutInfo(
const LaneLayout &layout,
const LaneData &data)
128 : laneLayout(layout), laneData(data) {}
132 bool operator==(
const LayoutInfo &other)
const {
133 return this->isAssigned() == other.isAssigned();
136 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
138 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
140 void print(raw_ostream &os)
const;
142 bool isAssigned()
const {
143 return laneLayout.size() > 0 && laneData.size() > 0;
148 const LaneLayout &getLayout()
const {
return laneLayout; }
149 const LaneData &getData()
const {
return laneData; }
156 os <<
"lane_layout: ";
157 laneLayout.print(os);
158 os <<
", lane_data: ";
161 os <<
"Not assigned.";
165 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
166 if (!lhs.isAssigned())
172 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
173 llvm_unreachable(
"Join should not be triggered by layout propagation.");
181 LaneLayout newLayout;
183 for (int64_t idx : permutation) {
184 newLayout.layout.push_back(laneLayout.layout[idx]);
185 newData.layout.push_back(laneData.layout[idx]);
187 return LayoutInfo(newLayout, newData);
195 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
197 using Lattice::Lattice;
207 static LayoutInfo getDefaultLayoutInfo(
unsigned rank) {
208 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
210 return LayoutInfo(LaneLayout({
subgroupSize}), LaneData({1}));
211 return LayoutInfo(LaneLayout({1,
subgroupSize}), LaneData({1, 1}));
215 static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
217 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
218 "Expected 1D or 2D vector.");
220 assert(vectorTy.getElementType().isIntOrFloat() &&
221 "Expected int or float element type.");
223 if (vectorTy.getRank() == 1)
224 return getDefaultLayoutInfo(1);
226 int packingFactor = 1;
227 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
231 LaneData({1, packingFactor}));
240 static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
241 unsigned operandNum) {
242 Type elementTy = vectorTy.getElementType();
244 "Expected int or float type in DPAS operands");
248 if (operandNum == 1 &&
252 return LayoutInfo(layout, data);
255 return getDefaultLayoutInfo(vectorTy);
267 class LayoutInfoPropagation
273 void visitStoreNdOp(xegpu::StoreNdOp store,
277 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
281 void visitLoadNdOp(xegpu::LoadNdOp load,
285 void visitLoadGatherOp(xegpu::LoadGatherOp load,
289 void visitTransposeOp(vector::TransposeOp
transpose,
293 void visitVectorBitcastOp(vector::BitCastOp bitcast,
297 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
301 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
305 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
309 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
323 void visitBranchOperand(
OpOperand &operand)
override {};
325 void visitCallOperand(
OpOperand &operand)
override {};
327 void visitExternalCall(CallOpInterface call,
332 void setToExitState(LayoutInfoLattice *lattice)
override {
333 (void)lattice->meet(LayoutInfo());
338 LogicalResult LayoutInfoPropagation::visitOperation(
342 .Case<xegpu::DpasOp>(
343 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
344 .Case<xegpu::StoreNdOp>(
345 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
346 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
347 visitStoreScatterOp(storeScatterOp, operands, results);
349 .Case<xegpu::LoadNdOp>(
350 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
351 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
352 visitLoadGatherOp(loadGatherOp, operands, results);
354 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
355 visitCreateDescOp(createDescOp, operands, results);
357 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
358 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
360 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
361 visitPrefetchNdOp(prefetchNdOp, operands, results);
365 .Case<xegpu::CreateNdDescOp>([&](
auto createNdDescOp) {})
366 .Case<vector::TransposeOp>([&](
auto transposeOp) {
367 visitTransposeOp(transposeOp, operands, results);
369 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
370 visitVectorBitcastOp(bitcastOp, operands, results);
372 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
373 visitVectorMultiReductionOp(reductionOp, operands, results);
377 for (
const LayoutInfoLattice *r : results) {
378 for (LayoutInfoLattice *operand : operands) {
380 if (r->getValue().isAssigned())
386 for (
const LayoutInfoLattice *r : results) {
387 addDependency(
const_cast<LayoutInfoLattice *
>(r), getProgramPointAfter(op));
392 void LayoutInfoPropagation::visitPrefetchNdOp(
397 auto tdescTy = prefetch.getTensorDescType();
398 auto prefetchLayout = getDefaultLayoutInfo(
401 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
404 void LayoutInfoPropagation::visitVectorMultiReductionOp(
405 vector::MultiDimReductionOp reduction,
409 LayoutInfo resultLayout = results[0]->getValue();
410 if (!resultLayout.isAssigned())
413 assert(resultLayout.getLayout().size() == 1 &&
414 "Expected 1D layout for reduction result.");
417 LayoutInfo operandLayout = getDefaultLayoutInfo(2);
418 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
420 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
425 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
426 xegpu::UpdateNdOffsetOp updateNdOffset,
430 LayoutInfo resultLayout = results[0]->getValue();
431 if (!resultLayout.isAssigned())
434 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
438 void LayoutInfoPropagation::visitDpasOp(
441 VectorType aTy = dpas.getLhsType();
442 VectorType bTy = dpas.getRhsType();
443 propagateIfChanged(operands[0],
444 operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
445 propagateIfChanged(operands[1],
446 operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
447 if (operands.size() > 2) {
448 VectorType cTy = dpas.getAccType();
449 propagateIfChanged(operands[2],
450 operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
455 void LayoutInfoPropagation::visitStoreNdOp(
458 LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
460 for (LayoutInfoLattice *operand : operands) {
461 propagateIfChanged(operand, operand->meet(storeLayout));
467 void LayoutInfoPropagation::visitLoadNdOp(
470 LayoutInfo valueLayout = results[0]->getValue();
472 if (!valueLayout.isAssigned())
474 LayoutInfo tensorDescLayout = valueLayout;
477 if (
auto transpose = load.getTranspose()) {
478 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
479 "LayoutInfoPropagation stage.");
480 tensorDescLayout = valueLayout.getTransposedLayout(
transpose.value());
483 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
488 void LayoutInfoPropagation::visitTransposeOp(
492 LayoutInfo resultLayout = results[0]->getValue();
493 if (!resultLayout.isAssigned())
495 LayoutInfo newLayout =
496 resultLayout.getTransposedLayout(
transpose.getPermutation());
498 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
503 void LayoutInfoPropagation::visitVectorBitcastOp(
507 LayoutInfo resultLayout = results[0]->getValue();
508 if (!resultLayout.isAssigned())
510 int inElemTyBitWidth =
511 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
512 int outElemTyBitWidth =
513 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
516 const LaneLayout &newLaneLayout = resultLayout.getLayout();
517 const LaneData &currData = resultLayout.getData();
518 LaneData newLaneData;
520 if (inElemTyBitWidth < outElemTyBitWidth) {
521 int ratio = outElemTyBitWidth / inElemTyBitWidth;
522 newLaneData = resultLayout.getData()[0] == 1
523 ? LaneData({1, currData[1] * ratio})
524 : LaneData({currData[0] * ratio, 1});
527 int ratio = inElemTyBitWidth / outElemTyBitWidth;
528 newLaneData = resultLayout.getData()[0] == 1
529 ? LaneData({1, currData[1] / ratio})
530 : LaneData({currData[0] / ratio, 1});
533 propagateIfChanged(operands[0],
534 operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
539 void LayoutInfoPropagation::visitLoadGatherOp(
542 LayoutInfo valueLayout = results[0]->getValue();
544 if (!valueLayout.isAssigned())
547 LayoutInfo tensorDescLayout = valueLayout;
548 if (load.getTranspose()) {
552 load.emitWarning(
"Transpose effect is not expected for LoadGatherOp at "
553 "LayoutInfoPropagation stage.");
554 tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
557 LayoutInfo maskLayout = getDefaultLayoutInfo(1);
559 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
561 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
566 void LayoutInfoPropagation::visitCreateDescOp(
569 LayoutInfo descLayout = results[0]->getValue();
571 if (!descLayout.isAssigned())
574 LayoutInfo layout = getDefaultLayoutInfo(1);
575 propagateIfChanged(operands[1], operands[1]->meet(layout));
580 void LayoutInfoPropagation::visitStoreScatterOp(
587 if (tdescShape.size() > 1)
590 "Expected the first dimension of 2D tensor descriptor to be equal to "
593 LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
594 LayoutInfo storeScatterLayout = valueLayout;
595 if (storeScatter.getTranspose()) {
599 storeScatter.emitWarning(
"Transpose effect is not expected for "
600 "StoreScatterOp at LayoutInfoPropagation stage.");
601 storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
604 propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
606 propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
608 LayoutInfo maskLayout = getDefaultLayoutInfo(1);
609 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
619 class RunLayoutInfoPropagation {
623 RunLayoutInfoPropagation(
Operation *op) : target(op) {
627 solver.load<LayoutInfoPropagation>(symbolTable);
628 (void)solver.initializeAndRun(op);
631 LayoutInfo getLayoutInfo(
Value val);
633 void printAnalysisResult(llvm::raw_ostream &os);
641 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
642 auto *state = solver.lookupState<LayoutInfoLattice>(val);
645 return state->getValue();
648 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
649 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
650 os <<
"function: " << funcOp.getName() <<
":\n";
653 LayoutInfo layout = getLayoutInfo(arg);
654 os <<
"argument: " << arg <<
"\n";
666 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
673 LayoutInfo layout = getLayoutInfo(r);
674 os <<
"layout for result #" << i <<
": ";
682 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
683 for (
auto funcOp : modOp.getOps<FunctionOpInterface>()) {
684 funcOps.push_back(funcOp);
687 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
688 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
689 funcOps.push_back(gpuFuncOp);
694 for (FunctionOpInterface funcOp : funcOps) {
695 printFunctionResult(funcOp);
707 class LayoutAttrAssignment {
711 : getAnalysisResult(getLayout), top(top) {}
717 void assignToUsers(
Value v, xegpu::LayoutAttr layout);
718 xegpu::LayoutAttr getLayoutAttrForValue(
Value v);
719 LogicalResult resolveConflicts();
729 void LayoutAttrAssignment::assignToUsers(
Value v, xegpu::LayoutAttr layout) {
732 unsigned operandNumber = user.getOperandNumber();
734 std::string attrName =
736 owner->
setAttr(attrName, layout);
741 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(
Value v) {
742 LayoutInfo layout = getAnalysisResult(v);
743 if (!layout.isAssigned())
746 for (
auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
747 layout.getDataAsArrayRef())) {
748 laneLayout.push_back(
static_cast<int>(layout));
749 laneData.push_back(
static_cast<int>(data));
756 LogicalResult LayoutAttrAssignment::assign(
Operation *op) {
758 if (
auto func = dyn_cast<FunctionOpInterface>(op)) {
760 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
762 assignToUsers(arg, layoutInfo);
772 [](
Type t) { return t.isIntOrIndexOrFloat(); }))
778 return isa<xegpu::TensorDescType>(t);
782 <<
" op has more than one result and at least one is a tensor "
783 "descriptor. This case is not handled.\n");
788 if (
auto tensorDescTy =
790 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->
getResult(0));
792 LLVM_DEBUG(
DBGS() <<
"No layout for result of " << *op <<
"\n");
801 tensorDescTy.getContext(), tensorDescTy.getShape(),
802 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
810 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
813 op->
setAttr(attrName, layoutInfo);
815 assignToUsers(r, layoutInfo);
822 LogicalResult LayoutAttrAssignment::run() {
823 auto walkResult = top->walk([&](
Operation *op) {
824 if (failed(assign(op)))
829 if (walkResult.wasInterrupted())
832 return resolveConflicts();
841 LogicalResult LayoutAttrAssignment::resolveConflicts() {
return success(); }
862 static FailureOr<VectorType>
863 getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
864 VectorType originalType) {
868 auto laneLayout = layout.getLaneLayout().asArrayRef();
869 assert(originalType.getShape().size() >= laneLayout.size() &&
870 "Rank of the original vector type should be greater or equal to the "
871 "size of the lane layout to distribute the vector type.");
875 unsigned distributionStart = originalType.getRank() - laneLayout.size();
877 if (i < distributionStart) {
881 if (dim % laneLayout[i - distributionStart] != 0)
883 distributedShape[i] = dim / laneLayout[i - distributionStart];
885 return VectorType::get(distributedShape, originalType.getElementType());
902 template <
typename T>
903 static Value resolveDistributedTy(
Value orig, T expected,
906 if (orig.
getType() == expected)
909 if (isa<VectorType>(orig.
getType())) {
911 rewriter.
create<vector::ShapeCastOp>(orig.
getLoc(), expected, orig);
916 if (isa<xegpu::TensorDescType>(orig.
getType())) {
917 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(orig.
getLoc(),
921 llvm_unreachable(
"Unsupported type for reconciliation");
936 newAttrs.push_back(attr);
943 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
944 if (layout == xegpu::LayoutAttr())
947 if (!laneData || laneData.size() != 2)
975 struct MoveFuncBodyToWarpExecuteOnLane0
978 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
981 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
982 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
986 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
987 return isa<gpu::WarpExecuteOnLane0Op>(op);
991 auto newGpuFunc = rewriter.
create<gpu::GPUFuncOp>(
992 gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
996 auto laneId = rewriter.
create<gpu::LaneIdOp>(
998 mlir::IntegerAttr());
999 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
1000 auto warpOp = rewriter.
create<gpu::WarpExecuteOnLane0Op>(
1001 laneId.getLoc(), gpuFuncResultType, laneId,
subgroupSize,
1002 newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
1003 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
1006 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
1008 rewriter.
create<gpu::YieldOp>(origRetunOp.getLoc(),
1009 origRetunOp.getOperands());
1010 rewriter.
eraseOp(origRetunOp);
1013 warpOp.getBodyRegion().begin());
1017 rewriter.
create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
1018 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
1056 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1057 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1060 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
1063 subgroupOp,
"warp result is not a xegpu::CreateNdDesc op");
1067 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
1070 descOp,
"the tensor descriptor lacks layout attribute");
1076 for (
Value operand : descOp->getOperands()) {
1077 newYieldValues.push_back(operand);
1078 newYieldTypes.push_back(operand.getType());
1081 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1082 rewriter, subgroupOp, newYieldValues,
1083 newYieldTypes, newRetIndices);
1086 for (
size_t i : newRetIndices) {
1087 newDescOperands.push_back(newWarpOp.getResult(i));
1090 xegpu::TensorDescType distributedTensorDescTy =
1091 descOp.getType().dropLayouts();
1093 auto newDescOp = rewriter.
create<xegpu::CreateNdDescOp>(
1094 newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
1095 descOp->getAttrs());
1097 Value distributedVal = newWarpOp.getResult(operandIdx);
1135 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1136 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1138 auto yield = cast<gpu::YieldOp>(
1139 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1140 Operation *lastNode = yield->getPrevNode();
1141 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
1145 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
1146 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1149 storeOp,
"the source tensor descriptor lacks layout attribute");
1151 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
1152 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1153 if (failed(distributedTypeByWarpOpOrFailure))
1155 "Failed to distribute the type");
1156 VectorType distributedTypeByWarpOp =
1157 distributedTypeByWarpOpOrFailure.value();
1160 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1161 rewriter, subgroupOp,
1163 ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
1165 TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
1176 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
1178 if (failed(storeNdDistributedValueTyOrFailure))
1180 storeOp,
"Failed to get distributed vector type for the store op");
1181 newStoreOperands.push_back(resolveDistributedTy(
1182 newWarpOp.getResult(newRetIndices[0]),
1183 storeNdDistributedValueTyOrFailure.value(), rewriter));
1186 xegpu::TensorDescType distributedTensorDescTy =
1187 storeOp.getTensorDescType().dropLayouts();
1188 newStoreOperands.push_back(
1189 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1190 distributedTensorDescTy, rewriter));
1192 rewriter.
create<xegpu::StoreNdOp>(
1193 newWarpOp.getLoc(),
TypeRange{}, newStoreOperands,
1194 removeTemporaryLayoutAttributes(storeOp->getAttrs()));
1237 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1238 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1241 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
1244 subgroupOp,
"warp result is not a xegpu::LoadNd op");
1247 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
1248 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1251 loadOp,
"the source tensor descriptor lacks layout attribute");
1254 VectorType distributedTypeByWarpOp =
1255 cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
1258 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1259 rewriter, subgroupOp,
1260 loadOp.getTensorDesc(),
1261 tensorDescTy, newRetIndices);
1266 FailureOr<VectorType> loadNdDistValueTyOrFailure =
1268 if (failed(loadNdDistValueTyOrFailure))
1270 loadOp,
"Failed to get distributed vector type for the load op");
1271 xegpu::TensorDescType distributedTensorDescTy =
1272 loadOp.getTensorDescType().dropLayouts();
1275 auto newLoadOp = rewriter.
create<xegpu::LoadNdOp>(
1276 newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1277 resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
1278 distributedTensorDescTy, rewriter),
1279 removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1281 newLoadOp.setPacked(hasPackedLayout(layout));
1282 Value distributedVal = newWarpOp.getResult(operandIdx);
1286 Value tyResolvedVal = resolveDistributedTy(
1287 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1328 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1329 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1332 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
1335 "warp result is not a xegpu::Dpas op");
1339 std::string layoutAName =
1341 std::string layoutBName =
1344 xegpu::LayoutAttr layoutA =
1345 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1346 xegpu::LayoutAttr layoutB =
1347 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
1348 xegpu::LayoutAttr layoutOut =
1349 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
1350 if (!layoutA || !layoutB || !layoutOut)
1353 "the xegpu::Dpas op lacks layout attribute for A, B or output");
1355 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
1356 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
1357 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
1358 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
1359 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
1360 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
1361 if (failed(distLhsTypeByWarpOpOrFailure) ||
1362 failed(distRhsTypeByWarpOpOrFailure) ||
1363 failed(distResultTypeByWarpOpOrFailure))
1366 "Failed to distribute the A, B or output types in xegpu::Dpas op");
1371 distLhsTypeByWarpOpOrFailure.value(),
1372 distRhsTypeByWarpOpOrFailure.value()};
1374 if (dpasOp.getAcc()) {
1375 newYieldValues.push_back(dpasOp.getAcc());
1376 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
1380 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1381 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1383 FailureOr<VectorType> expectedDistLhsTyOrFailure =
1385 FailureOr<VectorType> expectedDistRhsTyOrFailure =
1387 FailureOr<VectorType> expectedDistResultTyOrFailure =
1389 if (failed(expectedDistLhsTyOrFailure) ||
1390 failed(expectedDistRhsTyOrFailure) ||
1391 failed(expectedDistResultTyOrFailure))
1394 "Failed to get distributed vector type for the dpas operands.");
1401 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1402 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1403 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
1404 if (dpasOp.getAcc())
1405 newDpasOperandExpectedTypes.push_back(distributedResultTy);
1407 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
1408 newDpasOperands.push_back(
1409 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
1410 newDpasOperandExpectedTypes[i], rewriter));
1413 newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
1414 removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
1415 Value distributedVal = newWarpOp.getResult(operandIdx);
1417 newDpasOp = resolveDistributedTy(
1418 newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
1459 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1460 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1463 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1466 subgroupOp,
"warp result is not a xegpu::UpdateNdOffset op");
1470 xegpu::TensorDescType newTensorDescTy =
1471 updateOp.getTensorDescType().dropLayouts();
1475 for (
Value operand : updateOp->getOperands()) {
1476 newYieldValues.push_back(operand);
1477 if (isa<xegpu::TensorDescType>(operand.getType())) {
1478 newYieldTypes.push_back(newTensorDescTy);
1480 newYieldTypes.push_back(operand.getType());
1484 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1485 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1488 for (
size_t i : newRetIndices) {
1491 if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1492 newUpdateOperands.push_back(resolveDistributedTy(
1493 newWarpOp.getResult(i), newTensorDescTy, rewriter));
1495 newUpdateOperands.push_back(newWarpOp.getResult(i));
1499 auto newUpdateOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
1500 newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1501 removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1502 Value distributedVal = newWarpOp.getResult(operandIdx);
1536 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1537 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1539 auto yield = cast<gpu::YieldOp>(
1540 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1541 Operation *lastNode = yield->getPrevNode();
1542 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1545 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
1548 prefetchOp,
"the source tensor descriptor lacks layout attribute");
1553 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1554 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1557 xegpu::TensorDescType newTensorDescTy =
1558 prefetchOp.getTensorDescType().dropLayouts();
1561 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1562 rewriter.
create<xegpu::PrefetchNdOp>(
1563 newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
1564 removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1573 struct XeGPUSubgroupDistributePass final
1574 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
1575 XeGPUSubgroupDistributePass> {
1576 XeGPUSubgroupDistributePass() =
default;
1577 XeGPUSubgroupDistributePass(
const XeGPUSubgroupDistributePass &other) =
1579 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions
options)
1580 : XeGPUSubgroupDistributeBase(
options) {}
1581 void runOnOperation()
override;
1587 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1588 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1589 UpdateNdOffsetDistribution>(
patterns.getContext());
1592 void XeGPUSubgroupDistributePass::runOnOperation() {
1593 auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
1596 auto &os = llvm::outs();
1597 analyis.printAnalysisResult(os);
1600 auto getPropagatedLayout = [&](
Value val) {
1601 return analyis.getLayoutInfo(val);
1606 LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
1607 if (failed(layoutAssignment.run())) {
1608 signalPassFailure();
1619 signalPassFailure();
1626 getOperation()->walk([&](
Operation *op) {
1627 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1628 vector::moveScalarUniformCode(warpOp);
1636 auto distributionFn = [](
Value val) {
1637 VectorType vecType = dyn_cast<VectorType>(val.
getType());
1638 int64_t vecRank = vecType ? vecType.getRank() : 0;
1645 int64_t warpSz) {
return Value(); };
1646 vector::populatePropagateWarpVectorDistributionPatterns(
1647 patterns, distributionFn, shuffleFn);
1649 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
constexpr unsigned subgroupSize
HW dependent constants.
static const char *const operandLayoutNamePrefix
static const char *const resultLayoutNamePrefix
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The general data-flow analysis solver.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
ArrayRef< T > asArrayRef() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...