34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51 "resolve_simt_type_mismatch";
74 static FailureOr<VectorType>
75 getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
76 VectorType originalType) {
80 auto laneLayout = layout.getLaneLayout().asArrayRef();
81 assert(originalType.getShape().size() >= laneLayout.size() &&
82 "Rank of the original vector type should be greater or equal to the "
83 "size of the lane layout to distribute the vector type.");
87 unsigned distributionStart = originalType.getRank() - laneLayout.size();
89 if (i < distributionStart)
93 if (dim % laneLayout[i - distributionStart] != 0)
95 distributedShape[i] = dim / laneLayout[i - distributionStart];
97 return VectorType::get(distributedShape, originalType.getElementType());
114 template <
typename T>
115 static Value resolveDistributedTy(
Value orig, T expected,
118 if (orig.
getType() == expected)
121 if (isa<VectorType>(orig.
getType())) {
123 vector::ShapeCastOp::create(rewriter, orig.
getLoc(), expected, orig);
124 return castOp.getResult();
128 if (isa<xegpu::TensorDescType>(orig.
getType())) {
129 auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.
getLoc(),
132 return castOp.getResult(0);
134 llvm_unreachable(
"Unsupported type for reconciliation");
140 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
141 if (layout == xegpu::LayoutAttr())
144 if (!laneData || laneData.size() != 2)
172 struct MoveFuncBodyToWarpExecuteOnLane0
175 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
178 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
179 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
183 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](
Operation &op) {
184 return isa<gpu::WarpExecuteOnLane0Op>(op);
189 llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
192 llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
194 auto newGpuFunc = gpu::GPUFuncOp::create(
195 rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
197 privateAttributionsTypes);
198 newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
202 auto laneId = gpu::LaneIdOp::create(
204 mlir::IntegerAttr());
205 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
206 auto warpOp = gpu::WarpExecuteOnLane0Op::create(
207 rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
209 newGpuFunc.getArgumentTypes());
210 Block &warpBodyBlock = warpOp.getBodyRegion().
front();
213 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
215 gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
216 origRetunOp.getOperands());
220 warpOp.getBodyRegion().begin());
224 gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
225 rewriter.
replaceOp(gpuFuncOp, newGpuFunc);
263 using gpu::WarpDistributionPattern::WarpDistributionPattern;
264 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
267 getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
270 warpOp,
"warp result is not a xegpu::CreateNdDesc op");
274 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
277 descOp,
"the tensor descriptor lacks layout attribute");
281 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
282 rewriter, warpOp, descOp->getOperands(),
283 descOp.getOperandTypes(), newRetIndices);
286 newRetIndices, [&](
size_t i) {
return newWarpOp.getResult(i); });
288 xegpu::TensorDescType distributedTensorDescTy =
289 descOp.getType().dropLayouts();
291 Value newDescOp = xegpu::CreateNdDescOp::create(
292 rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
295 Value distributedVal = newWarpOp.getResult(operandIdx);
298 resolveDistributedTy(newDescOp, distributedVal.
getType(), rewriter);
336 using gpu::WarpDistributionPattern::WarpDistributionPattern;
337 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
339 gpu::YieldOp yield = warpOp.getTerminator();
340 Operation *lastNode = yield->getPrevNode();
341 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
345 int64_t offsetSize =
static_cast<int64_t
>(storeOp.getOffsets().size());
346 if ((offsetSize != 0) || storeOp.getConstOffsetsAttr())
349 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
350 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
353 storeOp,
"the source tensor descriptor lacks layout attribute");
355 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
356 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
357 if (
failed(distributedTypeByWarpOpOrFailure))
359 "Failed to distribute the type");
360 VectorType distributedTypeByWarpOp =
361 distributedTypeByWarpOpOrFailure.value();
364 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
367 ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
369 TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
380 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
382 if (
failed(storeNdDistributedValueTyOrFailure))
384 storeOp,
"Failed to get distributed vector type for the store op");
385 newStoreOperands.push_back(resolveDistributedTy(
386 newWarpOp.getResult(newRetIndices[0]),
387 storeNdDistributedValueTyOrFailure.value(), rewriter));
390 xegpu::TensorDescType distributedTensorDescTy =
391 storeOp.getTensorDescType().dropLayouts();
392 newStoreOperands.push_back(
393 resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
394 distributedTensorDescTy, rewriter));
397 xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
398 newStoreOperands, storeOp->getAttrs());
442 using gpu::WarpDistributionPattern::WarpDistributionPattern;
443 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
446 if (!isa<xegpu::LoadNdOp>(op))
451 gpu::YieldOp yield = warpOp.getTerminator();
452 return yield->getPrevNode() == op;
457 warpOp,
"warp result is not a xegpu::LoadNd op");
461 int64_t offsetSize =
static_cast<int64_t
>(loadOp.getOffsets().size());
462 if ((offsetSize != 0) || loadOp.getConstOffsetsAttr())
465 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
466 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
469 loadOp,
"the source tensor descriptor lacks layout attribute");
472 VectorType distributedTypeByWarpOp =
473 cast<VectorType>(warpOp.getResult(operandIdx).getType());
476 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
478 loadOp.getTensorDesc(),
479 tensorDescTy, newRetIndices);
484 FailureOr<VectorType> loadNdDistValueTyOrFailure =
486 if (
failed(loadNdDistValueTyOrFailure))
488 loadOp,
"Failed to get distributed vector type for the load op");
489 xegpu::TensorDescType distributedTensorDescTy =
490 loadOp.getTensorDescType().dropLayouts();
493 auto newLoadOp = xegpu::LoadNdOp::create(
494 rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
495 resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
496 distributedTensorDescTy, rewriter),
500 newLoadOp.setPacked(hasPackedLayout(layout));
501 Value distributedVal = newWarpOp.getResult(operandIdx);
505 Value tyResolvedVal = resolveDistributedTy(
506 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
547 using gpu::WarpDistributionPattern::WarpDistributionPattern;
548 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
550 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
553 "warp result is not a xegpu::Dpas op");
561 xegpu::LayoutAttr layoutA =
562 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
563 xegpu::LayoutAttr layoutB =
564 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
565 xegpu::LayoutAttr layoutOut =
566 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
567 if (!layoutA || !layoutB || !layoutOut)
570 "the xegpu::Dpas op lacks layout attribute for A, B or output");
572 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
573 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
574 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
575 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
576 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
577 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
578 if (
failed(distLhsTypeByWarpOpOrFailure) ||
579 failed(distRhsTypeByWarpOpOrFailure) ||
580 failed(distResultTypeByWarpOpOrFailure))
583 "Failed to distribute the A, B or output types in xegpu::Dpas op");
588 distLhsTypeByWarpOpOrFailure.value(),
589 distRhsTypeByWarpOpOrFailure.value()};
591 if (dpasOp.getAcc()) {
592 newYieldValues.push_back(dpasOp.getAcc());
593 newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
597 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
598 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
600 FailureOr<VectorType> expectedDistLhsTyOrFailure =
602 FailureOr<VectorType> expectedDistRhsTyOrFailure =
604 FailureOr<VectorType> expectedDistResultTyOrFailure =
606 if (
failed(expectedDistLhsTyOrFailure) ||
607 failed(expectedDistRhsTyOrFailure) ||
608 failed(expectedDistResultTyOrFailure))
611 "Failed to get distributed vector type for the dpas operands.");
618 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
619 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
620 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
622 newDpasOperandExpectedTypes.push_back(distributedResultTy);
624 for (
unsigned i = 0; i < newRetIndices.size(); i++) {
625 newDpasOperands.push_back(
626 resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
627 newDpasOperandExpectedTypes[i], rewriter));
629 auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
630 distributedResultTy, newDpasOperands,
633 Value distributedVal = newWarpOp.getResult(operandIdx);
636 resolveDistributedTy(newDpasOp.getResult(),
637 distResultTypeByWarpOpOrFailure.value(), rewriter);
678 using gpu::WarpDistributionPattern::WarpDistributionPattern;
679 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
682 getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685 warpOp,
"warp result is not a xegpu::UpdateNdOffset op");
690 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
695 xegpu::TensorDescType distributedTensorDescTy =
696 updateOp.getTensorDescType().dropLayouts();
698 llvm::map_to_vector(newRetIndices, [&](
size_t i) {
702 if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
703 return resolveDistributedTy(newWarpOp.getResult(i),
704 distributedTensorDescTy, rewriter);
706 return newWarpOp.getResult(i);
709 auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
710 rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
711 newUpdateOperands,
updateOp->getAttrs());
713 Value distributedVal = newWarpOp.getResult(operandIdx);
715 Value typeResolved = resolveDistributedTy(
716 newUpdateOp.getResult(), distributedVal.
getType(), rewriter);
750 using gpu::WarpDistributionPattern::WarpDistributionPattern;
751 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
753 gpu::YieldOp yield = warpOp.getTerminator();
754 Operation *lastNode = yield->getPrevNode();
755 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
759 int64_t offsetSize =
static_cast<int64_t
>(prefetchOp.getOffsets().size());
760 if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr())
763 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
766 prefetchOp,
"the source tensor descriptor lacks layout attribute");
771 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
772 rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
775 xegpu::TensorDescType newTensorDescTy =
776 prefetchOp.getTensorDescType().dropLayouts();
779 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
780 xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(),
TypeRange{},
781 newPrefetchOperands, prefetchOp->getAttrs());
791 using gpu::WarpDistributionPattern::WarpDistributionPattern;
792 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
794 gpu::YieldOp yield = warpOp.getTerminator();
795 Operation *lastNode = yield->getPrevNode();
797 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
802 gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
803 barrierOp->getResultTypes(),
804 barrierOp->getOperands(), barrierOp->getAttrs());
813 struct XeGPUSubgroupDistributePass final
814 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
815 XeGPUSubgroupDistributePass> {
816 void runOnOperation()
override;
822 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
823 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824 UpdateNdOffsetDistribution, GpuBarrierDistribution>(
828 void XeGPUSubgroupDistributePass::runOnOperation() {
837 if (!isa<VectorType>(operand.get().getType()))
841 xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
843 op->emitError(
"Could not find layout attribute for operand ")
844 << operand.getOperandNumber() <<
" of operation " << op->getName();
865 getOperation()->walk([&](
Operation *op) {
866 if (
auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
867 vector::moveScalarUniformCode(warpOp);
876 auto distributionFn = [](
Value val) {
877 VectorType vecType = dyn_cast<VectorType>(val.getType());
878 int64_t vecRank = vecType ? vecType.getRank() : 0;
883 auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
888 vecRank, {
static_cast<unsigned int>(vecRank - 1)}, val.getContext());
891 ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
892 for (
unsigned i = 0; i < laneLayout.size(); ++i) {
893 if (laneLayout[i] > 1)
894 distributedDims.push_back(i);
901 int64_t warpSz) {
return Value(); };
902 vector::populatePropagateWarpVectorDistributionPatterns(
903 patterns, distributionFn, shuffleFn);
912 getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
918 Value input = op.getOperand(0);
919 Value output = op.getResult(0);
922 xegpu::TensorDescType inputDescType =
923 mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
924 xegpu::TensorDescType outputDescType =
925 mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
926 assert(inputDescType && outputDescType &&
927 "Unrealized conversion cast must have tensor descriptor types");
932 if (inputDescType.getLayout()) {
933 auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
935 argument.setType(output.getType());
936 output.replaceAllUsesWith(argument);
937 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
938 argument.getOwner()->getParentOp())) {
939 auto result = loopOp.getTiedLoopResult(argument);
940 result.setType(output.getType());
948 if (outputDescType.getLayout())
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
ArrayRef< T > asArrayRef() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr unsigned subgroupSize
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...