23 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
24 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
87 assert(localOffset.size() == distUnitBaseAddr.size() &&
88 "localOffset and distUnitBaseAddr must have the same rank");
91 originalOffsets.end());
92 size_t rank = localOffset.size();
93 for (
size_t i = 0; i < rank; ++i) {
94 size_t dimIdx = originalOffsets.size() - rank + i;
96 rewriter.
create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
98 rewriter.
createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
100 rewriter.
create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
102 rewriter.
createOrFold<index::RemUOp>(loc, offset, modValue);
104 rewriter, loc, originalOffsets[dimIdx]);
106 rewriter.
createOrFold<index::AddOp>(loc, origOffset, offsetMod);
107 globalOffsets[dimIdx] = globalOffset;
110 return globalOffsets;
114 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
118 xegpu::TensorDescType tdescTy = op.getType();
119 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
122 Type elemTy = tdescTy.getElementType();
126 if (
auto sgLayoutAttr = layout.getSgLayout())
127 sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
130 op,
"sgLayout attribute is required in layout");
133 if (
auto sgDataAttr = layout.getSgData()) {
134 sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
136 assert(wgShape.size() == sgLayout.size() &&
137 "sgLayout and wgShape must have the same rank");
138 sgShape.reserve(wgShape.size());
139 for (
size_t i = 0; i < wgShape.size(); ++i) {
140 assert(sgLayout[i] != 0 &&
"sgLayout elements must be non-zero");
141 sgShape.push_back(wgShape[i] / sgLayout[i]);
148 rewriter.
create<gpu::SubgroupIdOp>(loc,
nullptr);
154 for (
size_t i = 0; i < sgLayout.size(); i++) {
156 rewriter.
create<arith::ConstantIndexOp>(loc, sgLayout[i]);
157 sgDataDim[i] = rewriter.
create<arith::ConstantIndexOp>(loc, sgShape[i]);
160 auto deLinearizeSgId =
162 if (failed(deLinearizeSgId))
169 for (
size_t i = 0; i < sgLayout.size(); i++) {
170 distUnitShape[i] =
std::min(sgLayout[i] * sgShape[i], wgShape[i]);
172 rewriter.
createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
177 xegpu::TensorDescType newTdescTy =
179 layout.dropSgLayoutAndData());
184 calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
185 distUnitBaseAddr, distUnitShape);
187 auto newCreateNdOp = rewriter.
create<xegpu::CreateNdDescOp>(
188 loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
189 op.getMixedStrides());
190 newCreateNdOps.push_back(newCreateNdOp);
202 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
205 for (
auto src : adaptor.getTensorDesc()) {
206 xegpu::TensorDescType tdescTy =
207 dyn_cast<xegpu::TensorDescType>(src.getType());
209 VectorType newResTy =
VectorType::get(srcShape, tdescTy.getElementType());
210 auto newLoadOp = rewriter.
create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
211 src, op->getAttrs());
212 newLoadOps.push_back(newLoadOp);
215 return mlir::success();
225 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
227 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
228 rewriter.
create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
229 op.getL2HintAttr(), op.getL3HintAttr());
239 struct WgToSgUpdateNdOffsetOp
243 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
246 for (
auto tDesc : adaptor.getTensorDesc()) {
247 auto newUpdateTileOffsetOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
248 op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
249 op.getConstOffsets());
250 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
262 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
265 VectorType resultTy = op.getResult().getType();
266 if (resultTy.getRank() != 2)
269 auto originalLayout =
270 llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr(
"layout"));
276 for (
auto aVec : adaptor.getLhs()) {
277 for (
auto bVec : adaptor.getRhs()) {
281 tmpC = adaptor.getAcc()[i++];
282 operands.push_back(tmpC);
286 llvm::cast<VectorType>(aVec.getType()).getShape();
288 llvm::cast<VectorType>(bVec.getType()).getShape();
290 resultTy.getElementType());
291 tmpC = rewriter.
create<xegpu::DpasOp>(
292 loc, resTy, operands,
294 {
"layout_result_0", originalLayout.dropSgLayoutAndData()}));
295 newDpasOps.push_back(tmpC);
307 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
309 for (
auto src : adaptor.getTensorDesc())
322 patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
330 struct XeGPUWgToSgDistributePass
331 :
public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
332 void runOnOperation()
override;
336 void XeGPUWgToSgDistributePass::runOnOperation() {
341 auto getTensorDescType = [](
Operation *op) -> xegpu::TensorDescType {
342 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
343 return createOp.getType();
344 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
345 return loadOp.getTensorDescType();
346 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
347 return storeOp.getTensorDescType();
348 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
349 return updateOp.getType();
350 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
351 return prefetchOp.getTensorDescType();
352 return xegpu::TensorDescType();
355 auto isLegal = [&](xegpu::LayoutAttr layout) ->
bool {
356 return !layout || layout.getSgLayout() ==
nullptr;
359 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361 xegpu::PrefetchNdOp>([=](
Operation *op) ->
bool {
362 auto tdescTy = getTensorDescType(op);
363 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
364 return isLegal(layout);
367 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
368 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr(
"layout"));
369 return isLegal(layout);
372 target.markUnknownOpDynamicallyLegal([](
Operation *) {
return true; });
377 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.