35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/raw_ostream.h"
45 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
46 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
50 #define DEBUG_TYPE "xegpu-subgroup-distribute"
51 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
77 Layout(std::initializer_list<int64_t> list) : layout(list) {}
78 void print(llvm::raw_ostream &os)
const;
79 size_t size()
const {
return layout.size(); }
80 int64_t operator[](
size_t idx)
const;
84 os << llvm::interleaved_array(layout);
87 int64_t Layout::operator[](
size_t idx)
const {
88 assert(idx < layout.size() &&
"Index out of bounds.");
95 using LaneLayout = Layout;
96 using LaneData = Layout;
120 LaneLayout laneLayout;
124 LayoutInfo() =
default;
125 LayoutInfo(
const LaneLayout &layout,
const LaneData &data)
126 : laneLayout(layout), laneData(data) {}
130 bool operator==(
const LayoutInfo &other)
const {
131 return this->isAssigned() == other.isAssigned();
134 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
136 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
138 void print(raw_ostream &os)
const;
140 bool isAssigned()
const {
141 return laneLayout.size() > 0 && laneData.size() > 0;
146 const LaneLayout &getLayout()
const {
return laneLayout; }
147 const LaneData &getData()
const {
return laneData; }
154 os <<
"lane_layout: ";
155 laneLayout.print(os);
156 os <<
", lane_data: ";
159 os <<
"Not assigned.";
163 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
164 if (!lhs.isAssigned())
170 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
171 llvm_unreachable(
"Join should not be triggered by layout propagation.");
179 LaneLayout newLayout;
181 for (int64_t idx : permutation) {
182 newLayout.layout.push_back(laneLayout.layout[idx]);
183 newData.layout.push_back(laneData.layout[idx]);
185 return LayoutInfo(newLayout, newData);
193 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
195 using Lattice::Lattice;
205 static LayoutInfo getDefaultLayoutInfo(
unsigned rank) {
206 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
208 return LayoutInfo(LaneLayout({
subgroupSize}), LaneData({1}));
209 return LayoutInfo(LaneLayout({1,
subgroupSize}), LaneData({1, 1}));
213 static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
215 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
216 "Expected 1D or 2D vector.");
218 assert(vectorTy.getElementType().isIntOrFloat() &&
219 "Expected int or float element type.");
221 if (vectorTy.getRank() == 1)
222 return getDefaultLayoutInfo(1);
224 int packingFactor = 1;
225 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
229 LaneData({1, packingFactor}));
238 static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
239 unsigned operandNum) {
240 Type elementTy = vectorTy.getElementType();
242 "Expected int or float type in DPAS operands");
246 if (operandNum == 1 &&
250 return LayoutInfo(layout, data);
253 return getDefaultLayoutInfo(vectorTy);
265 class LayoutInfoPropagation
271 void visitStoreNdOp(xegpu::StoreNdOp store,
275 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
279 void visitLoadNdOp(xegpu::LoadNdOp load,
283 void visitLoadGatherOp(xegpu::LoadGatherOp load,
287 void visitTransposeOp(vector::TransposeOp
transpose,
291 void visitVectorBitcastOp(vector::BitCastOp bitcast,
295 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
299 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
303 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
307 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
321 void visitBranchOperand(
OpOperand &operand)
override {};
323 void visitCallOperand(
OpOperand &operand)
override {};
325 void visitExternalCall(CallOpInterface call,
330 void setToExitState(LayoutInfoLattice *lattice)
override {
331 (void)lattice->meet(LayoutInfo());
336 LogicalResult LayoutInfoPropagation::visitOperation(
340 .Case<xegpu::DpasOp>(
341 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
342 .Case<xegpu::StoreNdOp>(
343 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
344 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
345 visitStoreScatterOp(storeScatterOp, operands, results);
347 .Case<xegpu::LoadNdOp>(
348 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
349 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
350 visitLoadGatherOp(loadGatherOp, operands, results);
352 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
353 visitCreateDescOp(createDescOp, operands, results);
355 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
356 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
358 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
359 visitPrefetchNdOp(prefetchNdOp, operands, results);
363 .Case<xegpu::CreateNdDescOp>([&](
auto createNdDescOp) {})
364 .Case<vector::TransposeOp>([&](
auto transposeOp) {
365 visitTransposeOp(transposeOp, operands, results);
367 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
368 visitVectorBitcastOp(bitcastOp, operands, results);
370 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
371 visitVectorMultiReductionOp(reductionOp, operands, results);
375 for (
const LayoutInfoLattice *r : results) {
376 for (LayoutInfoLattice *operand : operands) {
378 if (r->getValue().isAssigned())
384 for (
const LayoutInfoLattice *r : results) {
385 addDependency(
const_cast<LayoutInfoLattice *
>(r), getProgramPointAfter(op));
390 void LayoutInfoPropagation::visitPrefetchNdOp(
395 auto tdescTy = prefetch.getTensorDescType();
396 auto prefetchLayout = getDefaultLayoutInfo(
399 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
402 void LayoutInfoPropagation::visitVectorMultiReductionOp(
403 vector::MultiDimReductionOp reduction,
407 LayoutInfo resultLayout = results[0]->getValue();
408 if (!resultLayout.isAssigned())
411 assert(resultLayout.getLayout().size() == 1 &&
412 "Expected 1D layout for reduction result.");
415 LayoutInfo operandLayout = getDefaultLayoutInfo(2);
416 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
418 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
423 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
424 xegpu::UpdateNdOffsetOp updateNdOffset,
428 LayoutInfo resultLayout = results[0]->getValue();
429 if (!resultLayout.isAssigned())
432 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
436 void LayoutInfoPropagation::visitDpasOp(
439 VectorType aTy = dpas.getLhsType();
440 VectorType bTy = dpas.getRhsType();
441 propagateIfChanged(operands[0],
442 operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
443 propagateIfChanged(operands[1],
444 operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
445 if (operands.size() > 2) {
446 VectorType cTy = dpas.getAccType();
447 propagateIfChanged(operands[2],
448 operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
453 void LayoutInfoPropagation::visitStoreNdOp(
456 LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
458 for (LayoutInfoLattice *operand : operands) {
459 propagateIfChanged(operand, operand->meet(storeLayout));
465 void LayoutInfoPropagation::visitLoadNdOp(
468 LayoutInfo valueLayout = results[0]->getValue();
470 if (!valueLayout.isAssigned())
472 LayoutInfo tensorDescLayout = valueLayout;
475 if (
auto transpose = load.getTranspose()) {
476 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
477 "LayoutInfoPropagation stage.");
478 tensorDescLayout = valueLayout.getTransposedLayout(
transpose.value());
481 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
486 void LayoutInfoPropagation::visitTransposeOp(
490 LayoutInfo resultLayout = results[0]->getValue();
491 if (!resultLayout.isAssigned())
493 LayoutInfo newLayout =
494 resultLayout.getTransposedLayout(
transpose.getPermutation());
496 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
501 void LayoutInfoPropagation::visitVectorBitcastOp(
505 LayoutInfo resultLayout = results[0]->getValue();
506 if (!resultLayout.isAssigned())
508 int inElemTyBitWidth =
509 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
510 int outElemTyBitWidth =
511 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
514 const LaneLayout &newLaneLayout = resultLayout.getLayout();
515 const LaneData &currData = resultLayout.getData();
516 LaneData newLaneData;
518 if (inElemTyBitWidth < outElemTyBitWidth) {
519 int ratio = outElemTyBitWidth / inElemTyBitWidth;
520 newLaneData = resultLayout.getData()[0] == 1
521 ? LaneData({1, currData[1] * ratio})
522 : LaneData({currData[0] * ratio, 1});
525 int ratio = inElemTyBitWidth / outElemTyBitWidth;
526 newLaneData = resultLayout.getData()[0] == 1
527 ? LaneData({1, currData[1] / ratio})
528 : LaneData({currData[0] / ratio, 1});
531 propagateIfChanged(operands[0],
532 operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
537 void LayoutInfoPropagation::visitLoadGatherOp(
540 LayoutInfo valueLayout = results[0]->getValue();
542 if (!valueLayout.isAssigned())
545 LayoutInfo tensorDescLayout = valueLayout;
546 if (load.getTranspose()) {
550 load.emitWarning(
"Transpose effect is not expected for LoadGatherOp at "
551 "LayoutInfoPropagation stage.");
552 tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
555 LayoutInfo maskLayout = getDefaultLayoutInfo(1);
557 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
559 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
564 void LayoutInfoPropagation::visitCreateDescOp(
567 LayoutInfo descLayout = results[0]->getValue();
569 if (!descLayout.isAssigned())
572 LayoutInfo layout = getDefaultLayoutInfo(1);
573 propagateIfChanged(operands[1], operands[1]->meet(layout));
578 void LayoutInfoPropagation::visitStoreScatterOp(
585 if (tdescShape.size() > 1)
588 "Expected the first dimension of 2D tensor descriptor to be equal to "
591 LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
592 LayoutInfo storeScatterLayout = valueLayout;
593 if (storeScatter.getTranspose()) {
597 storeScatter.emitWarning(
"Transpose effect is not expected for "
598 "StoreScatterOp at LayoutInfoPropagation stage.");
599 storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
602 propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
604 propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
606 LayoutInfo maskLayout = getDefaultLayoutInfo(1);
607 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
617 class RunLayoutInfoPropagation {
621 RunLayoutInfoPropagation(
Operation *op) : target(op) {
625 solver.load<LayoutInfoPropagation>(symbolTable);
626 (void)solver.initializeAndRun(op);
629 LayoutInfo getLayoutInfo(
Value val);
631 void printAnalysisResult(llvm::raw_ostream &os);
639 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
640 auto *state = solver.lookupState<LayoutInfoLattice>(val);
643 return state->getValue();
646 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
647 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
648 os <<
"function: " << funcOp.getName() <<
":\n";
651 LayoutInfo layout = getLayoutInfo(arg);
652 os <<
"argument: " << arg <<
"\n";
664 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
671 LayoutInfo layout = getLayoutInfo(r);
672 os <<
"layout for result #" << i <<
": ";
680 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
681 for (
auto funcOp : modOp.getOps<FunctionOpInterface>()) {
682 funcOps.push_back(funcOp);
685 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
686 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
687 funcOps.push_back(gpuFuncOp);
692 for (FunctionOpInterface funcOp : funcOps) {
693 printFunctionResult(funcOp);
705 class LayoutAttrAssignment {
709 : getAnalysisResult(getLayout), top(top) {}
715 void assignToUsers(
Value v, xegpu::LayoutAttr layout);
716 xegpu::LayoutAttr getLayoutAttrForValue(
Value v);
717 LogicalResult resolveConflicts();
727 void LayoutAttrAssignment::assignToUsers(
Value v, xegpu::LayoutAttr layout) {
731 owner->
setAttr(attrName, layout);
736 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(
Value v) {
737 LayoutInfo layout = getAnalysisResult(v);
738 if (!layout.isAssigned())
741 for (
auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
742 layout.getDataAsArrayRef())) {
743 laneLayout.push_back(
static_cast<int>(layout));
744 laneData.push_back(
static_cast<int>(data));
751 LogicalResult LayoutAttrAssignment::assign(
Operation *op) {
753 if (
auto func = dyn_cast<FunctionOpInterface>(op)) {
755 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
757 assignToUsers(arg, layoutInfo);
767 [](
Type t) { return t.isIntOrIndexOrFloat(); }))
773 return isa<xegpu::TensorDescType>(t);
777 <<
" op has more than one result and at least one is a tensor "
778 "descriptor. This case is not handled.\n");
783 if (
auto tensorDescTy =
785 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->
getResult(0));
787 LLVM_DEBUG(
DBGS() <<
"No layout for result of " << *op <<
"\n");
796 tensorDescTy.getContext(), tensorDescTy.getShape(),
797 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
805 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
808 op->
setAttr(attrName, layoutInfo);
810 assignToUsers(r, layoutInfo);
817 LogicalResult LayoutAttrAssignment::run() {
818 auto walkResult = top->walk([&](
Operation *op) {
819 if (failed(assign(op)))
824 if (walkResult.wasInterrupted())
827 return resolveConflicts();
836 LogicalResult LayoutAttrAssignment::resolveConflicts() {
return success(); }
857 static FailureOr<VectorType>
858 getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
859 VectorType originalType) {
863 auto laneLayout = layout.getLaneLayout().asArrayRef();
864 assert(originalType.getShape().size() >= laneLayout.size() &&
865 "Rank of the original vector type should be greater or equal to the "
866 "size of the lane layout to distribute the vector type.");
870 unsigned distributionStart = originalType.getRank() - laneLayout.size();
872 if (i < distributionStart) {
876 if (dim % laneLayout[i - distributionStart] != 0)
878 distributedShape[i] = dim / laneLayout[i - distributionStart];
880 return VectorType::get(distributedShape, originalType.getElementType());
897 template <
typename T>
898 static Value resolveDistributedTy(
Value orig, T expected,
901 if (orig.
getType() == expected)
904 if (isa<VectorType>(orig.
getType())) {
906 rewriter.
create<vector::ShapeCastOp>(orig.
getLoc(), expected, orig);
911 if (isa<xegpu::TensorDescType>(orig.
getType())) {
912 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(orig.
getLoc(),
916 llvm_unreachable(
"Unsupported type for reconciliation");
927 if (!isa<xegpu::LayoutAttr>(attr.getValue()))
928 newAttrs.push_back(attr);
935 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
936 if (layout == xegpu::LayoutAttr())
939 if (!laneData || laneData.size() != 2)
967 struct MoveFuncBodyToWarpExecuteOnLane0
970 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
973 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
974 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
978 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
979 return isa<gpu::WarpExecuteOnLane0Op>(op);
983 auto newGpuFunc = rewriter.
create<gpu::GPUFuncOp>(
984 gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
988 auto laneId = rewriter.
create<gpu::LaneIdOp>(
990 mlir::IntegerAttr());
991 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
992 auto warpOp = rewriter.
create<gpu::WarpExecuteOnLane0Op>(
993 laneId.getLoc(), gpuFuncResultType, laneId,
subgroupSize,
994 newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
995 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
998 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
1000 rewriter.
create<gpu::YieldOp>(origRetunOp.getLoc(),
1001 origRetunOp.getOperands());
1002 rewriter.
eraseOp(origRetunOp);
1005 warpOp.getBodyRegion().begin());
1009 rewriter.
create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
1010 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
1048 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1049 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1052 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
1055 subgroupOp,
"warp result is not a xegpu::CreateNdDesc op");
1059 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
1062 descOp,
"the tensor descriptor lacks layout attribute");
1068 for (
Value operand : descOp->getOperands()) {
1069 newYieldValues.push_back(operand);
1070 newYieldTypes.push_back(operand.getType());
1073 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1074 rewriter, subgroupOp, newYieldValues,
1075 newYieldTypes, newRetIndices);
1078 for (
size_t i : newRetIndices) {
1079 newDescOperands.push_back(newWarpOp.getResult(i));
1082 xegpu::TensorDescType distributedTensorDescTy =
1083 descOp.getType().dropLayouts();
1085 auto newDescOp = rewriter.
create<xegpu::CreateNdDescOp>(
1086 newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
1087 descOp->getAttrs());
1089 Value distributedVal = newWarpOp.getResult(operandIdx);
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1130 auto yield = cast<gpu::YieldOp>(
1131 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1132 Operation *lastNode = yield->getPrevNode();
1133 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
1137 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
1138 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1141 storeOp,
"the source tensor descriptor lacks layout attribute");
1143 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
1144 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1145 if (failed(distributedTypeByWarpOpOrFailure))
1147 "Failed to distribute the type");
1148 VectorType distributedTypeByWarpOp =
1149 distributedTypeByWarpOpOrFailure.value();
1152 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1153 rewriter, subgroupOp,
1155 ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
1157 TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
1168 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
1170 if (failed(storeNdDistributedValueTyOrFailure))
1172 storeOp,
"Failed to get distributed vector type for the store op");
1173 newStoreOperands.push_back(resolveDistributedTy(
1174 newWarpOp.getResult(newRetIndices[0]),
1175 storeNdDistributedValueTyOrFailure.value(), rewriter));
1178 xegpu::TensorDescType distributedTensorDescTy =
1179 storeOp.getTensorDescType().dropLayouts();
1180 newStoreOperands.push_back(
1181 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1182 distributedTensorDescTy, rewriter));
1184 rewriter.
create<xegpu::StoreNdOp>(
1185 newWarpOp.getLoc(),
TypeRange{}, newStoreOperands,
1186 removeTemporaryLayoutAttributes(storeOp->getAttrs()));
1229 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1230 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1233 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
1236 subgroupOp,
"warp result is not a xegpu::LoadNd op");
1239 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
1240 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1243 loadOp,
"the source tensor descriptor lacks layout attribute");
1246 VectorType distributedTypeByWarpOp =
1247 cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
1250 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1251 rewriter, subgroupOp,
1252 loadOp.getTensorDesc(),
1253 tensorDescTy, newRetIndices);
1258 FailureOr<VectorType> loadNdDistValueTyOrFailure =
1260 if (failed(loadNdDistValueTyOrFailure))
1262 loadOp,
"Failed to get distributed vector type for the load op");
1263 xegpu::TensorDescType distributedTensorDescTy =
1264 loadOp.getTensorDescType().dropLayouts();
1267 auto newLoadOp = rewriter.
create<xegpu::LoadNdOp>(
1268 newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1269 resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
1270 distributedTensorDescTy, rewriter),
1271 removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1273 newLoadOp.setPacked(hasPackedLayout(layout));
1274 Value distributedVal = newWarpOp.getResult(operandIdx);
1278 Value tyResolvedVal = resolveDistributedTy(
1279 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1320 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1321 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1324 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
1327 "warp result is not a xegpu::Dpas op");
1335 xegpu::LayoutAttr layoutA =
1336 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1337 xegpu::LayoutAttr layoutB =
1338 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
1339 xegpu::LayoutAttr layoutOut =
1340 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
1341 if (!layoutA || !layoutB || !layoutOut)
1344 "the xegpu::Dpas op lacks layout attribute for A, B or output");
1346 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
1347 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
1348 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
1349 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
1350 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
1351 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
1352 if (failed(distLhsTypeByWarpOpOrFailure) ||
1353 failed(distRhsTypeByWarpOpOrFailure) ||
1354 failed(distResultTypeByWarpOpOrFailure))
1357 "Failed to distribute the A, B or output types in xegpu::Dpas op");
1362 distLhsTypeByWarpOpOrFailure.value(),
1363 distRhsTypeByWarpOpOrFailure.value()};
1365 if (dpasOp.getAcc()) {
1366 newYieldValues.push_back(dpasOp.getAcc());
1367 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
1371 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1372 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1374 FailureOr<VectorType> expectedDistLhsTyOrFailure =
1376 FailureOr<VectorType> expectedDistRhsTyOrFailure =
1378 FailureOr<VectorType> expectedDistResultTyOrFailure =
1380 if (failed(expectedDistLhsTyOrFailure) ||
1381 failed(expectedDistRhsTyOrFailure) ||
1382 failed(expectedDistResultTyOrFailure))
1385 "Failed to get distributed vector type for the dpas operands.");
1392 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1393 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1394 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
1395 if (dpasOp.getAcc())
1396 newDpasOperandExpectedTypes.push_back(distributedResultTy);
1398 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
1399 newDpasOperands.push_back(
1400 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
1401 newDpasOperandExpectedTypes[i], rewriter));
1404 newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
1405 removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
1406 Value distributedVal = newWarpOp.getResult(operandIdx);
1408 newDpasOp = resolveDistributedTy(
1409 newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
1450 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1451 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1454 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1457 subgroupOp,
"warp result is not a xegpu::UpdateNdOffset op");
1461 xegpu::TensorDescType newTensorDescTy =
1462 updateOp.getTensorDescType().dropLayouts();
1466 for (
Value operand : updateOp->getOperands()) {
1467 newYieldValues.push_back(operand);
1468 if (isa<xegpu::TensorDescType>(operand.getType())) {
1469 newYieldTypes.push_back(newTensorDescTy);
1471 newYieldTypes.push_back(operand.getType());
1475 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1476 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1479 for (
size_t i : newRetIndices) {
1482 if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1483 newUpdateOperands.push_back(resolveDistributedTy(
1484 newWarpOp.getResult(i), newTensorDescTy, rewriter));
1486 newUpdateOperands.push_back(newWarpOp.getResult(i));
1490 auto newUpdateOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
1491 newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1492 removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1493 Value distributedVal = newWarpOp.getResult(operandIdx);
1527 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1528 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1530 auto yield = cast<gpu::YieldOp>(
1531 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1532 Operation *lastNode = yield->getPrevNode();
1533 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1536 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
1539 prefetchOp,
"the source tensor descriptor lacks layout attribute");
1544 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1545 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1548 xegpu::TensorDescType newTensorDescTy =
1549 prefetchOp.getTensorDescType().dropLayouts();
1552 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1553 rewriter.
create<xegpu::PrefetchNdOp>(
1554 newWarpOp.getLoc(),
TypeRange{}, newPrefetchOperands,
1555 removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1564 struct XeGPUSubgroupDistributePass final
1565 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
1566 XeGPUSubgroupDistributePass> {
1567 XeGPUSubgroupDistributePass() =
default;
1568 XeGPUSubgroupDistributePass(
const XeGPUSubgroupDistributePass &other) =
1570 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions
options)
1571 : XeGPUSubgroupDistributeBase(
options) {}
1572 void runOnOperation()
override;
1578 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1579 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1580 UpdateNdOffsetDistribution>(
patterns.getContext());
1583 void XeGPUSubgroupDistributePass::runOnOperation() {
1584 auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
1587 auto &os = llvm::outs();
1588 analyis.printAnalysisResult(os);
1591 auto getPropagatedLayout = [&](
Value val) {
1592 return analyis.getLayoutInfo(val);
1597 LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
1598 if (failed(layoutAssignment.run())) {
1599 signalPassFailure();
1610 signalPassFailure();
1617 getOperation()->walk([&](
Operation *op) {
1618 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1619 vector::moveScalarUniformCode(warpOp);
1627 auto distributionFn = [](
Value val) {
1628 VectorType vecType = dyn_cast<VectorType>(val.
getType());
1629 int64_t vecRank = vecType ? vecType.getRank() : 0;
1636 int64_t warpSz) {
return Value(); };
1637 vector::populatePropagateWarpVectorDistributionPatterns(
1638 patterns, distributionFn, shuffleFn);
1640 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
constexpr unsigned subgroupSize
HW dependent constants.
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The general data-flow analysis solver.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
result_range getOpResults()
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
ArrayRef< T > asArrayRef() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...