MLIR  21.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
20 
21 namespace mlir {
22 namespace xegpu {
23 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
24 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
25 } // namespace xegpu
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
33 /// from a workgroup descriptor. It replaces the offsets and sizes with
34 /// appropriate values for the subgroup.
35 /// It uses round-robin assignment to distribute the work to the subgroups.
36 /// Following create_nd_desc operation:,
37 /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
38 /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
39 /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
40 /// is converted to 9 subgroup level operations based on the sg_layout &
41 /// sg_data:
42 /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
43 /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
44 /// lane_data = [1, 1]>>
45 ///
46 /// The sg_layout and sg_data attributes are dropped after the pass as they are
47 /// no longer needed.
48 ///
49 /// 24x24 matrix distribution example:
50 /// sg_layout = [4, 4], sg_data = [2, 2]
51 /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
52 /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
53 ///
54 /// +------------------------+
55 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
56 /// |-----+-----+-----|
57 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
58 /// |-----+-----+-----|
59 /// | 8x8 | 8x8 | 8x8 |
60 /// +------------------------+
61 ///
62 /// Each 8x8 tile is further subdivided among subgroups:
63 /// +------------------------+
64 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
65 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
66 /// | 2x2 2x2 2x2 2x2 |
67 /// | 2x2 2x2 2x2 2x2 |
68 /// +------------------------+
69 ///
70 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
71 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
72 
73 /// The pass currently has entire distribution logic in the WgToSgCreateNdOp
74 /// pattern and all the other ops just follow.
75 /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
76 /// ops in the pass.
77 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
79 
80  // Calculate offset for each subgroup
82  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
83  const SmallVector<OpFoldResult> &originalOffsets,
84  const SmallVector<Value> &localOffset,
85  const SmallVector<int64_t> &distUnitBaseAddr,
86  const SmallVector<int64_t> &distUnitShape) const {
87  assert(localOffset.size() == distUnitBaseAddr.size() &&
88  "localOffset and distUnitBaseAddr must have the same rank");
89 
90  SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
91  originalOffsets.end());
92  size_t rank = localOffset.size();
93  for (size_t i = 0; i < rank; ++i) {
94  size_t dimIdx = originalOffsets.size() - rank + i;
95  Value constOffset =
96  rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
97  Value offset =
98  rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
99  Value modValue =
100  rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
101  Value offsetMod =
102  rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
104  rewriter, loc, originalOffsets[dimIdx]);
105  Value globalOffset =
106  rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
107  globalOffsets[dimIdx] = globalOffset;
108  }
109 
110  return globalOffsets;
111  }
112 
113  LogicalResult
114  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
115  ConversionPatternRewriter &rewriter) const override {
116  Location loc = op.getLoc();
117  MLIRContext *ctx = op.getContext();
118  xegpu::TensorDescType tdescTy = op.getType();
119  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
120  if (!layout)
121  return failure();
122  Type elemTy = tdescTy.getElementType();
123  ArrayRef<int64_t> wgShape = tdescTy.getShape();
124  // sgLayout must be present for workgroup-level distribution.
125  SmallVector<int64_t> sgLayout;
126  if (auto sgLayoutAttr = layout.getSgLayout())
127  sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
128  else
129  return rewriter.notifyMatchFailure(
130  op, "sgLayout attribute is required in layout");
131 
132  SmallVector<int64_t> sgShape;
133  if (auto sgDataAttr = layout.getSgData()) {
134  sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
135  } else {
136  assert(wgShape.size() == sgLayout.size() &&
137  "sgLayout and wgShape must have the same rank");
138  sgShape.reserve(wgShape.size());
139  for (size_t i = 0; i < wgShape.size(); ++i) {
140  assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
141  sgShape.push_back(wgShape[i] / sgLayout[i]);
142  }
143  }
144 
145  // TODO : Handle order attribute
146  // Get the subgroup ID
147  auto linearSgId =
148  rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
149 
150  // Create constants for layout dimensions
151  SmallVector<Value> sgLayoutDim(sgLayout.size());
152  SmallVector<Value> sgDataDim(sgShape.size());
153 
154  for (size_t i = 0; i < sgLayout.size(); i++) {
155  sgLayoutDim[i] =
156  rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]);
157  sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
158  }
159 
160  auto deLinearizeSgId =
161  affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
162  if (failed(deLinearizeSgId))
163  return failure();
164  SmallVector<Value> sgIds = *deLinearizeSgId;
165 
166  // Calculate distribution unit shape and local offsets for subgroup
167  SmallVector<int64_t> distUnitShape(sgLayout.size());
168  SmallVector<Value> localOffset(sgLayout.size());
169  for (size_t i = 0; i < sgLayout.size(); i++) {
170  distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
171  localOffset[i] =
172  rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
173  }
174 
175  SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
176 
177  xegpu::TensorDescType newTdescTy =
178  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
179  layout.dropSgLayoutAndData());
180  SmallVector<Value> newCreateNdOps;
181  for (SmallVector<int64_t> distUnitBaseAddr :
182  StaticTileOffsetRange(wgShape, distUnitShape)) {
183  SmallVector<OpFoldResult> globalOffsets =
184  calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
185  distUnitBaseAddr, distUnitShape);
186 
187  auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
188  loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
189  op.getMixedStrides());
190  newCreateNdOps.push_back(newCreateNdOp);
191  }
192 
193  rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
194  return success();
195  }
196 };
197 
198 /// This pattern transforms the LoadNdOp to load subgroup data.
199 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
201  LogicalResult
202  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override {
204  SmallVector<Value> newLoadOps;
205  for (auto src : adaptor.getTensorDesc()) {
206  xegpu::TensorDescType tdescTy =
207  dyn_cast<xegpu::TensorDescType>(src.getType());
208  ArrayRef<int64_t> srcShape = tdescTy.getShape();
209  VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
210  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
211  src, op->getAttrs());
212  newLoadOps.push_back(newLoadOp);
213  }
214  rewriter.replaceOpWithMultiple(op, {newLoadOps});
215  return mlir::success();
216  }
217 };
218 
219 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
220 /// It creates a StoreNdOp op to store the updated values to the new subgroup
221 /// src tensor descriptors.
222 struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
224  LogicalResult
225  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override {
227  for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
228  rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
229  op.getL2HintAttr(), op.getL3HintAttr());
230 
231  rewriter.eraseOp(op);
232  return success();
233  }
234 };
235 
236 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
237 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
238 /// offsets of the new subgroup src tensor descriptors.
239 struct WgToSgUpdateNdOffsetOp
240  : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
242  LogicalResult
243  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
244  ConversionPatternRewriter &rewriter) const override {
245  llvm::SmallVector<Value> newUpdateTileOffsetOps;
246  for (auto tDesc : adaptor.getTensorDesc()) {
247  auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
248  op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
249  op.getConstOffsets());
250  newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
251  }
252 
253  rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
254  return success();
255  }
256 };
257 
258 /// This pattern transforms the DpasOp to work at subgroup level.
259 struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
261  LogicalResult
262  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const override {
264  Location loc = op.getLoc();
265  VectorType resultTy = op.getResult().getType();
266  if (resultTy.getRank() != 2)
267  return failure();
268 
269  auto originalLayout =
270  llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
271  if (!originalLayout)
272  return failure();
273 
274  SmallVector<Value> newDpasOps;
275  size_t i = 0;
276  for (auto aVec : adaptor.getLhs()) {
277  for (auto bVec : adaptor.getRhs()) {
278  llvm::SmallVector<Value> operands({aVec, bVec});
279  Value tmpC;
280  if (op.getAcc()) {
281  tmpC = adaptor.getAcc()[i++];
282  operands.push_back(tmpC);
283  }
284 
285  ArrayRef<int64_t> aVecShape =
286  llvm::cast<VectorType>(aVec.getType()).getShape();
287  ArrayRef<int64_t> bVecShape =
288  llvm::cast<VectorType>(bVec.getType()).getShape();
289  VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
290  resultTy.getElementType());
291  tmpC = rewriter.create<xegpu::DpasOp>(
292  loc, resTy, operands,
294  {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
295  newDpasOps.push_back(tmpC);
296  }
297  }
298  rewriter.replaceOpWithMultiple(op, {newDpasOps});
299  return success();
300  }
301 };
302 
303 /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
304 struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
306  LogicalResult
307  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
308  ConversionPatternRewriter &rewriter) const override {
309  for (auto src : adaptor.getTensorDesc())
310  rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
311  op->getAttrs());
312  rewriter.eraseOp(op);
313  return success();
314  }
315 };
316 
317 } // namespace
318 
319 namespace mlir {
320 namespace xegpu {
322  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323  WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324  patterns.getContext());
325 }
326 } // namespace xegpu
327 } // namespace mlir
328 
329 namespace {
330 struct XeGPUWgToSgDistributePass
331  : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
332  void runOnOperation() override;
333 };
334 } // namespace
335 
336 void XeGPUWgToSgDistributePass::runOnOperation() {
337  MLIRContext *ctx = &getContext();
339  ConversionTarget target(*ctx);
340 
341  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
342  if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
343  return createOp.getType();
344  if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
345  return loadOp.getTensorDescType();
346  if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
347  return storeOp.getTensorDescType();
348  if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
349  return updateOp.getType();
350  if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
351  return prefetchOp.getTensorDescType();
352  return xegpu::TensorDescType();
353  };
354 
355  auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
356  return !layout || layout.getSgLayout() == nullptr;
357  };
358 
359  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360  xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361  xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
362  auto tdescTy = getTensorDescType(op);
363  auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
364  return isLegal(layout);
365  });
366 
367  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
368  auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
369  return isLegal(layout);
370  });
371 
372  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
373 
375  if (failed(
376  applyPartialConversion(getOperation(), target, std::move(patterns))))
377  return signalPassFailure();
378 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1978
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.