MLIR  22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
24 #include <optional>
25 
26 namespace mlir {
27 namespace xegpu {
28 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 } // namespace xegpu
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 // Retrieve the RangeAttr if it is specified.
38 static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39  Operation *parent = op->getParentOfType<scf::IfOp>();
40  while (parent) {
41  if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42  parent->getAttr("sg_id_range")))
43  return attr;
44  parent = parent->getParentOfType<scf::IfOp>();
45  }
46  return {};
47 }
48 
49 static std::pair<SmallVector<int64_t>, int>
50 getSgShapeAndCount(ArrayRef<int64_t> shape,
51  xegpu::DistributeLayoutAttr layout) {
52  int count = 1;
53  SmallVector<int64_t> sgShape(shape);
54  if (layout && layout.isForWorkgroup()) {
55  SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
56  if (!layout.getEffectiveSgDataAsInt().empty())
57  sgShape = layout.getEffectiveSgDataAsInt();
58  else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59  sgShape = *maybeDerivedSgData;
60  SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61  // Clamp distUnit to the original shape to handle cases where data is
62  // shared among subgroups, which may cause distUnit to exceed the original
63  // shape.
64  for (size_t i = 0; i < distUnit.size(); ++i)
65  distUnit[i] = std::min(shape[i], distUnit[i]);
66  count = computeProduct(shape) / computeProduct(distUnit);
67  }
68  return std::make_pair(sgShape, count);
69 }
70 
71 /// Utility helper for deriving a list of offsets for each sub-TensorDescs
72 /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73 /// associated distribute layout attribute, the shape, subgroup id and the
74 /// original offsets of the op
75 template <
76  typename OpType,
77  typename = std::enable_if_t<llvm::is_one_of<
78  OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79  xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80 static LogicalResult
81 genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
82  SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
83  Location loc = op.getLoc();
84  SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85  // not applicable to ops without offsets operands.
86  if (origOffsets.empty())
87  return failure();
88 
89  // not applicable to ops without workgroup layout attributes
90  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91  if (!layout || !layout.isForWorkgroup())
92  return failure();
93 
94  Value sgId =
95  gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
96 
97  // verify and adjust the sgId if the range specifier is present
98  xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
99  if (sgIdRange) {
100  int64_t startOfRange = sgIdRange.getStart().getInt();
101  int64_t endOfRange = sgIdRange.getEnd().getInt();
102  // verify the RangeAttr against the layout attribute
103  if (layout.getNumSubgroups() != endOfRange - startOfRange)
104  return rewriter.notifyMatchFailure(
105  op, "sg_layout size must match the sg_id_range");
106  // adjust the sgId if necessary
107  if (startOfRange > 0) {
108  Value startOfRangeVal =
109  arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
110  sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
111  }
112  }
113 
114  // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115  // descriptors to be accessed, based on the layout information.
116  ArrayRef<int64_t> wgShape = op.getDataShape();
117  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
118  if (failed(maybeDescOffsets))
119  return failure();
120 
121  // Compute the final global offsets for each accessed sub-tensor
122  // or sub-memory descriptor.
123  for (const auto &sgOffsets : *maybeDescOffsets) {
125  rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
126  offsetsList.push_back(std::move(newOffsets));
127  }
128 
129  // callback(offsetsList);
130  return success();
131 }
132 
133 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
134 /// from a workgroup descriptor. It replaces the offsets and sizes with
135 /// appropriate values for the subgroup.
136 /// It uses round-robin assignment to distribute the work to the subgroups.
137 /// Following create_nd_desc operation:,
138 /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
139 /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
140 /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
141 /// is converted to 9 subgroup level operations based on the sg_layout &
142 /// sg_data:
143 /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
144 /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
145 /// lane_data = [1, 1]>>
146 ///
147 /// The sg_layout and sg_data attributes are dropped after the pass as they are
148 /// no longer needed.
149 ///
150 /// 24x24 matrix distribution example:
151 /// sg_layout = [4, 4], sg_data = [2, 2]
152 /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
153 /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
154 ///
155 /// +------------------------+
156 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
157 /// |-----+-----+-----|
158 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
159 /// |-----+-----+-----|
160 /// | 8x8 | 8x8 | 8x8 |
161 /// +------------------------+
162 ///
163 /// Each 8x8 tile is further subdivided among subgroups:
164 /// +------------------------+
165 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
166 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
167 /// | 2x2 2x2 2x2 2x2 |
168 /// | 2x2 2x2 2x2 2x2 |
169 /// +------------------------+
170 ///
171 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
172 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
173 
174 /// The pass currently has entire distribution logic in the WgToSgCreateNdOp
175 /// pattern and all the other ops just follow.
176 /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
177 /// ops in the pass.
178 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
180 
181  LogicalResult
182  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const override {
185  if (failed(genOffsetsList(rewriter, op, offsetsList)))
186  return failure();
187 
188  MLIRContext *ctx = op.getContext();
189  xegpu::TensorDescType tdescTy = op.getType();
190  ArrayRef<int64_t> wgShape = tdescTy.getShape();
191  Type elemTy = tdescTy.getElementType();
192  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
193  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
194  auto newTdescTy =
195  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
196  layout.dropSgLayoutAndData());
197 
198  SmallVector<Value> newOps;
199  for (auto offsets : offsetsList) {
200  auto newOp = xegpu::CreateNdDescOp::create(
201  rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
202  op.getMixedSizes(), op.getMixedStrides());
203 
204  newOps.push_back(newOp);
205  }
206  rewriter.replaceOpWithMultiple(op, {newOps});
207 
208  return success();
209  }
210 };
211 
212 // This pattern transforms the CreateNdDescOp without offsets to create a
213 // subgroup descriptor from a workgroup descriptor
214 struct WgToSgCreateNdOpNoOffset
215  : public OpConversionPattern<xegpu::CreateNdDescOp> {
217 
218  LogicalResult
219  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
220  ConversionPatternRewriter &rewriter) const override {
221 
222  // Check no offsets are specified.
223  if (!op.getMixedOffsets().empty())
224  return failure();
225 
226  Location loc = op.getLoc();
227  MLIRContext *ctx = op.getContext();
228  xegpu::TensorDescType tdescTy = op.getType();
229  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
230  if (!layout || !layout.isForWorkgroup())
231  return failure();
232 
233  Type elemTy = tdescTy.getElementType();
234  ArrayRef<int64_t> wgShape = tdescTy.getShape();
235 
236  SmallVector<int64_t> sgShape;
237  int count;
238  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
239  xegpu::TensorDescType newTdescTy =
240  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
241  layout.dropSgLayoutAndData());
242 
243  SmallVector<Value> newCreateNdOps(count);
244  std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
245  return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
246  op.getSource(), op.getMixedSizes(),
247  op.getMixedStrides());
248  });
249 
250  rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
251  return success();
252  }
253 };
254 
255 /// This pattern transforms the LoadNdOp to load subgroup data.
256 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
258  LogicalResult
259  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
260  ConversionPatternRewriter &rewriter) const override {
261  if (!op.getMixedOffsets().empty())
262  return failure();
263 
264  SmallVector<Value> newLoadOps;
265  for (auto src : adaptor.getTensorDesc()) {
266  xegpu::TensorDescType tdescTy =
267  dyn_cast<xegpu::TensorDescType>(src.getType());
268  ArrayRef<int64_t> srcShape = tdescTy.getShape();
269  VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
270  auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
271  src, op->getAttrs());
272  newLoadOps.push_back(newLoadOp);
273  }
274  rewriter.replaceOpWithMultiple(op, {newLoadOps});
275  return mlir::success();
276  }
277 };
278 
279 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
280 /// It creates a StoreNdOp op to store the updated values to the new subgroup
281 /// src tensor descriptors.
282 struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
284  LogicalResult
285  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override {
287  if (!op.getMixedOffsets().empty())
288  return failure();
289 
290  for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
291  xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
292  op.getL2HintAttr(), op.getL3HintAttr());
293 
294  rewriter.eraseOp(op);
295  return success();
296  }
297 };
298 
299 // This pattern transforms the LoadNdOp with explicit offsets to load
300 // subgroup data.
301 struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
303  LogicalResult
304  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
305  ConversionPatternRewriter &rewriter) const override {
306 
308  if (failed(genOffsetsList(rewriter, op, offsetsList)))
309  return failure();
310 
311  SmallVector<Value> newOps;
312  for (auto [tdesc, offsets] :
313  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
314  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
315  VectorType newResTy =
316  VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
317  auto newOp = xegpu::LoadNdOp::create(
318  rewriter, op.getLoc(), newResTy, tdesc, offsets,
319  /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
320  op.getL2HintAttr(), op.getL3HintAttr());
321  newOps.push_back(newOp);
322  }
323  rewriter.replaceOpWithMultiple(op, {newOps});
324 
325  return success();
326  }
327 };
328 
329 // This pattern transforms the StoreNdOp with explicit offsets to store
330 // subgroup data.
331 struct WgToSgStoreNdOpWithOffset
332  : public OpConversionPattern<xegpu::StoreNdOp> {
334  LogicalResult
335  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
336  ConversionPatternRewriter &rewriter) const override {
338  if (failed(genOffsetsList(rewriter, op, offsetsList)))
339  return failure();
340 
341  for (auto [v, tdesc, offsets] :
342  llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
343  xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
344  op.getL1HintAttr(), op.getL2HintAttr(),
345  op.getL3HintAttr());
346  }
347  rewriter.eraseOp(op);
348 
349  return success();
350  }
351 };
352 
353 // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
354 // subgroup data.
355 struct WgToSgPrefetchNdOpWithOffset
356  : public OpConversionPattern<xegpu::PrefetchNdOp> {
358  LogicalResult
359  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
360  ConversionPatternRewriter &rewriter) const override {
362  if (failed(genOffsetsList(rewriter, op, offsetsList)))
363  return failure();
364 
365  for (auto [tdesc, offsets] :
366  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
367  xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
368  op.getL1HintAttr(), op.getL2HintAttr(),
369  op.getL3HintAttr());
370  }
371  rewriter.eraseOp(op);
372 
373  return success();
374  }
375 };
376 
377 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
378 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
379 /// offsets of the new subgroup src tensor descriptors.
380 struct WgToSgUpdateNdOffsetOp
381  : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
383  LogicalResult
384  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
385  ConversionPatternRewriter &rewriter) const override {
386  llvm::SmallVector<Value> newUpdateTileOffsetOps;
387  for (auto tDesc : adaptor.getTensorDesc()) {
388  auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
389  rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
390  op.getConstOffsets());
391  newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
392  }
393 
394  rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
395  return success();
396  }
397 };
398 
399 /// This pattern transforms the DpasOp to work at subgroup level.
400 struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
402  LogicalResult
403  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
404  ConversionPatternRewriter &rewriter) const override {
405  Location loc = op.getLoc();
406  VectorType resultTy = op.getResult().getType();
407  if (resultTy.getRank() != 2)
408  return failure();
409 
410  auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
411  if (!originalLayout)
412  return failure();
413 
414  size_t i = 0;
415  SmallVector<Value> newDpasOps;
416  for (auto aVec : adaptor.getLhs()) {
417  for (auto bVec : adaptor.getRhs()) {
418 
419  llvm::SmallVector<Value> operands({aVec, bVec});
420  Value tmpC;
421  if (op.getAcc()) {
422  tmpC = adaptor.getAcc()[i++];
423  operands.push_back(tmpC);
424  }
425 
426  ArrayRef<int64_t> aVecShape =
427  llvm::cast<VectorType>(aVec.getType()).getShape();
428  ArrayRef<int64_t> bVecShape =
429  llvm::cast<VectorType>(bVec.getType()).getShape();
430  VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
431  resultTy.getElementType());
432  tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
433  xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
434  originalLayout.dropSgLayoutAndData());
435 
436  newDpasOps.push_back(tmpC);
437  }
438  }
439  rewriter.replaceOpWithMultiple(op, {newDpasOps});
440  return success();
441  }
442 };
443 
444 /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
445 struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
447  LogicalResult
448  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
449  ConversionPatternRewriter &rewriter) const override {
450 
451  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
452  if ((offsetSize != 0) || op.getConstOffsetsAttr())
453  return failure();
454 
455  for (auto src : adaptor.getTensorDesc())
456  xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
457  op->getAttrs());
458  rewriter.eraseOp(op);
459  return success();
460  }
461 };
462 
463 /// This pattern transforms vector.broadcast ops to work at subgroup level.
464 struct WgToSgVectorBroadcastOp
465  : public OpConversionPattern<vector::BroadcastOp> {
467 
468  LogicalResult
469  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
470  ConversionPatternRewriter &rewriter) const override {
471 
472  VectorType resultType = op.getResult().getType();
473  ArrayRef<int64_t> wgShape = resultType.getShape();
474 
475  xegpu::DistributeLayoutAttr layout =
476  xegpu::getDistributeLayoutAttr(op.getResult());
477  if (!layout || !layout.isForWorkgroup())
478  return failure();
479 
480  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
481  VectorType newResultType =
482  VectorType::get(sgShape, resultType.getElementType());
483 
484  if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
485  return failure();
486 
487  SmallVector<Value> newBroadcastOps;
488  for (auto operand : adaptor.getOperands().front()) {
489  auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
490  newResultType, operand);
491  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
492  !layout.getEffectiveInstDataAsInt().empty())
493  xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
494  layout.dropSgLayoutAndData());
495 
496  newBroadcastOps.push_back(newBroadcast.getResult());
497  }
498  rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
499  return success();
500  }
501 };
502 
503 // This pattern transforms elementwise ops to work at subgroup level.
504 struct WgToSgElementwiseOp : public ConversionPattern {
505  WgToSgElementwiseOp(MLIRContext *ctx)
506  : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
507 
508  LogicalResult
509  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
510  ConversionPatternRewriter &rewriter) const override {
511  // Only match ops with elementwise trait and single result.
513  return failure();
514 
515  auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
516  assert(resultType && "Expected result to be a VectorType");
517 
518  ArrayRef<int64_t> wgShape = resultType.getShape();
519 
520  xegpu::DistributeLayoutAttr layout =
522  if (!layout || !layout.isForWorkgroup())
523  return failure();
524 
525  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
526 
527  size_t numVariants = operands.empty() ? 0 : operands.front().size();
528 
529  if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
530  return operandVec.size() != numVariants;
531  }))
532  return failure();
533 
534  SmallVector<Value> newResults;
535  VectorType newResultType =
536  VectorType::get(sgShape, resultType.getElementType());
537 
538  for (size_t i = 0; i < numVariants; ++i) {
539  SmallVector<Value> opOperands;
540  for (auto &operandVec : operands)
541  opOperands.push_back(operandVec[i]);
542 
543  OperationState state(op->getLoc(), op->getName());
544  state.addOperands(opOperands);
545  state.addTypes(newResultType);
546  // Copy all attributes, but update "layout_result_0" to drop
547  // sgLayout/sgData
548  for (auto attr : op->getAttrs()) {
549  if (auto layout =
550  dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
551  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
552  !layout.getEffectiveInstDataAsInt().empty())
553  state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
554  } else {
555  state.addAttribute(attr.getName(), attr.getValue());
556  }
557  }
558  Operation *newOp = rewriter.create(state);
559  newResults.push_back(newOp->getResult(0));
560  }
561 
562  rewriter.replaceOpWithMultiple(op, {newResults});
563  return success();
564  }
565 };
566 
567 // clang-format off
568 // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
569 // If input_layout and target_layout have identical sg_layout and sg_data,
570 // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
571 // dropped. For example:
572 // #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
573 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
574 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
575 // becomes:
576 // #a = #xegpu.layout<inst_data = [16, 16]>
577 // #b = #xegpu.layout<inst_data = [8, 16]>
578 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
579 // (vector<16x16xf32> is determined by sg_data = [16, 16])
580 //
581 // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
582 // For example:
583 // #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
584 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
585 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
586 // is lowered to:
587 // #a = #xegpu.layout<inst_data = [16, 16]>
588 // #b = #xegpu.layout<inst_data = [8, 16]>
589 // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
590 // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
591 // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
592 // clang-format on
593 struct WgToSgConvertLayoutOp
594  : public OpConversionPattern<xegpu::ConvertLayoutOp> {
596  LogicalResult
597  matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
598  ConversionPatternRewriter &rewriter) const override {
599  // TODO: currently, we only support LayoutAttr
600  auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
601  auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
602 
603  if (!input || !target || !input.isForWorkgroup() ||
604  !target.isForWorkgroup())
605  return rewriter.notifyMatchFailure(
606  op, "Input and target layouts must have subgroup layout");
607 
608  DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
609  DenseI32ArrayAttr inputSgData = input.getSgData();
610  DenseI32ArrayAttr inputOrder = input.getOrder();
611  DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
612  DenseI32ArrayAttr targetSgData = target.getSgData();
613  DenseI32ArrayAttr targetOrder = target.getOrder();
614 
615  // TODO: currently we only support for optimal case, where input and
616  // output has the same sg_layout and sg_data, so SLM is not involved.
617  if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
618  inputOrder != targetOrder)
619  return failure();
620 
621  input = input.dropSgLayoutAndData();
622  target = target.dropSgLayoutAndData();
623 
624  SmallVector<Value> newOps(adaptor.getSource());
625  if (input && target) {
626  // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
627  for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
628  auto newOp = xegpu::ConvertLayoutOp::create(
629  rewriter, op.getLoc(), src.getType(), src, input, target);
630  newOps[i] = newOp;
631  }
632  }
633  rewriter.replaceOpWithMultiple(op, {newOps});
634  return success();
635  }
636 };
637 
638 // Handles UnrealizedConversionCastOp generated during
639 // SCFStructuralTypeConversions (step 1). This op may appear as either a
640 // target or source materialization for Vector values, e.g.:
641 // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
642 // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
643 // it could be either 1:N or N:1 cast. In both cases, the pattern
644 // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
645 // for example, the following scf::forOp
646 // ```
647 // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
648 // %n = use(%arg1): vector<128x128xf16>
649 // scf.yield %n : vector<128x128xf16>
650 // }
651 // ```
652 // Could be converted to:
653 // ```
654 // %1 = unrealized_conversion_cast %0
655 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
656 // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
657 // -> (vector<16x16xf16>, vector<16x16xf16) {
658 // %m = unrealized_conversion_cast %arg1, %arg2
659 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
660 // %n = use(%m): vector<128x128xf16>
661 // %b = unrealized_conversion_cast %n
662 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
663 // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
664 // }
665 // %cast = unrealized_conversion_cast %for:2
666 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
667 // ```
668 // TODO: remove it when context-aware type converter is ready.
669 struct UnrealizedConversionCastOpPattern
670  : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
671  using OpConversionPattern<
672  mlir::UnrealizedConversionCastOp>::OpConversionPattern;
673 
674  mlir::LogicalResult
675  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
676  ConversionPatternRewriter &rewriter) const override {
677  SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
678 
679  auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
680  auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
681 
682  if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
683  !llvm::all_equal(ValueRange(inputs).getTypes()))
684  return failure();
685 
686  // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
687  // It is generated by source materialization (e.g., inits to scf forOp).
688  // The input values provided by the adaptor should already be distributed,
689  // and their types should correspond exactly to the result types of the
690  // operation.
691  if (op.getNumOperands() == 1 &&
692  llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
693  rewriter.replaceOp(op, inputs);
694  return success();
695  }
696 
697  // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
698  // It is generated by target materialization (e.g., arguments/results
699  // of scf forOp). All input values must have the same vector type, and
700  // their shape must be evenly divisible by the output vector's shape
701  // (determined by the nature of the workgroup to subgroup distribution).
702  // TODO: it is not safe to do such forward, since such N:1 cast could be
703  // from others.
704  if (op.getNumResults() == 1 &&
705  computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
706  rewriter.replaceOpWithMultiple(op, {inputs});
707  return success();
708  }
709 
710  return mlir::failure();
711  }
712 };
713 
714 // This pattern distributes arith.constant op into subgroup-level constants
715 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
717 
718  LogicalResult
719  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
720  ConversionPatternRewriter &rewriter) const override {
721  auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
722  auto vecType = dyn_cast<VectorType>(op.getType());
723  if (!vecAttr || !vecType)
724  return failure();
725 
726  xegpu::DistributeLayoutAttr layout =
727  xegpu::getDistributeLayoutAttr(op.getResult());
728  if (!layout || !layout.isForWorkgroup())
729  return failure();
730 
731  ArrayRef<int64_t> wgShape = vecType.getShape();
732  SmallVector<int64_t> sgShape;
733  int count;
734  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
735 
736  auto newType = VectorType::get(sgShape, vecType.getElementType());
737  Location loc = op.getLoc();
738  auto eltType = vecType.getElementType();
739 
740  auto setLayoutIfNeeded = [&](Value val) {
741  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
742  !layout.getEffectiveInstDataAsInt().empty()) {
743  xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
744  layout.dropSgLayoutAndData());
745  }
746  };
747 
748  if (vecAttr.isSplat()) {
749  // Splat: single value for all subgroups
750  Attribute singleVal = vecAttr.getSplatValue<Attribute>();
751  auto sgAttr = DenseElementsAttr::get(newType, singleVal);
752  auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
753  setLayoutIfNeeded(cstOp->getResult(0));
754  rewriter.replaceOp(op, cstOp);
755  return success();
756  } else if (sgShape == wgShape) { // if the entire vector is shared by all
757  // subgroups, don't distribute
758  auto newConstOp =
759  arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
760  setLayoutIfNeeded(newConstOp->getResult(0));
761  rewriter.replaceOp(op, newConstOp);
762  return success();
763  } else {
764  // Non-splat constant
765  // Only supports 1D & 2D
766  // TODO: support other cases that require SLM access
767  if (!eltType.isIndex())
768  return rewriter.notifyMatchFailure(
769  op, "Unsupported element type for non-splat constant op.");
770 
771  if (wgShape.size() > 2)
772  return rewriter.notifyMatchFailure(
773  op, "Only 1D & 2D vector constant supported");
774 
775  SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
776  int64_t rowStride = 0, colStride = 0;
777  int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
778  int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
779 
780  // Compute colStride and rowStride, and check for constant strides.
781  if (cols > 1) {
782  colStride = cast<IntegerAttr>(values[1]).getInt() -
783  cast<IntegerAttr>(values[0]).getInt();
784  }
785  if (rows > 1) {
786  rowStride = cast<IntegerAttr>(values[cols]).getInt() -
787  cast<IntegerAttr>(values[0]).getInt();
788  }
789 
790  for (int64_t r = 0; r < rows; ++r) {
791  for (int64_t c = 0; c < cols; ++c) {
792  int64_t idx = r * cols + c;
793  // Check column stride
794  if (c > 0 && cols > 1) {
795  int64_t prevIdx = r * cols + (c - 1);
796  int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
797  cast<IntegerAttr>(values[prevIdx]).getInt();
798  if (diff != colStride)
799  return rewriter.notifyMatchFailure(
800  op, "Non-constant column stride in constant op.");
801  }
802  // Check row stride
803  if (r > 0 && rows > 1) {
804  int64_t prevIdx = (r - 1) * cols + c;
805  int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
806  cast<IntegerAttr>(values[prevIdx]).getInt();
807  if (diff != rowStride)
808  return rewriter.notifyMatchFailure(
809  op, "Non-constant row stride in constant op.");
810  }
811  }
812  }
813 
814  // Create a constant for the base tile.
815  // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
816  // For 1D case, extract the first sgShape[0] elements.
817  SmallVector<Attribute> baseTileValues;
818  int baseTileCols = sgShape[sgShape.size() - 1];
819  int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
820  for (int64_t r = 0; r < baseTileRows; ++r) {
821  for (int64_t c = 0; c < baseTileCols; ++c) {
822  baseTileValues.push_back(values[r * cols + c]);
823  }
824  }
825 
826  auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
827  baseTileValues);
828  auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
829 
830  // Get subgroup id
831  Value sgId =
832  gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
833 
834  auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
835  if (failed(sgOffsets))
836  return failure();
837 
838  SmallVector<Value, 2> strideConsts;
839  strideConsts.push_back(
840  arith::ConstantIndexOp::create(rewriter, loc, colStride));
841  if (rows > 1)
842  strideConsts.insert(
843  strideConsts.begin(),
844  arith::ConstantIndexOp::create(rewriter, loc, rowStride));
845 
846  SmallVector<Value> newConstOps;
847  for (auto offsets : *sgOffsets) {
848  // Multiply offset with stride, broadcast it and add to baseConstVec
849  Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
850  for (size_t i = 0; i < strideConsts.size(); ++i) {
851  Value mul =
852  arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
853  offsets[i], strideConsts[i]);
854  mulOffset = arith::AddIOp::create(
855  rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
856  }
857  // Broadcast to baseConstVec size
858  auto bcastOffset = vector::BroadcastOp::create(
859  rewriter, loc, baseConstVec.getType(), mulOffset);
860  auto finalConst =
861  arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
862  setLayoutIfNeeded(baseConstVec);
863  setLayoutIfNeeded(bcastOffset);
864  setLayoutIfNeeded(finalConst);
865  newConstOps.push_back(finalConst);
866  }
867  rewriter.replaceOpWithMultiple(op, {newConstOps});
868  return success();
869  }
870  }
871 };
872 
873 // This pattern transforms the LoadGatherOp with explicit offsets to load
874 // subgroup data
875 struct WgToSgLoadGatherOpWithOffset
876  : public OpConversionPattern<xegpu::LoadGatherOp> {
878  LogicalResult
879  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
880  ConversionPatternRewriter &rewriter) const override {
881 
882  if (!op.getOffsets())
883  return failure();
884 
885  Location loc = op.getLoc();
886  VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
887  if (!resultType)
888  return failure();
889  ArrayRef<int64_t> wgShape = resultType.getShape();
890 
891  xegpu::DistributeLayoutAttr layout =
892  xegpu::getDistributeLayoutAttr(op.getResult());
893  if (!layout || !layout.isForWorkgroup())
894  return failure();
895 
896  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
897 
898  // The offsets need to be distributed
899  auto offsetsVecType =
900  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
901  auto maskVecType =
902  dyn_cast<VectorType>(adaptor.getMask().front().getType());
903  if (!offsetsVecType || !maskVecType ||
904  offsetsVecType.getShape() != maskVecType.getShape()) {
905  return rewriter.notifyMatchFailure(op,
906  "offsets have not been distributed");
907  }
908 
909  SmallVector<Value> newLoadOps;
910  auto chunkSizeAttr =
911  rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
912  VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
913  for (auto [offsets, mask] :
914  llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
915  auto newLoadOp = xegpu::LoadGatherOp::create(
916  rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
917  op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
918  xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
919  layout.dropSgLayoutAndData());
920  newLoadOps.push_back(newLoadOp);
921  }
922  rewriter.replaceOpWithMultiple(op, {newLoadOps});
923  return success();
924  }
925 };
926 
927 // This pattern transforms the StoreScatterOp with explicit offsets to store
928 // subgroup data
929 struct WgToSgStoreScatterOpWithOffset
930  : public OpConversionPattern<xegpu::StoreScatterOp> {
932  LogicalResult
933  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
934  ConversionPatternRewriter &rewriter) const override {
935 
936  if (!op.getOffsets())
937  return failure();
938 
939  Location loc = op.getLoc();
940  VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
941  if (!valueType)
942  return failure();
943 
944  xegpu::DistributeLayoutAttr layout =
945  xegpu::getDistributeLayoutAttr(op.getOperand(0));
946  if (!layout || !layout.isForWorkgroup())
947  return failure();
948 
949  // The offsets need to be distributed
950  auto offsetsVecType =
951  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
952  auto maskVecType =
953  dyn_cast<VectorType>(adaptor.getMask().front().getType());
954  if (!offsetsVecType || !maskVecType ||
955  offsetsVecType.getShape() != maskVecType.getShape()) {
956  return rewriter.notifyMatchFailure(op,
957  "offsets have not been distributed");
958  }
959 
960  auto chunkSizeOpt = op.getChunkSize();
961  int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
962  auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
963  for (auto [val, offs, mask] : llvm::zip(
964  adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
965  auto store = xegpu::StoreScatterOp::create(
966  rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
967  op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
968  // Update the layout attribute to drop sg_layout and sg_data.
969  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
970  !layout.getEffectiveInstDataAsInt().empty()) {
971  for (OpOperand &operand : store->getOpOperands()) {
972  // Skip for operand one (memref)
973  if (operand.getOperandNumber() == 1)
974  continue;
975  xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
976  }
977  }
978  }
979  rewriter.eraseOp(op);
980  return success();
981  }
982 };
983 
984 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
986  LogicalResult
987  matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
988  ConversionPatternRewriter &rewriter) const override {
989 
991  if (failed(genOffsetsList(rewriter, op, offsetsList)))
992  return failure();
993 
994  ArrayRef<int64_t> wgShape = op.getDataShape();
995  VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
996  assert(valueTy && "the value type must be vector type!");
997  Type elemTy = valueTy.getElementType();
998 
999  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1000  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1001  VectorType newResTy = VectorType::get(sgShape, elemTy);
1002  SmallVector<Value> newOps;
1003  for (auto offsets : offsetsList) {
1004  auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1005  op.getMemDesc(), offsets,
1006  layout.dropSgLayoutAndData());
1007  newOps.push_back(newOp);
1008  }
1009  rewriter.replaceOpWithMultiple(op, {newOps});
1010 
1011  return success();
1012  }
1013 };
1014 
1015 struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
1017  LogicalResult
1018  matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1019  ConversionPatternRewriter &rewriter) const override {
1020 
1022  if (failed(genOffsetsList(rewriter, op, offsetsList)))
1023  return failure();
1024 
1025  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1026  for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1027  xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1028  offsets, layout.dropSgLayoutAndData());
1029  rewriter.eraseOp(op);
1030  return success();
1031  }
1032 };
1033 
1034 // This pattern distributes the vector.step ops to work at subgroup level
1035 struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
1037  LogicalResult
1038  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const override {
1040  xegpu::DistributeLayoutAttr layout =
1041  xegpu::getDistributeLayoutAttr(op.getResult());
1042  if (!layout || !layout.isForWorkgroup())
1043  return failure();
1044 
1045  Location loc = op.getLoc();
1046  VectorType type = op.getResult().getType();
1047  auto wgShape = type.getShape();
1048  std::optional<SmallVector<int64_t>> sgShape =
1049  getSgShapeAndCount(wgShape, layout).first;
1050  if (!sgShape)
1051  return failure();
1052 
1053  Value sgId =
1054  gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1055  auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
1056  if (failed(sgOffsets))
1057  return failure();
1058 
1059  VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1060  auto steps = vector::StepOp::create(rewriter, loc, newTy);
1061  SmallVector<Value> newOps;
1062  for (auto offsets : *sgOffsets) {
1063  // Broadcast the offset scalar to a vector & add to the base steps
1064  auto bcastOffset =
1065  vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1066  auto finalSteps =
1067  arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1068  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1069  !layout.getEffectiveInstDataAsInt().empty()) {
1070  xegpu::setDistributeLayoutAttr(steps->getResult(0),
1071  layout.dropSgLayoutAndData());
1072  xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
1073  layout.dropSgLayoutAndData());
1074  xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
1075  layout.dropSgLayoutAndData());
1076  }
1077  newOps.push_back(finalSteps);
1078  }
1079 
1080  rewriter.replaceOpWithMultiple(op, {newOps});
1081  return success();
1082  }
1083 };
1084 
1085 // This pattern transforms vector.shape_cast ops to work at subgroup level.
1086 struct WgToSgVectorShapeCastOp
1087  : public OpConversionPattern<vector::ShapeCastOp> {
1089 
1090  LogicalResult
1091  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1092  ConversionPatternRewriter &rewriter) const override {
1093 
1094  VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1095  if (!resultType)
1096  return failure();
1097 
1098  ArrayRef<int64_t> wgShape = resultType.getShape();
1099  xegpu::DistributeLayoutAttr layout =
1100  xegpu::getDistributeLayoutAttr(op.getResult());
1101  if (!layout || !layout.isForWorkgroup())
1102  return failure();
1103 
1104  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1105  VectorType newResultType =
1106  VectorType::get(sgShape, resultType.getElementType());
1107 
1108  // TODO: Add check for compatible layouts in layout attr.
1109  auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1110  if (!srcType)
1111  return failure();
1112 
1113  // Check that shape_cast only adds/removes unit dimensions,
1114  auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
1115  // Remove all 1s from both shapes and compare the rest.
1116  SmallVector<int64_t> srcNonUnit, dstNonUnit;
1117  for (int64_t d : src)
1118  if (d != 1)
1119  srcNonUnit.push_back(d);
1120  for (int64_t d : dst)
1121  if (d != 1)
1122  dstNonUnit.push_back(d);
1123  return srcNonUnit == dstNonUnit;
1124  };
1125 
1126  if (!onlyUnitDims(srcType.getShape(), sgShape))
1127  return failure();
1128 
1129  // For rank reducing or increasing shape_cast ops, the lower rank layout
1130  // must be a slice of higher rank layout.
1131  int64_t sourceRank = srcType.getRank();
1132  int64_t resultRank = sgShape.size();
1133  xegpu::DistributeLayoutAttr sourceLayout =
1134  xegpu::getDistributeLayoutAttr(op.getSource());
1135  if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1136  return failure();
1137  if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1138  return failure();
1139 
1140  SmallVector<Value> newShapeCastOps;
1141  for (auto src : adaptor.getSource()) {
1142  auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1143  newResultType, src);
1144  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1145  !layout.getEffectiveInstDataAsInt().empty())
1146  xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1147  layout.dropSgLayoutAndData());
1148  newShapeCastOps.push_back(newShapeCast.getResult());
1149  }
1150 
1151  rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1152  return success();
1153  }
1154 };
1155 
1156 /// Pattern for lowering vector.multi_reduction op to subgroup level.
1157 /// Current limitation: the sg_layout in the reduced dimension being 1
1158 /// so that reduction is local to subgroup & no cross-subgroup communication is
1159 /// needed.
1160 /// TODO: Add cases to handle more general situations which require SLM access.
1161 struct WgToSgMultiDimReductionOp
1162  : public OpConversionPattern<vector::MultiDimReductionOp> {
1164 
1165  LogicalResult
1166  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1167  ConversionPatternRewriter &rewriter) const override {
1168  VectorType srcType = op.getSourceVectorType();
1169  VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1170  if (!dstType)
1171  return failure();
1172 
1173  auto srcShape = srcType.getShape();
1174  xegpu::DistributeLayoutAttr layout =
1175  xegpu::getDistributeLayoutAttr(op.getResult());
1176  if (!layout || !layout.isForWorkgroup())
1177  return failure();
1178 
1179  auto reductionDims = llvm::to_vector(op.getReductionDims());
1180 
1181  SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1182  .getParent()
1183  .getEffectiveSgLayoutAsInt();
1184  SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1185  .getParent()
1186  .getEffectiveSgDataAsInt();
1187 
1188  // Check that the sgLayout in the reduced dimension is 1 and
1189  // each sg gets the entire slice to reduce.
1190  for (int64_t dim : reductionDims) {
1191  if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1192  return rewriter.notifyMatchFailure(
1193  op,
1194  "sgLayout in each reduced dimension must be 1 and sgData in the "
1195  "reduced dim must match srcShape in that dim");
1196  }
1197 
1198  SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1199 
1200  VectorType newDstType =
1201  VectorType::get({sgShape}, dstType.getElementType());
1202 
1203  SmallVector<Value> newReductions;
1204  for (auto sgSrc : adaptor.getSource()) {
1205  auto newOp = vector::MultiDimReductionOp::create(
1206  rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1207  adaptor.getAcc()[0], op.getReductionDims());
1208  if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1209  !layout.getEffectiveInstDataAsInt().empty())
1211  layout.dropSgLayoutAndData());
1212  newReductions.push_back(newOp.getResult());
1213  }
1214 
1215  rewriter.replaceOpWithMultiple(op, {newReductions});
1216  return success();
1217  }
1218 };
1219 
1220 } // namespace
1221 
1222 namespace mlir {
1223 namespace xegpu {
1225  patterns
1226  .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1227  WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1228  WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1229  WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1230  WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1231  WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1232  WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1233  WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1234  WgToSgMultiDimReductionOp>(patterns.getContext());
1235 }
1236 } // namespace xegpu
1237 } // namespace mlir
1238 
1239 namespace {
1240 struct XeGPUWgToSgDistributePass
1241  : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
1242  void runOnOperation() override;
1243 };
1244 } // namespace
1245 
1246 void XeGPUWgToSgDistributePass::runOnOperation() {
1247  // Track existing UnrealizedConversionCastOps
1248  SmallVector<Operation *> existingCastOps;
1249  getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1250  existingCastOps.push_back(castOp.getOperation());
1251  });
1252 
1253  {
1254  // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
1255  // VectorType operands. This first converts such operands to
1256  // RankedTensorType, propagates the layout attribute into the encoding
1257  // attribute, and finally converts the RankedTensorType to VectorType based
1258  // on the encoding.
1259 
1260  TypeConverter converter;
1261  converter.addConversion([&](Type type) -> Type { return type; });
1262  converter.addConversion(
1263  [&](RankedTensorType type,
1264  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1265  Type elemTy = type.getElementType();
1266  ArrayRef<int64_t> shape = type.getShape();
1267 
1268  int count;
1269  SmallVector<int64_t> subShape;
1270  std::tie(subShape, count) = getSgShapeAndCount(
1271  shape,
1272  dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1273 
1274  auto newTy = VectorType::get(subShape, elemTy);
1275  result.append(count, newTy);
1276  return success();
1277  });
1278 
1280  converter);
1281  }
1282 
1283  // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
1284  // as well as XeGPU, Arith, and Vector operations.
1285  MLIRContext *ctx = &getContext();
1287  ConversionTarget target(*ctx);
1288  TypeConverter converter;
1289  converter.addConversion([&](Type type) -> Type { return type; });
1290  converter.addConversion(
1291  [&](xegpu::TensorDescType type,
1292  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
1293  Type elemTy = type.getElementType();
1294  ArrayRef<int64_t> shape = type.getShape();
1295 
1296  int count;
1297  SmallVector<int64_t> subShape;
1298  xegpu::LayoutAttr layout = type.getLayoutAttr();
1299  std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1300 
1301  if (layout)
1302  layout = layout.dropSgLayoutAndData();
1303 
1304  auto newTy = xegpu::TensorDescType::get(
1305  type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1306  result.append(count, newTy);
1307  return success();
1308  });
1309 
1310  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1311  if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1312  return createOp.getType();
1313  if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1314  return loadOp.getTensorDescType();
1315  if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1316  return storeOp.getTensorDescType();
1317  if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1318  return updateOp.getType();
1319  if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1320  return prefetchOp.getTensorDescType();
1321  return xegpu::TensorDescType();
1322  };
1323 
1324  auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1325  return !layout || !layout.isForWorkgroup();
1326  };
1327 
1328  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1329  xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1330  xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1331  auto tdescTy = getTensorDescType(op);
1332  auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1333  return isLegal(layout);
1334  });
1335 
1336  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1337  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1338  return isLegal(layout);
1339  });
1340 
1341  target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1342  [=](xegpu::LoadMatrixOp op) -> bool {
1343  return isLegal(op.getLayoutAttr());
1344  });
1345 
1346  target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1347  [=](xegpu::StoreMatrixOp op) -> bool {
1348  return isLegal(op.getLayoutAttr());
1349  });
1350 
1351  target.addDynamicallyLegalOp<arith::ConstantOp>(
1352  [=](arith::ConstantOp op) -> bool {
1353  auto vecType = dyn_cast<VectorType>(op.getType());
1354  if (!vecType)
1355  return true;
1356 
1357  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1358  return isLegal(layout);
1359  });
1360 
1361  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
1362  [=](Operation *op) -> bool {
1363  // Check for either a SliceAttr or LayoutAttr on the result.
1364  auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1365  return isLegal(layout);
1366  });
1367 
1368  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1369  [=](xegpu::LoadGatherOp op) -> bool {
1370  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1371  return isLegal(layout);
1372  });
1373 
1374  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1375  [=](xegpu::StoreScatterOp op) -> bool {
1376  auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
1377  return isLegal(layout);
1378  });
1379 
1380  target.addDynamicallyLegalOp<vector::BroadcastOp>(
1381  [=](vector::BroadcastOp op) -> bool {
1382  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1383  });
1384 
1385  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1386  [=](vector::MultiDimReductionOp op) -> bool {
1387  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1388  });
1389 
1390  target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1391  [=](xegpu::ConvertLayoutOp op) -> bool {
1392  return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1393  });
1394 
1395  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1396  [=](Operation *op) -> std::optional<bool> {
1397  // Only handle elementwise mappable ops
1399  return true;
1400 
1401  VectorType resultType =
1402  dyn_cast<VectorType>(op->getResult(0).getType());
1403  if (!resultType)
1404  return true;
1405 
1406  // Check if all operands are vectors of the same shape
1407  // TODO: Support other types.
1408  for (Value operand : op->getOperands()) {
1409  VectorType operandType = dyn_cast<VectorType>(operand.getType());
1410  if (!operandType || operandType.getShape() != resultType.getShape()) {
1411  return true;
1412  }
1413  }
1414 
1415  xegpu::DistributeLayoutAttr layout =
1416  xegpu::getDistributeLayoutAttr(op->getResult(0));
1417  return isLegal(layout);
1418  });
1419 
1420  target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1421  [=](UnrealizedConversionCastOp op) {
1422  return llvm::is_contained(existingCastOps, op.getOperation());
1423  });
1424 
1425  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1426 
1428  target);
1430  if (failed(
1431  applyPartialConversion(getOperation(), target, std::move(patterns))))
1432  return signalPassFailure();
1433 
1434  // Remove sg_layout and sg_data attributes from the Layout
1435  // attribute for each VectorType result of the operation.
1436  // For Structured Control Flow ops, the layout is simply removed,
1437  // since in 1:N case, the layout for new results are missing.
1438  // Layout propagation pass will activated.
1439  getOperation()->walk([](Operation *op) {
1440  for (OpResult result : op->getOpResults()) {
1441  std::string name = xegpu::getLayoutName(result);
1442  if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1443  op->removeAttr(name);
1444  if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1445  if (auto newLayout = layout.dropSgLayoutAndData())
1446  op->setAttr(name, newLayout);
1447  }
1448  }
1449  }
1450  });
1451 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
int64_t cols
int64_t rows
#define mul(a, b)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getOpResults()
Definition: Operation.h:420
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:178
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:301
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:116
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Definition: XeGPUUtils.cpp:492
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.