MLIR  22.0.0git
XeGPUWgToSgDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
24 #include <optional>
25 
26 namespace mlir {
27 namespace xegpu {
28 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 } // namespace xegpu
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 // Retrieve the RangeAttr if it is specified.
38 static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39  Operation *parent = op->getParentOfType<scf::IfOp>();
40  while (parent) {
41  if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42  parent->getAttr("sg_id_range")))
43  return attr;
44  parent = parent->getParentOfType<scf::IfOp>();
45  }
46  return {};
47 }
48 
49 static std::pair<SmallVector<int64_t>, int>
50 getSgShapeAndCount(ArrayRef<int64_t> shape,
51  xegpu::DistributeLayoutAttr layout) {
52  int count = 1;
53  SmallVector<int64_t> sgShape(shape);
54  if (layout && layout.isForWorkgroup()) {
55  SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
56  if (!layout.getSgDataAsInt().empty())
57  sgShape = layout.getSgDataAsInt();
58  else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
59  sgShape = *maybeDerivedSgData;
60  SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
61  // Clamp distUnit to the original shape to handle cases where data is
62  // shared among subgroups, which may cause distUnit to exceed the original
63  // shape.
64  for (size_t i = 0; i < distUnit.size(); ++i)
65  distUnit[i] = std::min(shape[i], distUnit[i]);
66  count = computeProduct(shape) / computeProduct(distUnit);
67  }
68  return std::make_pair(sgShape, count);
69 }
70 
71 /// Utility helper for deriving a list of offsets for each sub-TensorDescs
72 /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
73 /// associated distribute layout attribute, the shape, subgroup id and the
74 /// original offsets of the op
75 template <
76  typename OpType,
77  typename = std::enable_if_t<llvm::is_one_of<
78  OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79  xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
80 static LogicalResult
81 genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
82  SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
83  Location loc = op.getLoc();
84  SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
85  // not applicable to ops without offsets operands.
86  if (origOffsets.empty())
87  return failure();
88 
89  // not applicable to ops without workgroup layout attributes
90  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91  if (!layout || !layout.isForWorkgroup())
92  return failure();
93 
94  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
95 
96  // verify and adjust the sgId if the range specifier is present
97  xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
98  if (sgIdRange) {
99  int64_t startOfRange = sgIdRange.getStart().getInt();
100  int64_t endOfRange = sgIdRange.getEnd().getInt();
101  // verify the RangeAttr against the layout attribute
102  if (layout.getNumSubgroups() != endOfRange - startOfRange)
103  return rewriter.notifyMatchFailure(
104  op, "sg_layout size must match the sg_id_range");
105  // adjust the sgId if necessary
106  if (startOfRange > 0) {
107  Value startOfRangeVal =
108  rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
109  sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
110  }
111  }
112 
113  // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
114  // descriptors to be accessed, based on the layout information.
115  ArrayRef<int64_t> wgShape = op.getDataShape();
116  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
117  if (failed(maybeDescOffsets))
118  return failure();
119 
120  // Compute the final global offsets for each accessed sub-tensor
121  // or sub-memory descriptor.
122  for (const auto &sgOffsets : *maybeDescOffsets) {
124  rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
125  offsetsList.push_back(std::move(newOffsets));
126  }
127 
128  // callback(offsetsList);
129  return success();
130 }
131 
132 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
133 /// from a workgroup descriptor. It replaces the offsets and sizes with
134 /// appropriate values for the subgroup.
135 /// It uses round-robin assignment to distribute the work to the subgroups.
136 /// Following create_nd_desc operation:,
137 /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
138 /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
139 /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
140 /// is converted to 9 subgroup level operations based on the sg_layout &
141 /// sg_data:
142 /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
143 /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
144 /// lane_data = [1, 1]>>
145 ///
146 /// The sg_layout and sg_data attributes are dropped after the pass as they are
147 /// no longer needed.
148 ///
149 /// 24x24 matrix distribution example:
150 /// sg_layout = [4, 4], sg_data = [2, 2]
151 /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
152 /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
153 ///
154 /// +------------------------+
155 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
156 /// |-----+-----+-----|
157 /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
158 /// |-----+-----+-----|
159 /// | 8x8 | 8x8 | 8x8 |
160 /// +------------------------+
161 ///
162 /// Each 8x8 tile is further subdivided among subgroups:
163 /// +------------------------+
164 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
165 /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
166 /// | 2x2 2x2 2x2 2x2 |
167 /// | 2x2 2x2 2x2 2x2 |
168 /// +------------------------+
169 ///
170 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
171 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
172 
173 /// The pass currently has entire distribution logic in the WgToSgCreateNdOp
174 /// pattern and all the other ops just follow.
175 /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
176 /// ops in the pass.
177 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
179 
180  LogicalResult
181  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override {
184  if (failed(genOffsetsList(rewriter, op, offsetsList)))
185  return failure();
186 
187  MLIRContext *ctx = op.getContext();
188  xegpu::TensorDescType tdescTy = op.getType();
189  ArrayRef<int64_t> wgShape = tdescTy.getShape();
190  Type elemTy = tdescTy.getElementType();
191  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
192  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
193  auto newTdescTy =
194  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
195  layout.dropSgLayoutAndData());
196 
197  SmallVector<Value> newOps;
198  for (auto offsets : offsetsList) {
199  auto newOp = xegpu::CreateNdDescOp::create(
200  rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
201  op.getMixedSizes(), op.getMixedStrides());
202 
203  newOps.push_back(newOp);
204  }
205  rewriter.replaceOpWithMultiple(op, {newOps});
206 
207  return success();
208  }
209 };
210 
211 // This pattern transforms the CreateNdDescOp without offsets to create a
212 // subgroup descriptor from a workgroup descriptor
213 struct WgToSgCreateNdOpNoOffset
214  : public OpConversionPattern<xegpu::CreateNdDescOp> {
216 
217  LogicalResult
218  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const override {
220 
221  // Check no offsets are specified.
222  if (!op.getMixedOffsets().empty())
223  return failure();
224 
225  Location loc = op.getLoc();
226  MLIRContext *ctx = op.getContext();
227  xegpu::TensorDescType tdescTy = op.getType();
228  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
229  if (!layout || !layout.isForWorkgroup())
230  return failure();
231 
232  Type elemTy = tdescTy.getElementType();
233  ArrayRef<int64_t> wgShape = tdescTy.getShape();
234 
235  SmallVector<int64_t> sgShape;
236  int count;
237  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
238  xegpu::TensorDescType newTdescTy =
239  xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
240  layout.dropSgLayoutAndData());
241 
242  SmallVector<Value> newCreateNdOps(count);
243  std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
244  return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
245  op.getSource(), op.getMixedSizes(),
246  op.getMixedStrides());
247  });
248 
249  rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
250  return success();
251  }
252 };
253 
254 /// This pattern transforms the LoadNdOp to load subgroup data.
255 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
257  LogicalResult
258  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
259  ConversionPatternRewriter &rewriter) const override {
260  if (!op.getMixedOffsets().empty())
261  return failure();
262 
263  SmallVector<Value> newLoadOps;
264  for (auto src : adaptor.getTensorDesc()) {
265  xegpu::TensorDescType tdescTy =
266  dyn_cast<xegpu::TensorDescType>(src.getType());
267  ArrayRef<int64_t> srcShape = tdescTy.getShape();
268  VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
269  auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
270  src, op->getAttrs());
271  newLoadOps.push_back(newLoadOp);
272  }
273  rewriter.replaceOpWithMultiple(op, {newLoadOps});
274  return mlir::success();
275  }
276 };
277 
278 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
279 /// It creates a StoreNdOp op to store the updated values to the new subgroup
280 /// src tensor descriptors.
281 struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
283  LogicalResult
284  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
285  ConversionPatternRewriter &rewriter) const override {
286  if (!op.getMixedOffsets().empty())
287  return failure();
288 
289  for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
290  xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
291  op.getL2HintAttr(), op.getL3HintAttr());
292 
293  rewriter.eraseOp(op);
294  return success();
295  }
296 };
297 
298 // This pattern transforms the LoadNdOp with explicit offsets to load
299 // subgroup data.
300 struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
302  LogicalResult
303  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
304  ConversionPatternRewriter &rewriter) const override {
305 
307  if (failed(genOffsetsList(rewriter, op, offsetsList)))
308  return failure();
309 
310  SmallVector<Value> newOps;
311  for (auto [tdesc, offsets] :
312  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
313  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
314  VectorType newResTy =
315  VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
316  auto newOp = xegpu::LoadNdOp::create(
317  rewriter, op.getLoc(), newResTy, tdesc, offsets,
318  /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
319  op.getL2HintAttr(), op.getL3HintAttr());
320  newOps.push_back(newOp);
321  }
322  rewriter.replaceOpWithMultiple(op, {newOps});
323 
324  return success();
325  }
326 };
327 
328 // This pattern transforms the StoreNdOp with explicit offsets to store
329 // subgroup data.
330 struct WgToSgStoreNdOpWithOffset
331  : public OpConversionPattern<xegpu::StoreNdOp> {
333  LogicalResult
334  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
337  if (failed(genOffsetsList(rewriter, op, offsetsList)))
338  return failure();
339 
340  for (auto [v, tdesc, offsets] :
341  llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
342  rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
343  op.getL1HintAttr(), op.getL2HintAttr(),
344  op.getL3HintAttr());
345  }
346  rewriter.eraseOp(op);
347 
348  return success();
349  }
350 };
351 
352 // This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
353 // subgroup data.
354 struct WgToSgPrefetchNdOpWithOffset
355  : public OpConversionPattern<xegpu::PrefetchNdOp> {
357  LogicalResult
358  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
359  ConversionPatternRewriter &rewriter) const override {
361  if (failed(genOffsetsList(rewriter, op, offsetsList)))
362  return failure();
363 
364  for (auto [tdesc, offsets] :
365  llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
366  rewriter.create<xegpu::PrefetchNdOp>(
367  op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
368  op.getL3HintAttr());
369  }
370  rewriter.eraseOp(op);
371 
372  return success();
373  }
374 };
375 
376 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
377 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
378 /// offsets of the new subgroup src tensor descriptors.
379 struct WgToSgUpdateNdOffsetOp
380  : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
382  LogicalResult
383  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
384  ConversionPatternRewriter &rewriter) const override {
385  llvm::SmallVector<Value> newUpdateTileOffsetOps;
386  for (auto tDesc : adaptor.getTensorDesc()) {
387  auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
388  rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
389  op.getConstOffsets());
390  newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
391  }
392 
393  rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
394  return success();
395  }
396 };
397 
398 /// This pattern transforms the DpasOp to work at subgroup level.
399 struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
401  LogicalResult
402  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const override {
404  Location loc = op.getLoc();
405  VectorType resultTy = op.getResult().getType();
406  if (resultTy.getRank() != 2)
407  return failure();
408 
409  auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
410  if (!originalLayout)
411  return failure();
412 
413  size_t i = 0;
414  SmallVector<Value> newDpasOps;
415  for (auto aVec : adaptor.getLhs()) {
416  for (auto bVec : adaptor.getRhs()) {
417 
418  llvm::SmallVector<Value> operands({aVec, bVec});
419  Value tmpC;
420  if (op.getAcc()) {
421  tmpC = adaptor.getAcc()[i++];
422  operands.push_back(tmpC);
423  }
424 
425  ArrayRef<int64_t> aVecShape =
426  llvm::cast<VectorType>(aVec.getType()).getShape();
427  ArrayRef<int64_t> bVecShape =
428  llvm::cast<VectorType>(bVec.getType()).getShape();
429  VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
430  resultTy.getElementType());
431  tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
432  xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
433  originalLayout.dropSgLayoutAndData());
434 
435  newDpasOps.push_back(tmpC);
436  }
437  }
438  rewriter.replaceOpWithMultiple(op, {newDpasOps});
439  return success();
440  }
441 };
442 
443 /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
444 struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
446  LogicalResult
447  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
448  ConversionPatternRewriter &rewriter) const override {
449 
450  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
451  if ((offsetSize != 0) || op.getConstOffsetsAttr())
452  return failure();
453 
454  for (auto src : adaptor.getTensorDesc())
455  xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src,
456  op->getAttrs());
457  rewriter.eraseOp(op);
458  return success();
459  }
460 };
461 
462 /// This pattern transforms vector.broadcast ops to work at subgroup level.
463 struct WgToSgVectorBroadcastOp
464  : public OpConversionPattern<vector::BroadcastOp> {
466 
467  LogicalResult
468  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
469  ConversionPatternRewriter &rewriter) const override {
470  VectorType resultType = op.getResult().getType();
471  ArrayRef<int64_t> wgShape = resultType.getShape();
472 
473  xegpu::DistributeLayoutAttr layout =
474  xegpu::getDistributeLayoutAttr(op.getResult());
475  if (!layout || !layout.isForWorkgroup())
476  return failure();
477 
478  // TODO: Currently only supports cases where the source and result ranks
479  // are the same.
480  auto srcType =
481  dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
482  if (!srcType || srcType.getRank() != resultType.getRank())
483  return failure();
484 
485  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
486  VectorType newResultType =
487  VectorType::get(sgShape, resultType.getElementType());
488 
489  // Check if the output layout is distributable
490  SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
491  if (sgLayout.empty())
492  return failure();
493 
494  if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
495  return failure();
496 
497  // Check if the srcShape has unit dim in dimensions being broadcasted,
498  // and the other dimensions are the same as the destination type
499  // TODO: Generalize it
500  auto srcShape = srcType.getShape();
501  for (size_t i = 0; i < srcShape.size(); ++i) {
502  if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
503  return failure();
504  }
505 
506  SmallVector<Value> newBroadcastOps;
507  for (auto operand : adaptor.getOperands().front()) {
508  auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
509  newResultType, operand);
510  xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
511  layout.dropSgLayoutAndData());
512  newBroadcastOps.push_back(newBroadcast.getResult());
513  }
514 
515  rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
516  return success();
517  }
518 };
519 
520 // This pattern transforms elementwise ops to work at subgroup level.
521 struct WgToSgElementwiseOp : public ConversionPattern {
522  WgToSgElementwiseOp(MLIRContext *ctx)
523  : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
524 
525  LogicalResult
526  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
527  ConversionPatternRewriter &rewriter) const override {
528  // Only match ops with elementwise trait and single result.
530  return failure();
531 
532  auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
533  assert(resultType && "Expected result to be a VectorType");
534 
535  ArrayRef<int64_t> wgShape = resultType.getShape();
536 
537  xegpu::DistributeLayoutAttr layout =
539  if (!layout || !layout.isForWorkgroup())
540  return failure();
541 
542  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
543 
544  size_t numVariants = operands.empty() ? 0 : operands.front().size();
545 
546  if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
547  return operandVec.size() != numVariants;
548  }))
549  return failure();
550 
551  SmallVector<Value> newResults;
552  VectorType newResultType =
553  VectorType::get(sgShape, resultType.getElementType());
554 
555  for (size_t i = 0; i < numVariants; ++i) {
556  SmallVector<Value> opOperands;
557  for (auto &operandVec : operands)
558  opOperands.push_back(operandVec[i]);
559 
560  OperationState state(op->getLoc(), op->getName());
561  state.addOperands(opOperands);
562  state.addTypes(newResultType);
563  // Copy all attributes, but update "layout_result_0" to drop
564  // sgLayout/sgData
565  for (auto attr : op->getAttrs()) {
566  if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
567  if (auto newLayout = layout.dropSgLayoutAndData())
568  state.addAttribute(attr.getName(), newLayout);
569  } else {
570  state.addAttribute(attr.getName(), attr.getValue());
571  }
572  }
573  Operation *newOp = rewriter.create(state);
574  newResults.push_back(newOp->getResult(0));
575  }
576 
577  rewriter.replaceOpWithMultiple(op, {newResults});
578  return success();
579  }
580 };
581 
582 // clang-format off
583 // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
584 // If input_layout and target_layout have identical sg_layout and sg_data,
585 // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
586 // dropped. For example:
587 // #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
588 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
589 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
590 // becomes:
591 // #a = #xegpu.layout<inst_data = [16, 16]>
592 // #b = #xegpu.layout<inst_data = [8, 16]>
593 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
594 // (vector<16x16xf32> is determined by sg_data = [16, 16])
595 //
596 // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
597 // For example:
598 // #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
599 // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
600 // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
601 // is lowered to:
602 // #a = #xegpu.layout<inst_data = [16, 16]>
603 // #b = #xegpu.layout<inst_data = [8, 16]>
604 // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
605 // %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
606 // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
607 // clang-format on
608 struct WgToSgConvertLayoutOp
609  : public OpConversionPattern<xegpu::ConvertLayoutOp> {
611  LogicalResult
612  matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
613  ConversionPatternRewriter &rewriter) const override {
614  // TODO: currently, we only support LayoutAttr
615  auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
616  auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
617 
618  if (!input || !target || !input.isForWorkgroup() ||
619  !target.isForWorkgroup())
620  return rewriter.notifyMatchFailure(
621  op, "Input and target layouts must have subgroup layout");
622 
623  DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
624  DenseI32ArrayAttr inputSgData = input.getSgData();
625  DenseI32ArrayAttr inputOrder = input.getOrder();
626  DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
627  DenseI32ArrayAttr targetSgData = target.getSgData();
628  DenseI32ArrayAttr targetOrder = target.getOrder();
629 
630  // TODO: currently we only support for optimal case, where input and
631  // output has the same sg_layout and sg_data, so SLM is not involved.
632  if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
633  inputOrder != targetOrder)
634  return failure();
635 
636  input = input.dropSgLayoutAndData();
637  target = target.dropSgLayoutAndData();
638 
639  SmallVector<Value> newOps(adaptor.getSource());
640  if (input && target) {
641  // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
642  for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
643  auto newOp = xegpu::ConvertLayoutOp::create(
644  rewriter, op.getLoc(), src.getType(), src, input, target);
645  newOps[i] = newOp;
646  }
647  }
648  rewriter.replaceOpWithMultiple(op, {newOps});
649  return success();
650  }
651 };
652 
653 // Handles UnrealizedConversionCastOp generated during
654 // SCFStructuralTypeConversions (step 1). This op may appear as either a
655 // target or source materialization for Vector values, e.g.:
656 // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
657 // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
658 // it could be either 1:N or N:1 cast. In both cases, the pattern
659 // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
660 // for example, the following scf::forOp
661 // ```
662 // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
663 // %n = use(%arg1): vector<128x128xf16>
664 // scf.yield %n : vector<128x128xf16>
665 // }
666 // ```
667 // Could be converted to:
668 // ```
669 // %1 = unrealized_conversion_cast %0
670 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
671 // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
672 // -> (vector<16x16xf16>, vector<16x16xf16) {
673 // %m = unrealized_conversion_cast %arg1, %arg2
674 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
675 // %n = use(%m): vector<128x128xf16>
676 // %b = unrealized_conversion_cast %n
677 // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
678 // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
679 // }
680 // %cast = unrealized_conversion_cast %for:2
681 // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
682 // ```
683 // TODO: remove it when context-aware type converter is ready.
684 struct UnrealizedConversionCastOpPattern
685  : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
686  using OpConversionPattern<
687  mlir::UnrealizedConversionCastOp>::OpConversionPattern;
688 
689  mlir::LogicalResult
690  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
691  ConversionPatternRewriter &rewriter) const override {
692  SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
693 
694  auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
695  auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
696 
697  if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
698  !llvm::all_equal(ValueRange(inputs).getTypes()))
699  return failure();
700 
701  // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
702  // It is generated by source materialization (e.g., inits to scf forOp).
703  // The input values provided by the adaptor should already be distributed,
704  // and their types should correspond exactly to the result types of the
705  // operation.
706  if (op.getNumOperands() == 1 &&
707  llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
708  rewriter.replaceOp(op, inputs);
709  return success();
710  }
711 
712  // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
713  // It is generated by target materialization (e.g., arguments/results
714  // of scf forOp). All input values must have the same vector type, and
715  // their shape must be evenly divisible by the output vector's shape
716  // (determined by the nature of the workgroup to subgroup distribution).
717  // TODO: it is not safe to do such forward, since such N:1 cast could be
718  // from others.
719  if (op.getNumResults() == 1 &&
720  computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
721  rewriter.replaceOpWithMultiple(op, {inputs});
722  return success();
723  }
724 
725  return mlir::failure();
726  }
727 };
728 
729 // This pattern distributes arith.constant op into subgroup-level constants
730 struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
732 
733  LogicalResult
734  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
735  ConversionPatternRewriter &rewriter) const override {
736  auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
737  auto vecType = dyn_cast<VectorType>(op.getType());
738  if (!vecAttr || !vecAttr.isSplat() || !vecType)
739  return failure();
740 
741  xegpu::DistributeLayoutAttr layout =
742  xegpu::getDistributeLayoutAttr(op.getResult());
743  if (!layout || !layout.isForWorkgroup())
744  return failure();
745 
746  ArrayRef<int64_t> wgShape = vecType.getShape();
747  SmallVector<int64_t> sgShape;
748  int count;
749  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
750 
751  // Current limitation: constant of vector with single value.
752  // TODO: support more complex cases, e.g., vector with multiple values.
753  Attribute singleVal = vecAttr.getSplatValue<Attribute>();
754 
755  auto newType = VectorType::get(sgShape, vecType.getElementType());
756  auto sgAttr = DenseElementsAttr::get(newType, singleVal);
757  auto cstOp =
758  arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
759  if (auto newLayout = layout.dropSgLayoutAndData())
760  xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
761  SmallVector<Value> newConsts(count, cstOp);
762 
763  rewriter.replaceOpWithMultiple(op, {newConsts});
764  return success();
765  }
766 };
767 
768 // This pattern transforms the LoadGatherOp with explicit offsets to load
769 // subgroup data
770 struct WgToSgLoadGatherOpWithOffset
771  : public OpConversionPattern<xegpu::LoadGatherOp> {
773  LogicalResult
774  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
775  ConversionPatternRewriter &rewriter) const override {
776 
777  if (!op.getOffsets())
778  return failure();
779 
780  Location loc = op.getLoc();
781  VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
782  if (!resultType)
783  return failure();
784  ArrayRef<int64_t> wgShape = resultType.getShape();
785 
786  xegpu::DistributeLayoutAttr layout =
787  xegpu::getDistributeLayoutAttr(op.getResult());
788  if (!layout || !layout.isForWorkgroup())
789  return failure();
790 
791  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
792 
793  // The offsets need to be distributed
794  auto offsetsVecType =
795  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
796  auto maskVecType =
797  dyn_cast<VectorType>(adaptor.getMask().front().getType());
798  if (!offsetsVecType || !maskVecType ||
799  offsetsVecType.getShape() != maskVecType.getShape()) {
800  return rewriter.notifyMatchFailure(op,
801  "offsets have not been distributed");
802  }
803 
804  SmallVector<Value> newLoadOps;
805  auto chunkSizeAttr =
806  rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
807  VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
808  for (auto [offsets, mask] :
809  llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
810  auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
811  loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
812  op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
813  xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
814  layout.dropSgLayoutAndData());
815  newLoadOps.push_back(newLoadOp);
816  }
817  rewriter.replaceOpWithMultiple(op, {newLoadOps});
818  return success();
819  }
820 };
821 
822 // This pattern transforms the StoreScatterOp with explicit offsets to store
823 // subgroup data
824 struct WgToSgStoreScatterOpWithOffset
825  : public OpConversionPattern<xegpu::StoreScatterOp> {
827  LogicalResult
828  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
829  ConversionPatternRewriter &rewriter) const override {
830 
831  if (!op.getOffsets())
832  return failure();
833 
834  Location loc = op.getLoc();
835  VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
836  if (!valueType)
837  return failure();
838 
839  xegpu::DistributeLayoutAttr layout =
840  xegpu::getDistributeLayoutAttr(op.getValue());
841  if (!layout || !layout.isForWorkgroup())
842  return failure();
843 
844  // The offsets need to be distributed
845  auto offsetsVecType =
846  dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
847  auto maskVecType =
848  dyn_cast<VectorType>(adaptor.getMask().front().getType());
849  if (!offsetsVecType || !maskVecType ||
850  offsetsVecType.getShape() != maskVecType.getShape()) {
851  return rewriter.notifyMatchFailure(op,
852  "offsets have not been distributed");
853  }
854 
855  auto chunkSizeOpt = op.getChunkSize();
856  int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
857  auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
858  for (auto [val, offs, mask] : llvm::zip(
859  adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
860  rewriter.create<xegpu::StoreScatterOp>(
861  loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
862  op.getL2HintAttr(), op.getL3HintAttr());
863  // Update the layout attribute to drop sg_layout and sg_data.
864  if (auto newLayout = layout.dropSgLayoutAndData())
865  op->setAttr("layout", newLayout);
866  }
867  rewriter.eraseOp(op);
868  return success();
869  }
870 };
871 
872 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
874  LogicalResult
875  matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
876  ConversionPatternRewriter &rewriter) const override {
877 
879  if (failed(genOffsetsList(rewriter, op, offsetsList)))
880  return failure();
881 
882  ArrayRef<int64_t> wgShape = op.getDataShape();
883  VectorType valueTy = op.getRes().getType();
884  Type elemTy = valueTy.getElementType();
885 
886  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
887  SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
888  VectorType newResTy = VectorType::get(sgShape, elemTy);
889  SmallVector<Value> newOps;
890  for (auto offsets : offsetsList) {
891  auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
892  op.getLoc(), newResTy, op.getMemDesc(), offsets,
893  layout.dropSgLayoutAndData());
894  newOps.push_back(newOp);
895  }
896  rewriter.replaceOpWithMultiple(op, {newOps});
897 
898  return success();
899  }
900 };
901 
902 struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
904  LogicalResult
905  matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
906  ConversionPatternRewriter &rewriter) const override {
907 
909  if (failed(genOffsetsList(rewriter, op, offsetsList)))
910  return failure();
911 
912  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
913  for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
914  rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
915  offsets,
916  layout.dropSgLayoutAndData());
917  rewriter.eraseOp(op);
918  return success();
919  }
920 };
921 
922 } // namespace
923 
924 namespace mlir {
925 namespace xegpu {
927  patterns
928  .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
929  WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
930  WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
931  WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
932  WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
933  WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
934  WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
935  WgToSgStoreMatrixOp>(patterns.getContext());
936 }
937 } // namespace xegpu
938 } // namespace mlir
939 
940 namespace {
941 struct XeGPUWgToSgDistributePass
942  : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
943  void runOnOperation() override;
944 };
945 } // namespace
946 
947 void XeGPUWgToSgDistributePass::runOnOperation() {
948  // Track existing UnrealizedConversionCastOps
949  SmallVector<Operation *> existingCastOps;
950  getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
951  existingCastOps.push_back(castOp.getOperation());
952  });
953 
954  {
955  // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
956  // VectorType operands. This first converts such operands to
957  // RankedTensorType, propagates the layout attribute into the encoding
958  // attribute, and finally converts the RankedTensorType to VectorType based
959  // on the encoding.
960 
961  TypeConverter converter;
962  converter.addConversion([&](Type type) -> Type { return type; });
963  converter.addConversion(
964  [&](RankedTensorType type,
965  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
966  Type elemTy = type.getElementType();
967  ArrayRef<int64_t> shape = type.getShape();
968 
969  int count;
970  SmallVector<int64_t> subShape;
971  std::tie(subShape, count) = getSgShapeAndCount(
972  shape,
973  dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
974 
975  auto newTy = VectorType::get(subShape, elemTy);
976  result.append(count, newTy);
977  return success();
978  });
979 
981  converter);
982  }
983 
984  // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
985  // as well as XeGPU, Arith, and Vector operations.
986  MLIRContext *ctx = &getContext();
988  ConversionTarget target(*ctx);
989  TypeConverter converter;
990  converter.addConversion([&](Type type) -> Type { return type; });
991  converter.addConversion(
992  [&](xegpu::TensorDescType type,
993  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
994  Type elemTy = type.getElementType();
995  ArrayRef<int64_t> shape = type.getShape();
996 
997  int count;
998  SmallVector<int64_t> subShape;
999  xegpu::LayoutAttr layout = type.getLayoutAttr();
1000  std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1001 
1002  if (layout)
1003  layout = layout.dropSgLayoutAndData();
1004 
1005  auto newTy = xegpu::TensorDescType::get(
1006  type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1007  result.append(count, newTy);
1008  return success();
1009  });
1010 
1011  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1012  if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1013  return createOp.getType();
1014  if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1015  return loadOp.getTensorDescType();
1016  if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1017  return storeOp.getTensorDescType();
1018  if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1019  return updateOp.getType();
1020  if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1021  return prefetchOp.getTensorDescType();
1022  return xegpu::TensorDescType();
1023  };
1024 
1025  auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
1026  return !layout || !layout.isForWorkgroup();
1027  };
1028 
1029  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1030  xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1031  xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
1032  auto tdescTy = getTensorDescType(op);
1033  auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1034  return isLegal(layout);
1035  });
1036 
1037  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
1038  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1039  return isLegal(layout);
1040  });
1041 
1042  target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1043  [=](xegpu::LoadMatrixOp op) -> bool {
1044  return isLegal(op.getLayoutAttr());
1045  });
1046 
1047  target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1048  [=](xegpu::StoreMatrixOp op) -> bool {
1049  return isLegal(op.getLayoutAttr());
1050  });
1051 
1052  target.addDynamicallyLegalOp<arith::ConstantOp>(
1053  [=](arith::ConstantOp op) -> bool {
1054  auto vecType = dyn_cast<VectorType>(op.getType());
1055  if (!vecType)
1056  return true;
1057  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1058  });
1059 
1060  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1061  [=](xegpu::LoadGatherOp op) -> bool {
1062  auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1063  return isLegal(layout);
1064  });
1065 
1066  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1067  [=](xegpu::StoreScatterOp op) -> bool {
1068  // Check if the layout attribute is present on the result.
1069  auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
1070  if (!layout)
1071  return true;
1072  return isLegal(layout);
1073  });
1074 
1075  target.addDynamicallyLegalOp<vector::BroadcastOp>(
1076  [=](vector::BroadcastOp op) -> bool {
1077  return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1078  });
1079 
1080  target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1081  [=](xegpu::ConvertLayoutOp op) -> bool {
1082  return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1083  });
1084 
1085  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1086  [=](Operation *op) -> std::optional<bool> {
1087  // Only handle elementwise mappable ops
1089  return true;
1090 
1091  VectorType resultType =
1092  dyn_cast<VectorType>(op->getResult(0).getType());
1093  if (!resultType)
1094  return true;
1095 
1096  // Check if all operands are vectors of the same shape
1097  // TODO: Support other types.
1098  for (Value operand : op->getOperands()) {
1099  VectorType operandType = dyn_cast<VectorType>(operand.getType());
1100  if (!operandType || operandType.getShape() != resultType.getShape()) {
1101  return true;
1102  }
1103  }
1104 
1105  xegpu::DistributeLayoutAttr layout =
1106  xegpu::getDistributeLayoutAttr(op->getResult(0));
1107  return isLegal(layout);
1108  });
1109 
1110  target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1111  [=](UnrealizedConversionCastOp op) {
1112  return llvm::is_contained(existingCastOps, op.getOperation());
1113  });
1114 
1115  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
1116 
1118  target);
1120  if (failed(
1121  applyPartialConversion(getOperation(), target, std::move(patterns))))
1122  return signalPassFailure();
1123 
1124  // Remove sg_layout and sg_data attributes from the Layout
1125  // attribute for each VectorType result of the operation.
1126  // For Structured Control Flow ops, the layout is simply removed,
1127  // since in 1:N case, the layout for new results are missing.
1128  // Layout propagation pass will activated.
1129  getOperation()->walk([](Operation *op) {
1130  for (OpResult result : op->getOpResults()) {
1131  std::string name = xegpu::getLayoutName(result);
1132  if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
1133  op->removeAttr(name);
1134  if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1135  if (auto newLayout = layout.dropSgLayoutAndData())
1136  op->setAttr(name, newLayout);
1137  }
1138  }
1139  }
1140  });
1141 }
static MLIRContext * getContext(OpFoldResult val)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getOpResults()
Definition: Operation.h:420
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:179
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:285
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:117
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Definition: XeGPUUtils.cpp:476
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.