MLIR  22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
24 #include <optional>
25 
26 namespace mlir {
27 namespace xegpu {
28 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 } // namespace xegpu
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 // Retrieve the RangeAttr if it is specified.
38 static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39  Operation *parent = op->getParentOfType<scf::IfOp>();
40  while (parent) {
41  if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42  parent->getAttr("sg_id_range")))
43  return attr;
44  parent = parent->getParentOfType<scf::IfOp>();
45  }
46  return {};
47 }
48 
49 static std::pair<SmallVector<int64_t>, int>
50 getSgShapeAndCount(ArrayRef<int64_t> shape,
51  xegpu::DistributeLayoutAttr layout) {
52  int count = 1;
53  SmallVector<int64_t> sgShape(shape);
54  if (layout && layout.isForWorkgroup()) {
55  SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56  if (!layout.getEffectiveSgDataAsInt().empty())
57  sgShape = layout.getEffectiveSgDataAsInt();
58  else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59  sgShape = *maybeDerivedSgData;
60  SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61  // Clamp distUnit to the original shape to handle cases where data is
62  // shared among subgroups, which may cause distUnit to exceed the original
63  // shape.
64  for (size_t i = 0; i < distUnit.size(); ++i)
65  distUnit[i] = std::min(shape[i], distUnit[i]);
66  count = computeProduct(shape) / computeProduct(distUnit);
67  }
68  return std::make_pair(sgShape, count);
69 }
70 
71 /// Utility helper for deriving a list of offsets for each sub-TensorDescs
72 /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73 /// associated distribute layout attribute, the shape, subgroup id and the
74 /// original offsets of the op
75 template <
76  typename OpType,
77  typename = std::enable_if_t<llvm::is_one_of<
78  OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79  xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80 static LogicalResult
81 genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
82  SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
83  Location loc = op.getLoc();
84  SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85  // not applicable to ops without offsets operands.
86  if (origOffsets.empty())
87  return failure();
88 
89  // not applicable to ops without workgroup layout attributes
90  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91  if (!layout || !layout.isForWorkgroup())
92  return failure();
93 
94  Value sgId =
95  gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
96 
97  // verify and adjust the sgId if the range specifier is present
98  xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
99  if (sgIdRange) {
100  int64_t startOfRange = sgIdRange.getStart().getInt();
101  int64_t endOfRange = sgIdRange.getEnd().getInt();
102  // verify the RangeAttr against the layout attribute
103  if (layout.getNumSubgroups() != endOfRange - startOfRange)
104  return rewriter.notifyMatchFailure(
105  op, "sg_layout size must match the sg_id_range");
106  // adjust the sgId if necessary
107  if (startOfRange > 0) {
108  Value startOfRangeVal =
109  arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
110  sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
111  }
112  }
113 
114  // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115  // descriptors to be accessed, based on the layout information.
116  ArrayRef<int64_t> wgShape = op.getDataShape();
117  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
118  if (failed(maybeDescOffsets))
119  return failure();
120 
121  // Compute the final global offsets for each accessed sub-tensor
122  // or sub-memory descriptor.
123  for (const auto &sgOffsets : *maybeDescOffsets) {
125  rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
126  offsetsList.push_back(std::move(newOffsets));
127  }
128 
129  // callback(offsetsList);
130  return success();
131 }
132 
133 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
134 /// from a workgroup descriptor. It replaces the offsets and sizes with
135 /// appropriate values for the subgroup.
136 /// It uses round-robin assignment to distribute the work to the subgroups.
137 /// Following create_nd_desc operation:,
138 /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
139 /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
140 /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
141 /// is converted to 9 subgroup level operations based on the sg_layout &
142 /// sg_data:
143 /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
144 /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
145 /// lane_data = [1, 1]>>
146 ///
147 /// The sg_layout and sg_data attributes are dropped after the pass as they are
148 /// no longer needed.
149 ///
150 /// 24x24 matrix distribution example:
151 /// sg_layout = [4, 4], sg_data = [2, 2]
152 /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
153 /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
154 ///
155 /// +------------------------+
156 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
157 /// |-----+-----+-----|
158 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
159 /// |-----+-----+-----|
160 /// | 8x8 | 8x8 | 8x8 |
161 /// +------------------------+
162 ///
163 /// Each 8x8 tile is further subdivided among subgroups:
164 /// +------------------------+
165 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
166 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
167 /// | 2x2 2x2 2x2 2x2 |
168 /// | 2x2 2x2 2x2 2x2 |
169 /// +------------------------+
170 ///
171 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
172 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
173 
174 /// The pass currently has entire distribution logic in the WgToSgCreateNdOp
175 /// pattern and all the other ops just follow.
176 /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
177 /// ops in the pass.
178 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
180 
181  LogicalResult
182  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const override {
185  if (failed(genOffsetsList(rewriter, op, offsetsList)))
186  return failure();
187 
188  MLIRContext *ctx = op.getContext();
189  xegpu::TensorDescType tdescTy = op.getType();
190  ArrayRef<int64_t> wgShape = tdescTy.getShape();
191  Type elemTy = tdescTy.getElementType();
192  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
193  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
194  auto newTdescTy =
195  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
196  layout.dropSgLayoutAndData());
197 
198  SmallVector<Value> newOps;
199  for (auto offsets : offsetsList) {
200  auto newOp = xegpu::CreateNdDescOp::create(
201  rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
202  op.getMixedSizes(), op.getMixedStrides());
203 
204  newOps.push_back(newOp);
205  }
206  rewriter.replaceOpWithMultiple(op, {newOps});
207 
208  return success();
209  }
210 };
211 
212 // This pattern transforms the CreateNdDescOp without offsets to create a
213 // subgroup descriptor from a workgroup descriptor
214 struct WgToSgCreateNdOpNoOffset
215  : public OpConversionPattern<xegpu::CreateNdDescOp> {
217 
218  LogicalResult
219  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
220  ConversionPatternRewriter &rewriter) const override {
221 
222  // Check no offsets are specified.
223  if (!op.getMixedOffsets().empty())
224  return failure();
225 
226  Location loc = op.getLoc();
227  MLIRContext *ctx = op.getContext();
228  xegpu::TensorDescType tdescTy = op.getType();
229  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
230  if (!layout || !layout.isForWorkgroup())
231  return failure();
232 
233  Type elemTy = tdescTy.getElementType();
234  ArrayRef<int64_t> wgShape = tdescTy.getShape();
235 
236  SmallVector<int64_t> sgShape;
237  int count;
238  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
239  xegpu::TensorDescType newTdescTy =
240  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
241  layout.dropSgLayoutAndData());
242 
243  SmallVector<Value> newCreateNdOps(count);
244  std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
245  return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
246  op.getSource(), op.getMixedSizes(),
247  op.getMixedStrides());
248  });
249 
250  rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
251  return success();
252  }
253 };
254 
255 /// This pattern transforms the LoadNdOp to load subgroup data.
256 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
258  LogicalResult
259  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
260  ConversionPatternRewriter &rewriter) const override {
261  if (!op.getMixedOffsets().empty())
262  return failure();
263 
264  SmallVector<Value> newLoadOps;
265  for (auto src : adaptor.getTensorDesc()) {
266  xegpu::TensorDescType tdescTy =
267  dyn_cast<xegpu::TensorDescType>(src.getType());
268  ArrayRef<int64_t> srcShape = tdescTy.getShape();
269  VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
270  auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
271  src, op->getAttrs());
272  newLoadOps.push_back(newLoadOp);
273  }
274  rewriter.replaceOpWithMultiple(op, {newLoadOps});
275  return mlir::success();
276  }
277 };
278 
279 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
280 /// It creates a StoreNdOp op to store the updated values to the new subgroup
281 /// src tensor descriptors.
282 struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
284  LogicalResult
285  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override {
287  if (!op.getMixedOffsets().empty())
288  return failure();
289 
290  for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
291  xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
292  op.getL2HintAttr(), op.getL3HintAttr());
293 
294  rewriter.eraseOp(op);
295  return success();
296  }
297 };
298 
299 // This pattern transforms the LoadNdOp with explicit offsets to load
300 // subgroup data.
301 struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
303  LogicalResult
304  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
305  ConversionPatternRewriter &rewriter) const override {
306 
308  if (failed(genOffsetsList(rewriter, op, offsetsList)))
309  return failure();
310 
311  SmallVector<Value> newOps;
312  for (auto [tdesc, offsets] :
313  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
314  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
315  VectorType newResTy =
316  VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
317  auto newOp = xegpu::LoadNdOp::create(
318  rewriter, op.getLoc(), newResTy, tdesc, offsets,
319  /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
320  op.getL2HintAttr(), op.getL3HintAttr());
321  newOps.push_back(newOp);
322  }
323  rewriter.replaceOpWithMultiple(op, {newOps});
324 
325  return success();
326  }
327 };
328 
329 // This pattern transforms the StoreNdOp with explicit offsets to store
330 // subgroup data.
331 struct WgToSgStoreNdOpWithOffset
332  : public OpConversionPattern<xegpu::StoreNdOp> {
334  LogicalResult
335  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
336  ConversionPatternRewriter &rewriter) const override {
338  if (failed(genOffsetsList(rewriter, op, offsetsList)))
339  return failure();
340 
341  for (auto [v, tdesc, offsets] :
342  llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
343  xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
344  op.getL1HintAttr(), op.getL2HintAttr(),
345  op.getL3HintAttr());
346  }
347  rewriter.eraseOp(op);
348 
349  return success();
350  }
351 };
352 
353 // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
354 // subgroup data.
355 struct WgToSgPrefetchNdOpWithOffset
356  : public OpConversionPattern<xegpu::PrefetchNdOp> {
358  LogicalResult
359  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
360  ConversionPatternRewriter &rewriter) const override {
362  if (failed(genOffsetsList(rewriter, op, offsetsList)))
363  return failure();
364 
365  for (auto [tdesc, offsets] :
366  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
367  xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
368  op.getL1HintAttr(), op.getL2HintAttr(),
369  op.getL3HintAttr());
370  }
371  rewriter.eraseOp(op);
372 
373  return success();
374  }
375 };
376 
377 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
378 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
379 /// offsets of the new subgroup src tensor descriptors.
380 struct WgToSgUpdateNdOffsetOp
381  : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
383  LogicalResult
384  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
385  ConversionPatternRewriter &rewriter) const override {
386  llvm::SmallVector<Value> newUpdateTileOffsetOps;
387  for (auto tDesc : adaptor.getTensorDesc()) {
388  auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
389  rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
390  op.getConstOffsets());
391  newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
392  }
393 
394  rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
395  return success();
396  }
397 };
398 
399 /// This pattern transforms the DpasOp to work at subgroup level.
400 struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
402  LogicalResult
403  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
404  ConversionPatternRewriter &rewriter) const override {
405  Location loc = op.getLoc();
406  VectorType resultTy = op.getResult().getType();
407  if (resultTy.getRank() != 2)
408  return failure();
409 
410  auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
411  if (!originalLayout)
412  return failure();
413 
414  size_t i = 0;
415  SmallVector<Value> newDpasOps;
416  for (auto aVec : adaptor.getLhs()) {
417  for (auto bVec : adaptor.getRhs()) {
418 
419  llvm::SmallVector<Value> operands({aVec, bVec});
420  Value tmpC;
421  if (op.getAcc()) {
422  tmpC = adaptor.getAcc()[i++];
423  operands.push_back(tmpC);
424  }
425 
426  ArrayRef<int64_t> aVecShape =
427  llvm::cast<VectorType>(aVec.getType()).getShape();
428  ArrayRef<int64_t> bVecShape =
429  llvm::cast<VectorType>(bVec.getType()).getShape();
430  VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
431  resultTy.getElementType());
432  tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
433  xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
434  originalLayout.dropSgLayoutAndData());
435 
436  newDpasOps.push_back(tmpC);
437  }
438  }
439  rewriter.replaceOpWithMultiple(op, {newDpasOps});
440  return success();
441  }
442 };
443 
444 /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
445 struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
447  LogicalResult
448  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
449  ConversionPatternRewriter &rewriter) const override {
450 
451  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
452  if ((offsetSize != 0) || op.getConstOffsetsAttr())
453  return failure();
454 
455  for (auto src : adaptor.getTensorDesc())
456  xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
457  op->getAttrs());
458  rewriter.eraseOp(op);
459  return success();
460  }
461 };
462 
463 /// This pattern transforms vector.broadcast ops to work at subgroup level.
464 struct WgToSgVectorBroadcastOp
465  : public OpConversionPattern<vector::BroadcastOp> {
467 
468  LogicalResult
469  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
470  ConversionPatternRewriter &rewriter) const override {
471 
472  VectorType resultType = op.getResult().getType();
473  ArrayRef<int64_t> wgShape = resultType.getShape();
474 
475  xegpu::DistributeLayoutAttr layout =
476  xegpu::getDistributeLayoutAttr(op.getResult());
477  if (!layout || !layout.isForWorkgroup())
478  return failure();
479 
480  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
481  VectorType newResultType =
482  VectorType::get(sgShape, resultType.getElementType());
483 
484  if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
485  return failure();
486 
487  SmallVector<Value> newBroadcastOps;
488  for (auto operand : adaptor.getOperands().front()) {
489  auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
490  newResultType, operand);
491  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
492  !layout.getEffectiveInstDataAsInt().empty())
493  xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
494  layout.dropSgLayoutAndData());
495 
496  newBroadcastOps.push_back(newBroadcast.getResult());
497  }
498  rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
499  return success();
500  }
501 };
502 
503 // This pattern transforms elementwise ops to work at subgroup level.
504 struct WgToSgElementwiseOp : public ConversionPattern {
505  WgToSgElementwiseOp(MLIRContext *ctx)
506  : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
507 
508  LogicalResult
509  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
510  ConversionPatternRewriter &rewriter) const override {
511  // Only match ops with elementwise trait and single result.
513  return failure();
514 
515  auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
516  assert(resultType && "Expected result to be a VectorType");
517 
518  ArrayRef<int64_t> wgShape = resultType.getShape();
519 
520  xegpu::DistributeLayoutAttr layout =
522  if (!layout || !layout.isForWorkgroup())
523  return failure();
524 
525  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
526 
527  size_t numVariants = operands.empty() ? 0 : operands.front().size();
528 
529  if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
530  return operandVec.size() != numVariants;
531  }))
532  return failure();
533 
534  SmallVector<Value> newResults;
535  VectorType newResultType =
536  VectorType::get(sgShape, resultType.getElementType());
537 
538  for (size_t i = 0; i < numVariants; ++i) {
539  SmallVector<Value> opOperands;
540  for (auto &operandVec : operands)
541  opOperands.push_back(operandVec[i]);
542 
543  OperationState state(op->getLoc(), op->getName());
544  state.addOperands(opOperands);
545  state.addTypes(newResultType);
546  // Copy all attributes, but update "layout_result_0" to drop
547  // sgLayout/sgData
548  for (auto attr : op->getAttrs()) {
549  if (auto layout =
550  dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
551  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
552  !layout.getEffectiveInstDataAsInt().empty())
553  state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
554  } else {
555  state.addAttribute(attr.getName(), attr.getValue());
556  }
557  }
558  Operation *newOp = rewriter.create(state);
559  newResults.push_back(newOp->getResult(0));
560  }
561 
562  rewriter.replaceOpWithMultiple(op, {newResults});
563  return success();
564  }
565 };
566 
567 // clang-format off
568 // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
569 // If input_layout and target_layout have identical sg_layout and sg_data,
570 // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
571 // dropped. For example:
572 // #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
573 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
574 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
575 // becomes:
576 // #a = #xegpu.layout<inst_data = [16, 16]>
577 // #b = #xegpu.layout<inst_data = [8, 16]>
578 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
579 // (vector<16x16xf32> is determined by sg_data = [16, 16])
580 //
581 // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
582 // For example:
583 // #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
584 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
585 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
586 // is lowered to:
587 // #a = #xegpu.layout<inst_data = [16, 16]>
588 // #b = #xegpu.layout<inst_data = [8, 16]>
589 // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
590 // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
591 // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
592 // clang-format on
593 struct WgToSgConvertLayoutOp
594  : public OpConversionPattern<xegpu::ConvertLayoutOp> {
596  LogicalResult
597  matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
598  ConversionPatternRewriter &rewriter) const override {
599  // TODO: currently, we only support LayoutAttr
600  auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
601  auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
602 
603  if (!input || !target || !input.isForWorkgroup() ||
604  !target.isForWorkgroup())
605  return rewriter.notifyMatchFailure(
606  op, "Input and target layouts must have subgroup layout");
607 
608  DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
609  DenseI32ArrayAttr inputSgData = input.getSgData();
610  DenseI32ArrayAttr inputOrder = input.getOrder();
611  DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
612  DenseI32ArrayAttr targetSgData = target.getSgData();
613  DenseI32ArrayAttr targetOrder = target.getOrder();
614 
615  // TODO: currently we only support for optimal case, where input and
616  // output has the same sg_layout and sg_data, so SLM is not involved.
617  if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
618  inputOrder != targetOrder)
619  return failure();
620 
621  input = input.dropSgLayoutAndData();
622  target = target.dropSgLayoutAndData();
623 
624  SmallVector<Value> newOps(adaptor.getSource());
625  if (input && target) {
626  // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
627  for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
628  auto newOp = xegpu::ConvertLayoutOp::create(
629  rewriter, op.getLoc(), src.getType(), src, input, target);
630  newOps[i] = newOp;
631  }
632  }
633  rewriter.replaceOpWithMultiple(op, {newOps});
634  return success();
635  }
636 };
637 
638 // Handles UnrealizedConversionCastOp generated during
639 // SCFStructuralTypeConversions (step 1). This op may appear as either a
640 // target or source materialization for Vector values, e.g.:
641 // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
642 // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
643 // it could be either 1:N or N:1 cast. In both cases, the pattern
644 // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
645 // for example, the following scf::forOp
646 // ```
647 // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
648 // %n = use(%arg1): vector<128x128xf16>
649 // scf.yield %n : vector<128x128xf16>
650 // }
651 // ```
652 // Could be converted to:
653 // ```
654 // %1 = unrealized_conversion_cast %0
655 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
656 // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
657 // -> (vector<16x16xf16>, vector<16x16xf16) {
658 // %m = unrealized_conversion_cast %arg1, %arg2
659 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
660 // %n = use(%m): vector<128x128xf16>
661 // %b = unrealized_conversion_cast %n
662 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
663 // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
664 // }
665 // %cast = unrealized_conversion_cast %for:2
666 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
667 // ```
668 // TODO: remove it when context-aware type converter is ready.
669 struct UnrealizedConversionCastOpPattern
670  : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
671  using OpConversionPattern<
672  mlir::UnrealizedConversionCastOp>::OpConversionPattern;
673 
674  mlir::LogicalResult
675  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
676  ConversionPatternRewriter &rewriter) const override {
677  SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
678 
679  auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
680  auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
681 
682  if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
683  !llvm::all_equal(ValueRange(inputs).getTypes()))
684  return failure();
685 
686  // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
687  // It is generated by source materialization (e.g., inits to scf forOp).
688  // The input values provided by the adaptor should already be distributed,
689  // and their types should correspond exactly to the result types of the
690  // operation.
691  if (op.getNumOperands() == 1 &&
692  llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
693  rewriter.replaceOp(op, inputs);
694  return success();
695  }
696 
697  // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
698  // It is generated by target materialization (e.g., arguments/results
699  // of scf forOp). All input values must have the same vector type, and
700  // their shape must be evenly divisible by the output vector's shape
701  // (determined by the nature of the workgroup to subgroup distribution).
702  // TODO: it is not safe to do such forward, since such N:1 cast could be
703  // from others.
704  if (op.getNumResults() == 1 &&
705  computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
706  rewriter.replaceOpWithMultiple(op, {inputs});
707  return success();
708  }
709 
710  return mlir::failure();
711  }
712 };
713 
714 // This pattern distributes arith.constant op into subgroup-level constants
715 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
717 
718  LogicalResult
719  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
720  ConversionPatternRewriter &rewriter) const override {
721  auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
722  auto vecType = dyn_cast<VectorType>(op.getType());
723  if (!vecAttr || !vecAttr.isSplat() || !vecType)
724  return failure();
725 
726  xegpu::DistributeLayoutAttr layout =
727  xegpu::getDistributeLayoutAttr(op.getResult());
728  if (!layout || !layout.isForWorkgroup())
729  return failure();
730 
731  ArrayRef<int64_t> wgShape = vecType.getShape();
732  SmallVector<int64_t> sgShape;
733  int count;
734  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
735 
736  // Current limitation: constant of vector with single value.
737  // TODO: support more complex cases, e.g., vector with multiple values.
738  Attribute singleVal = vecAttr.getSplatValue<Attribute>();
739 
740  auto newType = VectorType::get(sgShape, vecType.getElementType());
741  auto sgAttr = DenseElementsAttr::get(newType, singleVal);
742  auto cstOp =
743  arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
744  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
745  !layout.getEffectiveInstDataAsInt().empty())
746  xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
747  layout.dropSgLayoutAndData());
748  SmallVector<Value> newConsts(count, cstOp);
749 
750  rewriter.replaceOpWithMultiple(op, {newConsts});
751  return success();
752  }
753 };
754 
755 // This pattern transforms the LoadGatherOp with explicit offsets to load
756 // subgroup data
757 struct WgToSgLoadGatherOpWithOffset
758  : public OpConversionPattern<xegpu::LoadGatherOp> {
760  LogicalResult
761  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
762  ConversionPatternRewriter &rewriter) const override {
763 
764  if (!op.getOffsets())
765  return failure();
766 
767  Location loc = op.getLoc();
768  VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
769  if (!resultType)
770  return failure();
771  ArrayRef<int64_t> wgShape = resultType.getShape();
772 
773  xegpu::DistributeLayoutAttr layout =
774  xegpu::getDistributeLayoutAttr(op.getResult());
775  if (!layout || !layout.isForWorkgroup())
776  return failure();
777 
778  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
779 
780  // The offsets need to be distributed
781  auto offsetsVecType =
782  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
783  auto maskVecType =
784  dyn_cast<VectorType>(adaptor.getMask().front().getType());
785  if (!offsetsVecType || !maskVecType ||
786  offsetsVecType.getShape() != maskVecType.getShape()) {
787  return rewriter.notifyMatchFailure(op,
788  "offsets have not been distributed");
789  }
790 
791  SmallVector<Value> newLoadOps;
792  auto chunkSizeAttr =
793  rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
794  VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
795  for (auto [offsets, mask] :
796  llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
797  auto newLoadOp = xegpu::LoadGatherOp::create(
798  rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
799  op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
800  xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
801  layout.dropSgLayoutAndData());
802  newLoadOps.push_back(newLoadOp);
803  }
804  rewriter.replaceOpWithMultiple(op, {newLoadOps});
805  return success();
806  }
807 };
808 
809 // This pattern transforms the StoreScatterOp with explicit offsets to store
810 // subgroup data
811 struct WgToSgStoreScatterOpWithOffset
812  : public OpConversionPattern<xegpu::StoreScatterOp> {
814  LogicalResult
815  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
816  ConversionPatternRewriter &rewriter) const override {
817 
818  if (!op.getOffsets())
819  return failure();
820 
821  Location loc = op.getLoc();
822  VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
823  if (!valueType)
824  return failure();
825 
826  xegpu::DistributeLayoutAttr layout =
827  xegpu::getDistributeLayoutAttr(op.getValue());
828  if (!layout || !layout.isForWorkgroup())
829  return failure();
830 
831  // The offsets need to be distributed
832  auto offsetsVecType =
833  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
834  auto maskVecType =
835  dyn_cast<VectorType>(adaptor.getMask().front().getType());
836  if (!offsetsVecType || !maskVecType ||
837  offsetsVecType.getShape() != maskVecType.getShape()) {
838  return rewriter.notifyMatchFailure(op,
839  "offsets have not been distributed");
840  }
841 
842  auto chunkSizeOpt = op.getChunkSize();
843  int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
844  auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
845  for (auto [val, offs, mask] : llvm::zip(
846  adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
847  xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
848  mask, chunkSizeAttr, op.getL1HintAttr(),
849  op.getL2HintAttr(), op.getL3HintAttr());
850  // Update the layout attribute to drop sg_layout and sg_data.
851  if (auto newLayout = layout.dropSgLayoutAndData())
852  op->setAttr("layout", newLayout);
853  }
854  rewriter.eraseOp(op);
855  return success();
856  }
857 };
858 
859 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
861  LogicalResult
862  matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const override {
864 
866  if (failed(genOffsetsList(rewriter, op, offsetsList)))
867  return failure();
868 
869  ArrayRef<int64_t> wgShape = op.getDataShape();
870  VectorType valueTy = op.getRes().getType();
871  Type elemTy = valueTy.getElementType();
872 
873  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
874  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
875  VectorType newResTy = VectorType::get(sgShape, elemTy);
876  SmallVector<Value> newOps;
877  for (auto offsets : offsetsList) {
878  auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
879  op.getMemDesc(), offsets,
880  layout.dropSgLayoutAndData());
881  newOps.push_back(newOp);
882  }
883  rewriter.replaceOpWithMultiple(op, {newOps});
884 
885  return success();
886  }
887 };
888 
889 struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
891  LogicalResult
892  matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
893  ConversionPatternRewriter &rewriter) const override {
894 
896  if (failed(genOffsetsList(rewriter, op, offsetsList)))
897  return failure();
898 
899  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
900  for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
901  xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
902  offsets, layout.dropSgLayoutAndData());
903  rewriter.eraseOp(op);
904  return success();
905  }
906 };
907 
908 // This pattern distributes the vector.step ops to work at subgroup level
909 struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
911  LogicalResult
912  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
913  ConversionPatternRewriter &rewriter) const override {
914  xegpu::DistributeLayoutAttr layout =
915  xegpu::getDistributeLayoutAttr(op.getResult());
916  if (!layout || !layout.isForWorkgroup())
917  return failure();
918 
919  Location loc = op.getLoc();
920  VectorType type = op.getResult().getType();
921  auto wgShape = type.getShape();
922  std::optional<SmallVector<int64_t>> sgShape =
923  getSgShapeAndCount(wgShape, layout).first;
924  if (!sgShape)
925  return failure();
926 
927  Value sgId =
928  gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
929  auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
930  if (failed(sgOffsets))
931  return failure();
932 
933  VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
934  auto steps = vector::StepOp::create(rewriter, loc, newTy);
935  SmallVector<Value> newOps;
936  for (auto offsets : *sgOffsets) {
937  // Broadcast the offset scalar to a vector & add to the base steps
938  auto bcastOffset =
939  vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
940  auto finalSteps =
941  arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
942  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
943  !layout.getEffectiveInstDataAsInt().empty()) {
944  xegpu::setDistributeLayoutAttr(steps->getResult(0),
945  layout.dropSgLayoutAndData());
946  xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
947  layout.dropSgLayoutAndData());
948  xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
949  layout.dropSgLayoutAndData());
950  }
951  newOps.push_back(finalSteps);
952  }
953 
954  rewriter.replaceOpWithMultiple(op, {newOps});
955  return success();
956  }
957 };
958 
959 // This pattern transforms vector.shape_cast ops to work at subgroup level.
960 struct WgToSgVectorShapeCastOp
961  : public OpConversionPattern<vector::ShapeCastOp> {
963 
964  LogicalResult
965  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
966  ConversionPatternRewriter &rewriter) const override {
967 
968  VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
969  if (!resultType)
970  return failure();
971 
972  ArrayRef<int64_t> wgShape = resultType.getShape();
973  xegpu::DistributeLayoutAttr layout =
974  xegpu::getDistributeLayoutAttr(op.getResult());
975  if (!layout || !layout.isForWorkgroup())
976  return failure();
977 
978  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
979  VectorType newResultType =
980  VectorType::get(sgShape, resultType.getElementType());
981 
982  // TODO: Add check for compatible layouts in layout attr.
983  auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
984  if (!srcType)
985  return failure();
986 
987  // Check that shape_cast only adds/removes unit dimensions,
988  auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
989  // Remove all 1s from both shapes and compare the rest.
990  SmallVector<int64_t> srcNonUnit, dstNonUnit;
991  for (int64_t d : src)
992  if (d != 1)
993  srcNonUnit.push_back(d);
994  for (int64_t d : dst)
995  if (d != 1)
996  dstNonUnit.push_back(d);
997  return srcNonUnit == dstNonUnit;
998  };
999 
1000  if (!onlyUnitDims(srcType.getShape(), sgShape))
1001  return failure();
1002 
1003  // For rank reducing or increasing shape_cast ops, the lower rank layout
1004  // must be a slice of higher rank layout.
1005  int64_t sourceRank = srcType.getRank();
1006  int64_t resultRank = sgShape.size();
1007  xegpu::DistributeLayoutAttr sourceLayout =
1008  xegpu::getDistributeLayoutAttr(op.getSource());
1009  if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1010  return failure();
1011  if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1012  return failure();
1013 
1014  SmallVector<Value> newShapeCastOps;
1015  for (auto src : adaptor.getSource()) {
1016  auto newShapeCast =
1017  rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
1018  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1019  !layout.getEffectiveInstDataAsInt().empty())
1020  xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1021  layout.dropSgLayoutAndData());
1022  newShapeCastOps.push_back(newShapeCast.getResult());
1023  }
1024 
1025  rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1026  return success();
1027  }
1028 };
1029 
1030 /// Pattern for lowering vector.multi_reduction op to subgroup level.
1031 /// Current limitation: the sg_layout in the reduced dimension being 1
1032 /// so that reduction is local to subgroup & no cross-subgroup communication is
1033 /// needed.
1034 /// TODO: Add cases to handle more general situations which require SLM access.
1035 struct WgToSgMultiDimReductionOp
1036  : public OpConversionPattern<vector::MultiDimReductionOp> {
1038 
1039  LogicalResult
1040  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1041  ConversionPatternRewriter &rewriter) const override {
1042  VectorType srcType = op.getSourceVectorType();
1043  VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1044  if (!dstType)
1045  return failure();
1046 
1047  auto srcShape = srcType.getShape();
1048  xegpu::DistributeLayoutAttr layout =
1049  xegpu::getDistributeLayoutAttr(op.getResult());
1050  if (!layout || !layout.isForWorkgroup())
1051  return failure();
1052 
1053  auto reductionDims = llvm::to_vector(op.getReductionDims());
1054 
1055  SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1056  .getParent()
1057  .getEffectiveSgLayoutAsInt();
1058  SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1059  .getParent()
1060  .getEffectiveSgDataAsInt();
1061 
1062  // Check that the sgLayout in the reduced dimension is 1 and
1063  // each sg gets the entire slice to reduce.
1064  for (int64_t dim : reductionDims) {
1065  if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1066  return rewriter.notifyMatchFailure(
1067  op,
1068  "sgLayout in each reduced dimension must be 1 and sgData in the "
1069  "reduced dim must match srcShape in that dim");
1070  }
1071 
1072  SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1073 
1074  VectorType newDstType =
1075  VectorType::get({sgShape}, dstType.getElementType());
1076 
1077  SmallVector<Value> newReductions;
1078  for (auto sgSrc : adaptor.getSource()) {
1079  auto newOp = rewriter.create<vector::MultiDimReductionOp>(
1080  op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
1081  op.getReductionDims());
1082  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1083  !layout.getEffectiveInstDataAsInt().empty())
1085  layout.dropSgLayoutAndData());
1086  newReductions.push_back(newOp.getResult());
1087  }
1088 
1089  rewriter.replaceOpWithMultiple(op, {newReductions});
1090  return success();
1091  }
1092 };
1093 
1094 } // namespace
1095 
1096 namespace mlir {
1097 namespace xegpu {
1099  patterns
1100  .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1101  WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1102  WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1103  WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1104  WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1105  WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1106  WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1107  WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1108  WgToSgMultiDimReductionOp>(patterns.getContext());
1109 }
1110 } // namespace xegpu
1111 } // namespace mlir
1112 
1113 namespace {
1114 struct XeGPUWgToSgDistributePass
1115  : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1116  void runOnOperation() override;
1117 };
1118 } // namespace
1119 
1120 void XeGPUWgToSgDistributePass::runOnOperation() {
1121  // Track existing UnrealizedConversionCastOps
1122  SmallVector<Operation *> existingCastOps;
1123  getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1124  existingCastOps.push_back(castOp.getOperation());
1125  });
1126 
1127  {
1128  // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1129  // VectorType operands. This first converts such operands to
1130  // RankedTensorType, propagates the layout attribute into the encoding
1131  // attribute, and finally converts the RankedTensorType to VectorType based
1132  // on the encoding.
1133 
1134  TypeConverter converter;
1135  converter.addConversion([&](Type type) -> Type { return type; });
1136  converter.addConversion(
1137  [&](RankedTensorType type,
1138  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1139  Type elemTy = type.getElementType();
1140  ArrayRef<int64_t> shape = type.getShape();
1141 
1142  int count;
1143  SmallVector<int64_t> subShape;
1144  std::tie(subShape, count) = getSgShapeAndCount(
1145  shape,
1146  dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1147 
1148  auto newTy = VectorType::get(subShape, elemTy);
1149  result.append(count, newTy);
1150  return success();
1151  });
1152 
1154  converter);
1155  }
1156 
1157  // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1158  // as well as XeGPU, Arith, and Vector operations.
1159  MLIRContext *ctx = &getContext();
1161  ConversionTarget target(*ctx);
1162  TypeConverter converter;
1163  converter.addConversion([&](Type type) -> Type { return type; });
1164  converter.addConversion(
1165  [&](xegpu::TensorDescType type,
1166  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1167  Type elemTy = type.getElementType();
1168  ArrayRef<int64_t> shape = type.getShape();
1169 
1170  int count;
1171  SmallVector<int64_t> subShape;
1172  xegpu::LayoutAttr layout = type.getLayoutAttr();
1173  std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1174 
1175  if (layout)
1176  layout = layout.dropSgLayoutAndData();
1177 
1178  auto newTy = xegpu::TensorDescType::get(
1179  type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1180  result.append(count, newTy);
1181  return success();
1182  });
1183 
1184  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1185  if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1186  return createOp.getType();
1187  if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1188  return loadOp.getTensorDescType();
1189  if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1190  return storeOp.getTensorDescType();
1191  if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1192  return updateOp.getType();
1193  if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1194  return prefetchOp.getTensorDescType();
1195  return xegpu::TensorDescType();
1196  };
1197 
1198  auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1199  return !layout || !layout.isForWorkgroup();
1200  };
1201 
1202  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1203  xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1204  xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1205  auto tdescTy = getTensorDescType(op);
1206  auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1207  return isLegal(layout);
1208  });
1209 
1210  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1211  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1212  return isLegal(layout);
1213  });
1214 
1215  target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1216  [=](xegpu::LoadMatrixOp op) -> bool {
1217  return isLegal(op.getLayoutAttr());
1218  });
1219 
1220  target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1221  [=](xegpu::StoreMatrixOp op) -> bool {
1222  return isLegal(op.getLayoutAttr());
1223  });
1224 
1225  target.addDynamicallyLegalOp<arith::ConstantOp>(
1226  [=](arith::ConstantOp op) -> bool {
1227  auto vecType = dyn_cast<VectorType>(op.getType());
1228  if (!vecType)
1229  return true;
1230 
1231  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1232  return isLegal(layout);
1233  });
1234 
1235  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
1236  [=](Operation *op) -> bool {
1237  // Check for either a SliceAttr or LayoutAttr on the result.
1238  auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1239  return isLegal(layout);
1240  });
1241 
1242  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1243  [=](xegpu::LoadGatherOp op) -> bool {
1244  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1245  return isLegal(layout);
1246  });
1247 
1248  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1249  [=](xegpu::StoreScatterOp op) -> bool {
1250  // Check if the layout attribute is present on the result.
1251  auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
1252  if (!layout)
1253  return true;
1254  return isLegal(layout);
1255  });
1256 
1257  target.addDynamicallyLegalOp<vector::BroadcastOp>(
1258  [=](vector::BroadcastOp op) -> bool {
1259  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1260  });
1261 
1262  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1263  [=](vector::MultiDimReductionOp op) -> bool {
1264  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1265  });
1266 
1267  target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1268  [=](xegpu::ConvertLayoutOp op) -> bool {
1269  return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1270  });
1271 
1272  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1273  [=](Operation *op) -> std::optional<bool> {
1274  // Only handle elementwise mappable ops
1276  return true;
1277 
1278  VectorType resultType =
1279  dyn_cast<VectorType>(op->getResult(0).getType());
1280  if (!resultType)
1281  return true;
1282 
1283  // Check if all operands are vectors of the same shape
1284  // TODO: Support other types.
1285  for (Value operand : op->getOperands()) {
1286  VectorType operandType = dyn_cast<VectorType>(operand.getType());
1287  if (!operandType || operandType.getShape() != resultType.getShape()) {
1288  return true;
1289  }
1290  }
1291 
1292  xegpu::DistributeLayoutAttr layout =
1293  xegpu::getDistributeLayoutAttr(op->getResult(0));
1294  return isLegal(layout);
1295  });
1296 
1297  target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1298  [=](UnrealizedConversionCastOp op) {
1299  return llvm::is_contained(existingCastOps, op.getOperation());
1300  });
1301 
1302  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1303 
1305  target);
1307  if (failed(
1308  applyPartialConversion(getOperation(), target, std::move(patterns))))
1309  return signalPassFailure();
1310 
1311  // Remove sg_layout and sg_data attributes from the Layout
1312  // attribute for each VectorType result of the operation.
1313  // For Structured Control Flow ops, the layout is simply removed,
1314  // since in 1:N case, the layout for new results are missing.
1315  // Layout propagation pass will activated.
1316  getOperation()->walk([](Operation *op) {
1317  for (OpResult result : op->getOpResults()) {
1318  std::string name = xegpu::getLayoutName(result);
1319  if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1320  op->removeAttr(name);
1321  if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1322  if (auto newLayout = layout.dropSgLayoutAndData())
1323  op->setAttr(name, newLayout);
1324  }
1325  }
1326  }
1327  });
1328 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getOpResults()
Definition: Operation.h:420
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:179
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:285
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:117
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Definition: XeGPUUtils.cpp:476
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.