MLIR  22.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeRange.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/IR/Visitors.h"
30 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 
38 namespace mlir {
39 namespace xegpu {
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
42 } // namespace xegpu
43 } // namespace mlir
44 
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47 
48 using namespace mlir;
49 
50 static const char *const resolveSIMTTypeMismatch =
51  "resolve_simt_type_mismatch"; // Attribute name for identifying
52  // UnrelizedConversionCastOp added to resolve
53  // SIMT type mismatches.
54 
55 namespace {
56 
57 //===----------------------------------------------------------------------===//
58 // SIMT Distribution Patterns
59 //===----------------------------------------------------------------------===//
60 
61 /// In certain cases, we may need to favor XeGPU specific distribution patterns
62 /// over generic vector distribution patterns. In such cases, we can assign
63 /// priorities to patterns.
64 static constexpr unsigned regularPatternBenefit = 1;
65 static constexpr unsigned highPatternBenefit = 2;
66 
67 /// Helper function to get distributed vector type for a source vector type
68 /// according to the lane_layout. We simply divide each dimension of tensor
69 /// descriptor shape by corresponding lane_layout dimension. If
70 /// array_length > 1, that is appended to the front of the ditributed shape.
71 /// NOTE: This is the vector type that will be returned by the
72 /// gpu.warp_execute_on_lane0 op.
73 ///
74 /// Examples:
75 /// | original vector shape | lane_layout | distributed vector shape |
76 /// |-----------------------|-------------|--------------------------|
77 /// | 32x16 | [1, 16] | 32x1 |
78 /// | 32x16 | [2, 8] | 16x2 |
79 /// | 2x32x16 | [1, 16] | 2x32x1 |
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82  VectorType originalType) {
83  if (!layout)
84  return failure();
85  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86  "Expecting a valid layout.");
87  SmallVector<int64_t> effectiveLaneLayout =
88  layout.getEffectiveLaneLayoutAsInt();
89  assert(static_cast<size_t>(originalType.getRank()) >=
90  effectiveLaneLayout.size() &&
91  "Rank of the original vector type should be greater or equal to the "
92  "size of the lane layout to distribute the vector type.");
93  SmallVector<int64_t> distributedShape(originalType.getShape());
94  // Only distribute the last `laneLayout.size()` dimensions. The remaining
95  // dimensions are not distributed.
96  unsigned distributionStart =
97  originalType.getRank() - effectiveLaneLayout.size();
98  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
99  if (i < distributionStart)
100  continue;
101 
102  // Check if the dimension can be distributed evenly.
103  if (dim % effectiveLaneLayout[i - distributionStart] != 0)
104  return failure();
105  distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
106  }
107  return VectorType::get(distributedShape, originalType.getElementType());
108 }
109 
110 /// Helper function to resolve types if the distributed type out of
111 /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
112 /// Example 1:
113 /// distributed type: vector<8x1xf32>
114 /// expected type: vector<8xf32>
115 /// resolved using,
116 /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
117 /// Example 2:
118 /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
119 /// expected type: xegpu.tensor_desc<8x16xf32>
120 /// resolved using,
121 /// %0 = unrealized_conversion_cast %1 :
122 /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
123 /// xegpu.tensor_desc<8x16xf32>
124 template <typename T>
125 static Value resolveDistributedTy(Value orig, T expected,
126  PatternRewriter &rewriter) {
127  // If orig and expected types are the same, return orig.
128  if (orig.getType() == expected)
129  return orig;
130  // If orig is a vector type, create a shape cast op to reconcile the types.
131  if (isa<VectorType>(orig.getType())) {
132  auto castOp =
133  vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
134  return castOp.getResult();
135  }
136  // If orig is a tensor descriptor type, create an unrealized conversion cast
137  // op to reconcile the types.
138  if (isa<xegpu::TensorDescType>(orig.getType())) {
139  auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
140  expected, orig);
141  castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
142  return castOp.getResult(0);
143  }
144  llvm_unreachable("Unsupported type for reconciliation");
145  return orig;
146 }
147 
148 /// Helper function to check if the layout is packed. Layout is packed if it is
149 /// 2D and lane_data[0] != 1 (data packed from col dimension).
150 /// TODO: Move to target info.
151 static bool requirePacked(const xegpu::LayoutAttr layout) {
152  if (!layout)
153  return false;
154  auto laneData = layout.getEffectiveLaneDataAsInt();
155  if (laneData.size() != 2)
156  return false;
157  return laneData[0] != 1;
158 }
159 
160 /// Helper function to check if the layout requires a transpose effect.
161 static bool requireTranspose(const xegpu::LayoutAttr layout,
162  const std::string &chipStr) {
163  // Return false for unsupported targets.
164  // TODO: Add more support or move to target info.
165  if (chipStr != "pvc" && chipStr != "bmg")
166  return false;
167  if (!layout)
168  return false;
169  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
170  if (laneLayout.size() != 2)
171  return false;
172  return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1;
173 }
174 
175 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
176 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
177 /// contained within a WarpExecuteOnLane0Op.
178 /// Example:
179 ///
180 /// ```
181 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
182 /// ...
183 /// ...
184 /// gpu.return %result: vector<8x16xf32>
185 /// }
186 /// ```
187 /// To
188 /// ```
189 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
190 /// %laneid = gpu.lane_id : index
191 /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
192 /// ...
193 /// ...
194 /// gpu.yield %result: vector<8x16xf32>
195 /// }
196 /// return %0
197 /// }
198 struct MoveFuncBodyToWarpExecuteOnLane0
199  : public OpRewritePattern<gpu::GPUFuncOp> {
201  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
202  PatternRewriter &rewriter) const override {
203  // If the function only contains a single void return, skip.
204  if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
205  return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
206  }))
207  return failure();
208  // If the function already moved inside a warp_execute_on_lane0, skip.
209  if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
210  return isa<gpu::WarpExecuteOnLane0Op>(op);
211  }))
212  return failure();
213  // Create a new function with the same signature and same attributes.
214  SmallVector<Type> workgroupAttributionsTypes =
215  llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
216  [](BlockArgument arg) { return arg.getType(); });
217  SmallVector<Type> privateAttributionsTypes =
218  llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
219  [](BlockArgument arg) { return arg.getType(); });
220  auto newGpuFunc = gpu::GPUFuncOp::create(
221  rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
222  gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
223  privateAttributionsTypes);
224  newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
225  // Create a WarpExecuteOnLane0Op with same arguments and results as the
226  // original gpuFuncOp.
227  rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
228  auto laneId = gpu::LaneIdOp::create(
229  rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
230  /** upperBound = **/ mlir::IntegerAttr());
231  ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
232  auto warpOp = gpu::WarpExecuteOnLane0Op::create(
233  rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
234  xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
235  newGpuFunc.getArgumentTypes());
236  Block &warpBodyBlock = warpOp.getBodyRegion().front();
237  // Replace the ReturnOp of the original gpu function with a YieldOp.
238  auto origRetunOp =
239  cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
240  rewriter.setInsertionPointAfter(origRetunOp);
241  gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
242  origRetunOp.getOperands());
243  rewriter.eraseOp(origRetunOp);
244  // Move the original function body to the WarpExecuteOnLane0Op body.
245  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
246  warpOp.getBodyRegion().begin());
247  rewriter.eraseBlock(&warpBodyBlock);
248  // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
249  rewriter.setInsertionPointAfter(warpOp);
250  gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
251  rewriter.replaceOp(gpuFuncOp, newGpuFunc);
252  return success();
253  }
254 };
255 
256 /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
257 /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
258 /// still contain the original op that will not be used by the yield op (and
259 /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
260 /// arguments. Tensor descriptor shape is not distributed because it is a
261 /// uniform value across all work items within the subgroup. However, the
262 /// layout information is dropped in the new tensor descriptor type.
263 ///
264 /// Example:
265 ///
266 /// ```
267 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
268 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
269 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
270 /// ...
271 /// %td = xegpu.create_nd_tdesc %arg0[0, 0]
272 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
273 /// vector.yield %td
274 /// }
275 /// ```
276 /// To
277 /// ```
278 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
279 /// ...
280 /// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
281 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
282 /// vector.yield %arg0, %dead
283 /// }
284 /// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
285 /// -> !xegpu.tensor_desc<4x8xf32>
286 ///
287 /// ```
288 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
289  using gpu::WarpDistributionPattern::WarpDistributionPattern;
290  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
291  PatternRewriter &rewriter) const override {
292  OpOperand *operand =
293  getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
294  if (!operand)
295  return rewriter.notifyMatchFailure(
296  warpOp, "warp result is not a xegpu::CreateNdDesc op");
297  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
298  unsigned operandIdx = operand->getOperandNumber();
299 
300  xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
301  if (!layout)
302  return rewriter.notifyMatchFailure(
303  descOp, "the tensor descriptor lacks layout attribute");
304 
305  SmallVector<size_t> newRetIndices;
306  rewriter.setInsertionPoint(warpOp);
307  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
308  rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
309  /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
310 
311  SmallVector<Value> newDescOperands = llvm::map_to_vector(
312  newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
313  rewriter.setInsertionPointAfter(newWarpOp);
314  xegpu::TensorDescType distributedTensorDescTy =
315  descOp.getType().dropLayouts(); // Distributed tensor descriptor type
316  // does not contain layout info.
317  Value newDescOp = xegpu::CreateNdDescOp::create(
318  rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
319  descOp->getAttrs());
320 
321  Value distributedVal = newWarpOp.getResult(operandIdx);
322  // Resolve the distributed type to the expected type.
323  newDescOp =
324  resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
325  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
326  return success();
327  }
328 };
329 
330 /// Distribute a store_nd op at the end of enclosing
331 /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
332 /// through the warp op interface they would be propagated as returned values.
333 /// Source vector is distributed based on lane layout. Appropriate cast ops are
334 /// inserted if the distributed types does not match expected xegpu SIMT types.
335 ///
336 /// Example:
337 ///
338 /// ```
339 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
340 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
341 /// ...
342 /// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
343 /// !xegpu.tensor_desc<4x8xf32, #layout0>
344 /// }
345 /// ```
346 /// To
347 /// ```
348 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
349 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
350 /// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
351 /// #layout0>
352 /// }
353 /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
354 /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
355 /// #layout0>
356 /// -> !xegpu.tensor_desc<4x8xf32>
357 /// xegpu.store_nd %0, %1: vector<4xf32>,
358 /// !xegpu.tensor_desc<4x8xf32>
359 ///
360 /// ```
361 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
362  using gpu::WarpDistributionPattern::WarpDistributionPattern;
363  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
364  PatternRewriter &rewriter) const override {
365  gpu::YieldOp yield = warpOp.getTerminator();
366  Operation *lastNode = yield->getPrevNode();
367  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
368  if (!storeOp)
369  return failure();
370 
371  int64_t offsetSize = static_cast<int64_t>(storeOp.getOffsets().size());
372  if ((offsetSize != 0) || storeOp.getConstOffsetsAttr())
373  return failure();
374 
375  xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
376  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
377  if (!layout)
378  return rewriter.notifyMatchFailure(
379  storeOp, "the source tensor descriptor lacks layout attribute");
380 
381  FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
382  getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
383  if (failed(distributedTypeByWarpOpOrFailure))
384  return rewriter.notifyMatchFailure(storeOp,
385  "Failed to distribute the type");
386  VectorType distributedTypeByWarpOp =
387  distributedTypeByWarpOpOrFailure.value();
388 
389  SmallVector<size_t> newRetIndices;
390  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
391  rewriter, warpOp,
392  /* new yielded values = */
393  ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
394  /* new yielded types = */
395  TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
396  newRetIndices);
397  // Create a new store op outside the warp op with the distributed vector
398  // type. Tensor descriptor is not distributed.
399  rewriter.setInsertionPointAfter(newWarpOp);
400  SmallVector<Value> newStoreOperands;
401 
402  // For the value operand, there can be a mismatch between the vector type
403  // distributed by the warp op and (xegpu-specific) distributed type
404  // supported by the store op. Type mismatch must be resolved using
405  // appropriate cast op.
406  FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
407  xegpu::getDistributedVectorType(storeOp.getTensorDescType());
408  if (failed(storeNdDistributedValueTyOrFailure))
409  return rewriter.notifyMatchFailure(
410  storeOp, "Failed to get distributed vector type for the store op");
411  newStoreOperands.push_back(resolveDistributedTy(
412  newWarpOp.getResult(newRetIndices[0]),
413  storeNdDistributedValueTyOrFailure.value(), rewriter));
414  // For the tensor descriptor operand, the layout attribute is dropped after
415  // distribution. Types needs to be resolved in this case also.
416  xegpu::TensorDescType distributedTensorDescTy =
417  storeOp.getTensorDescType().dropLayouts();
418  newStoreOperands.push_back(
419  resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
420  distributedTensorDescTy, rewriter));
421 
422  auto newStoreOp =
423  xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
424  newStoreOperands, storeOp->getAttrs());
425  xegpu::removeLayoutAttrs(newStoreOp);
426  rewriter.eraseOp(storeOp);
427  return success();
428  }
429 };
430 
431 /// Distribute a load_nd op feeding into vector.yield op for the enclosing
432 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
433 /// The warp op will still contain the original op that will not be used by
434 /// the yield op (and should be cleaned up later). The yield op will
435 /// bypass the load's arguments. Only the loaded vector is distributed
436 /// according to lane layout and, tensor descriptor types is not
437 /// distributed. Appropriate cast ops are inserted if the distributed types does
438 /// not match expected xegpu SIMT types.
439 ///
440 /// Example:
441 ///
442 /// ```
443 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
444 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
445 /// (vector<4x1xf32>) {
446 /// ...
447 /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
448 /// ->
449 /// vector<4x8xf32>
450 /// gpu.yield %ld
451 /// }
452 /// ```
453 /// To
454 /// ```
455 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
456 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
457 /// ...
458 /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
459 /// vector<4x8xf32> gpu.yield %dead, %arg0
460 /// }
461 /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
462 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
463 /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
464 /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
465 ///
466 /// ```
467 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
468  using gpu::WarpDistributionPattern::WarpDistributionPattern;
469  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
470  PatternRewriter &rewriter) const override {
471  OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
472  if (!isa<xegpu::LoadNdOp>(op))
473  return false;
474  // Make sure the same load op is the last operation in the warp op body.
475  // This ensure that load op is not sinked earlier violating any barrier
476  // synchronizations.
477  gpu::YieldOp yield = warpOp.getTerminator();
478  return yield->getPrevNode() == op;
479  });
480 
481  if (!operand)
482  return rewriter.notifyMatchFailure(
483  warpOp, "warp result is not a xegpu::LoadNd op");
484 
485  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
486  // Chip information is required to decide if the layout requires transpose
487  // effect.
488  auto chipStr = xegpu::getChipStr(loadOp);
489  if (!chipStr)
490  return rewriter.notifyMatchFailure(
491  loadOp,
492  "xegpu::LoadNdOp require chip information to determine transpose "
493  "requirement");
494  int64_t offsetSize = static_cast<int64_t>(loadOp.getOffsets().size());
495  if ((offsetSize != 0) || loadOp.getConstOffsetsAttr())
496  return failure();
497 
498  xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
499  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
500  if (!layout)
501  return rewriter.notifyMatchFailure(
502  loadOp, "the source tensor descriptor lacks layout attribute");
503 
504  unsigned operandIdx = operand->getOperandNumber();
505  VectorType distributedTypeByWarpOp =
506  cast<VectorType>(warpOp.getResult(operandIdx).getType());
507 
508  SmallVector<size_t> newRetIndices;
509  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510  rewriter, warpOp,
511  /* new yielded values = */ loadOp.getTensorDesc(),
512  /* new yielded types = */ tensorDescTy, newRetIndices);
513 
514  // Create a new load op outside the warp op with the distributed vector
515  // type.
516  rewriter.setInsertionPointAfter(newWarpOp);
517  FailureOr<VectorType> loadNdDistValueTyOrFailure =
518  xegpu::getDistributedVectorType(loadOp.getTensorDescType());
519  if (failed(loadNdDistValueTyOrFailure))
520  return rewriter.notifyMatchFailure(
521  loadOp, "Failed to get distributed vector type for the load op");
522  xegpu::TensorDescType distributedTensorDescTy =
523  loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
524  // descriptor type does not
525  // contain layout info.
526  auto newLoadOp = xegpu::LoadNdOp::create(
527  rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
528  resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
529  distributedTensorDescTy, rewriter),
530  loadOp->getAttrs());
531  xegpu::removeLayoutAttrs(newLoadOp);
532  // Set the packed attribute if the layout requires it.
533  newLoadOp.setPacked(requirePacked(layout));
534  // Set the transpose attribute if the layout requires it.
535  if (requireTranspose(layout, chipStr.value()))
536  newLoadOp.setTranspose(
537  DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
538  Value distributedVal = newWarpOp.getResult(operandIdx);
539  // There can be a conflict between the vector type distributed by the
540  // warp op and (xegpu-specific) distributed type supported by the load
541  // op. Resolve these mismatches by inserting a cast.
542  Value tyResolvedVal = resolveDistributedTy(
543  newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
544  rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
545  return success();
546  }
547 };
548 
549 /// Distribute a dpas op feeding into vector.yield op for the enclosing
550 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
551 /// The warp op will still contain the original op that will not be used by
552 /// the yield op (and should be cleaned up later). The yield op will
553 /// bypass the dpas's arguments. Appropriate cast ops are inserted if the
554 /// distributed types does not match expected xegpu SIMT types.
555 /// Example:
556 /// ```
557 /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
558 /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
559 /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
560 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
561 /// (vector<8x1xf32>) {
562 /// ...
563 /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
564 /// vector<8x16xf32>
565 /// gpu.yield %dpas
566 /// }
567 /// ```
568 /// To
569 /// ```
570 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
571 /// vector<8x1xf16>, vector<16x1xf16>) {
572 /// ...
573 /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
574 /// -> vector<8x16xf32>
575 /// gpu.yield %dead, %arg0, %arg1
576 /// }
577 /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
578 /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
579 /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
580 /// vector<8xf32>
581 /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
582 /// ```
583 struct DpasDistribution final : public gpu::WarpDistributionPattern {
584  using gpu::WarpDistributionPattern::WarpDistributionPattern;
585  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
586  PatternRewriter &rewriter) const override {
587  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
588  if (!operand)
589  return rewriter.notifyMatchFailure(warpOp,
590  "warp result is not a xegpu::Dpas op");
591 
592  auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
593  unsigned operandIdx = operand->getOperandNumber();
594  std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
595  std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
596  std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
597 
598  xegpu::LayoutAttr layoutA =
599  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
600  xegpu::LayoutAttr layoutB =
601  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
602  xegpu::LayoutAttr layoutOut =
603  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
604  if (!layoutA || !layoutB || !layoutOut)
605  return rewriter.notifyMatchFailure(
606  dpasOp,
607  "the xegpu::Dpas op lacks layout attribute for A, B or output");
608 
609  FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
610  getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
611  FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
612  getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
613  FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
614  getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
615  if (failed(distLhsTypeByWarpOpOrFailure) ||
616  failed(distRhsTypeByWarpOpOrFailure) ||
617  failed(distResultTypeByWarpOpOrFailure))
618  return rewriter.notifyMatchFailure(
619  dpasOp,
620  "Failed to distribute the A, B or output types in xegpu::Dpas op");
621 
622  llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
623  dpasOp.getRhs()};
624  llvm::SmallVector<Type, 3> newYieldTypes{
625  distLhsTypeByWarpOpOrFailure.value(),
626  distRhsTypeByWarpOpOrFailure.value()};
627  // Dpas acc operand is optional.
628  if (dpasOp.getAcc()) {
629  newYieldValues.push_back(dpasOp.getAcc());
630  newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
631  }
632  // Create a new warp op without the dpas.
633  SmallVector<size_t> newRetIndices;
634  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
635  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
636 
637  FailureOr<VectorType> expectedDistLhsTyOrFailure =
638  xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
639  FailureOr<VectorType> expectedDistRhsTyOrFailure =
640  xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
641  FailureOr<VectorType> expectedDistResultTyOrFailure =
642  xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
643  if (failed(expectedDistLhsTyOrFailure) ||
644  failed(expectedDistRhsTyOrFailure) ||
645  failed(expectedDistResultTyOrFailure))
646  return rewriter.notifyMatchFailure(
647  dpasOp,
648  "Failed to get distributed vector type for the dpas operands.");
649  // Create a new dpas op outside the warp op.
650  rewriter.setInsertionPointAfter(newWarpOp);
651  SmallVector<Value> newDpasOperands;
652  SmallVector<VectorType> newDpasOperandExpectedTypes;
653 
654  // Resolve the distributed types with the original types.
655  newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
656  newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
657  VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
658  if (dpasOp.getAcc())
659  newDpasOperandExpectedTypes.push_back(distributedResultTy);
660 
661  for (unsigned i = 0; i < newRetIndices.size(); i++) {
662  newDpasOperands.push_back(
663  resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
664  newDpasOperandExpectedTypes[i], rewriter));
665  }
666  auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
667  distributedResultTy, newDpasOperands,
668  dpasOp->getAttrs());
669  xegpu::removeLayoutAttrs(newDpasOp);
670  Value distributedVal = newWarpOp.getResult(operandIdx);
671  // Resolve the output type.
672  Value typeResolved =
673  resolveDistributedTy(newDpasOp.getResult(),
674  distResultTypeByWarpOpOrFailure.value(), rewriter);
675  rewriter.replaceAllUsesWith(distributedVal, typeResolved);
676  return success();
677  }
678 };
679 
680 /// Sink an update_nd_offset op feeding into yield op of an enclosing
681 /// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
682 /// original op that will not be used by the yield op (and should be cleaned
683 /// up later). The yield op will bypass the updateOp's arguments. The tensor
684 /// descriptor type is not distributed. Appropriate cast ops are inserted if
685 /// the distributed types does not match expected xegpu SIMT types.
686 /// Example:
687 /// ```
688 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
689 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
690 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
691 /// ...
692 /// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
693 /// !xegpu.tensor_desc<4x8xf32, #layout0>
694 /// gpu.yield %update
695 /// }
696 /// ...
697 /// ```
698 /// To
699 /// ```
700 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
701 /// !xegpu.tensor_desc<4x8xf32, #layout0>,
702 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
703 /// ...
704 /// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
705 /// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
706 /// gpu.yield %dead, %arg0, %c32, %c16
707 /// }
708 /// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
709 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
710 /// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
711 /// !xegpu.tensor_desc<4x8xf32>
712 /// ...
713 /// ```
714 struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
715  using gpu::WarpDistributionPattern::WarpDistributionPattern;
716  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
717  PatternRewriter &rewriter) const override {
718  OpOperand *operand =
719  getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
720  if (!operand)
721  return rewriter.notifyMatchFailure(
722  warpOp, "warp result is not a xegpu::UpdateNdOffset op");
723  auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
724  unsigned operandIdx = operand->getOperandNumber();
725 
726  SmallVector<size_t> newRetIndices;
727  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
728  rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(),
729  newRetIndices);
730  rewriter.setInsertionPointAfter(newWarpOp);
731  // new update op does not have layout attribute.
732  xegpu::TensorDescType distributedTensorDescTy =
733  updateOp.getTensorDescType().dropLayouts();
734  SmallVector<Value> newUpdateOperands =
735  llvm::map_to_vector(newRetIndices, [&](size_t i) {
736  // For the tensor descriptor operand, the layout attribute is
737  // dropped after distribution. Types needs to be resolved in this
738  // case.
739  if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
740  return resolveDistributedTy(newWarpOp.getResult(i),
741  distributedTensorDescTy, rewriter);
742  }
743  return newWarpOp.getResult(i);
744  });
745  // Create a new update op outside the warp op.
746  auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
747  rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
748  newUpdateOperands, updateOp->getAttrs());
749  xegpu::removeLayoutAttrs(newUpdateOp);
750  Value distributedVal = newWarpOp.getResult(operandIdx);
751  // Resolve the distributed type with the original type.
752  Value typeResolved = resolveDistributedTy(
753  newUpdateOp.getResult(), distributedVal.getType(), rewriter);
754  rewriter.replaceAllUsesWith(distributedVal, typeResolved);
755  return success();
756  }
757 };
758 
759 /// Distribute a prefetch_nd op at the end of enclosing
760 /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
761 /// through the warp op interface they would be propagated as returned values.
762 /// Tensor descriptor shape is not distributed because it is a uniform value
763 /// across all work items within the subgroup. Appropriate cast ops are inserted
764 /// if the distributed types does not match expected xegpu SIMT types.
765 ///
766 /// Example:
767 ///
768 /// ```
769 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
770 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
771 /// ...
772 /// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
773 /// }
774 /// ```
775 /// To
776 /// ```
777 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
778 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
779 /// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
780 /// }
781 /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
782 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
783 /// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
784 ///
785 /// ```
786 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
787  using gpu::WarpDistributionPattern::WarpDistributionPattern;
788  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
789  PatternRewriter &rewriter) const override {
790  gpu::YieldOp yield = warpOp.getTerminator();
791  Operation *lastNode = yield->getPrevNode();
792  auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
793  if (!prefetchOp)
794  return failure();
795 
796  int64_t offsetSize = static_cast<int64_t>(prefetchOp.getOffsets().size());
797  if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr())
798  return failure();
799 
800  xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
801  if (!layout)
802  return rewriter.notifyMatchFailure(
803  prefetchOp, "the source tensor descriptor lacks layout attribute");
804 
805  SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
806  SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
807  SmallVector<size_t> newRetIndices;
808  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
809  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
810  // Create a new prefetch op outside the warp op with updated tensor
811  // descriptor type. Source tensor descriptor require type resolution.
812  xegpu::TensorDescType newTensorDescTy =
813  prefetchOp.getTensorDescType().dropLayouts();
814  rewriter.setInsertionPointAfter(newWarpOp);
815  SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
816  newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
817  xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
818  newPrefetchOperands, prefetchOp->getAttrs());
819  xegpu::removeLayoutAttrs(prefetchOp);
820  rewriter.eraseOp(prefetchOp);
821  return success();
822  }
823 };
824 
825 /// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
826 /// region. This will simply move the barrier op outside of the warp op.
827 struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
828  using gpu::WarpDistributionPattern::WarpDistributionPattern;
829  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
830  PatternRewriter &rewriter) const override {
831  gpu::YieldOp yield = warpOp.getTerminator();
832  Operation *lastNode = yield->getPrevNode();
833  // The last node must be a gpu::BarrierOp.
834  auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
835  if (!barrierOp)
836  return failure();
837  // Move the barrier op outside of the warp op.
838  rewriter.setInsertionPointAfter(warpOp);
839  gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
840  barrierOp->getResultTypes(),
841  barrierOp->getOperands(), barrierOp->getAttrs());
842  rewriter.eraseOp(barrierOp);
843  return success();
844  }
845 };
846 
847 /// Distribute a scattered store op. The offsets argument is required.
848 /// Both offset and mask vectors must be 1D and have #subgroup_size elements.
849 /// The layouts are fixed and implicit: one offset/mask per lane.
850 /// The pass changes the offset/mask vector shapes to a
851 /// single-element vector, **it is assumed that their producer will also be
852 /// distributed**. The payload vector also has a fixed distribution:
853 /// no chunk size -> vector of one element.
854 /// chunk size -> vector of the innermost dimension of the SG-payload.
855 /// Example 1 (no chunk size):
856 /// %mask = producer_op : vector<16xi1>
857 /// %offset = producer_op : vector<16xindex>
858 /// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
859 /// memref<256xf16>, vector<16xindex>, vector<16xi1>
860 /// To
861 /// %mask = producer_op : vector<1xi1>
862 /// %offset = producer_op : vector<1xindex>
863 /// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
864 /// memref<256xf16>, vector<1xindex>, vector<1xi1>
865 /// Example 2 (chunk size, same mask and offsets):
866 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
867 /// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
868 /// To
869 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
870 /// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
871 struct StoreDistribution final : public gpu::WarpDistributionPattern {
872  using gpu::WarpDistributionPattern::WarpDistributionPattern;
873  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
874  PatternRewriter &rewriter) const override {
875  Operation *lastNode = warpOp.getTerminator()->getPrevNode();
876  auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
877  if (!storeScatterOp)
878  return failure();
879  auto offsets = storeScatterOp.getOffsets();
880  if (!offsets || !isa<VectorType>(offsets.getType()))
881  return rewriter.notifyMatchFailure(
882  storeScatterOp, "Store op must have a vector of offsets argument");
883  VectorType offsetsTy = cast<VectorType>(offsets.getType());
884  VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
885  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
886  return rewriter.notifyMatchFailure(storeScatterOp,
887  "Expected 1D offsets and mask vector");
888  VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
889  if (storeVecTy.getRank() > 2)
890  return rewriter.notifyMatchFailure(
891  storeScatterOp, "Expected at most 2D result at SG level");
892 
893  std::string layoutPayloadName =
894  xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
895  std::string layoutOffsetsName =
896  xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
897  std::string layoutMaskName =
898  xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
899 
900  xegpu::LayoutAttr layoutPayload =
901  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
902  xegpu::LayoutAttr layoutOffsets =
903  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
904  xegpu::LayoutAttr layoutMask =
905  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
906 
907  FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
908  getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
909  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
910  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
911  FailureOr<VectorType> distMaskByWarpOpOrFailure =
912  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
913  if (failed(distStoreVecByWarpOpOrFailure) ||
914  failed(distOffsetsByWarpOpOrFailure) ||
915  failed(distMaskByWarpOpOrFailure)) {
916  return rewriter.notifyMatchFailure(
917  storeScatterOp,
918  "Some vector operands have no layouts, using defaults instead.");
919  }
920  VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
921  VectorType expectedPayloadTy = VectorType::get(
922  {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
923 
924  SmallVector<size_t> newRetIndices;
925  SmallVector<Value> operands = storeScatterOp->getOperands();
926  SmallVector<Type> operandTypesToYield = {
927  expectedPayloadTy, operands[1].getType(),
928  distOffsetsByWarpOpOrFailure.value(),
929  distMaskByWarpOpOrFailure.value()};
930 
931  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
932  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
933  SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
934  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
935 
936  rewriter.setInsertionPointAfter(newWarpOp);
937  xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
938  rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
939  storeScatterOp->getAttrs());
941  rewriter.eraseOp(storeScatterOp);
942  return success();
943  }
944 };
945 
946 /// Distribute a scattered load op. The logic and requirements are the same as
947 /// for the scattered store distribution. The warpOp's payload vector is
948 /// expected to be distributed by the load's result consumer.
949 /// Example 1 (no chunk size):
950 /// %mask = producer_op : vector<16xi1>
951 /// %offset = producer_op : vector<16xindex>
952 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
953 /// vector<16xindex>, vector<16xi1> -> vector<16xf16>
954 /// To
955 /// %mask = producer_op : vector<1xi1>
956 /// %offset = producer_op : vector<1xindex>
957 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
958 /// vector<1xindex>, vector<1xi1> -> vector<1xf16>
959 /// Example 2 (chunk size, same mask and offsets):
960 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
961 /// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
962 /// To
963 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
964 /// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
965 struct LoadDistribution final : public gpu::WarpDistributionPattern {
966  using gpu::WarpDistributionPattern::WarpDistributionPattern;
967  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
968  PatternRewriter &rewriter) const override {
969  OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
970  // Check if the yield operand that was produced by the *last* scattered
971  // load op to avoid sinking it before barriers (maintain memory order).
972  return isa<xegpu::LoadGatherOp>(op) &&
973  warpOp.getTerminator()->getPrevNode() == op;
974  });
975  if (!producedByLastLoad)
976  return rewriter.notifyMatchFailure(
977  warpOp, "The last op is not xegpu::LoadGatherOp");
978 
979  auto loadGatherOp =
980  producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
981  auto offsets = loadGatherOp.getOffsets();
982  if (!offsets || !isa<VectorType>(offsets.getType()) ||
983  !isa<VectorType>(loadGatherOp.getMask().getType()))
984  return rewriter.notifyMatchFailure(
985  loadGatherOp,
986  "Load op must have a vector arguments for offsets and mask");
987  VectorType offsetsTy = cast<VectorType>(offsets.getType());
988  VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
989  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
990  return rewriter.notifyMatchFailure(loadGatherOp,
991  "Expected 1D offsets and mask vector");
992  // Assume offset and mask producers will be distributed as well.
993  std::string layoutOffsetsName =
994  xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
995  std::string layoutMaskName =
996  xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
997 
998  xegpu::LayoutAttr layoutOffsets =
999  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
1000  xegpu::LayoutAttr layoutMask =
1001  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
1002 
1003  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
1004  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
1005  FailureOr<VectorType> distMaskByWarpOpOrFailure =
1006  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
1007  if (failed(distOffsetsByWarpOpOrFailure) ||
1008  failed(distMaskByWarpOpOrFailure)) {
1009  return rewriter.notifyMatchFailure(
1010  loadGatherOp,
1011  "Some vector operands have no layouts, using defaults instead.");
1012  }
1013 
1014  SmallVector<size_t> newRetIndices;
1015  SmallVector<Value> operands = loadGatherOp->getOperands();
1016  SmallVector<Type> operandTypesToYield = {
1017  operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
1018  distMaskByWarpOpOrFailure.value()};
1019 
1020  const unsigned operandIdx = producedByLastLoad->getOperandNumber();
1021  VectorType loadVecTy =
1022  cast<VectorType>(warpOp.getResult(operandIdx).getType());
1023 
1024  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1025  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
1026 
1027  SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1028  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1029 
1030  rewriter.setInsertionPointAfter(newWarpOp);
1031  xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1032  rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1033  loadGatherOp->getAttrs());
1034  xegpu::removeLayoutAttrs(newOp);
1035  Value distributedVal = newWarpOp.getResult(operandIdx);
1036  rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
1037  return success();
1038  }
1039 };
1040 
1041 /// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1042 /// VectorReductionOps.
1043 static Value lowerToVectorReductions(TypedValue<VectorType> src,
1045  vector::CombiningKind kind,
1046  int64_t reductionDim, Location loc,
1047  PatternRewriter &rewriter) {
1048  // Expecting a 2D source vector.
1049  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1050  VectorType sourceType = src.getType();
1051  int64_t sourceH = sourceType.getShape()[0];
1052  int64_t sourceW = sourceType.getShape()[1];
1053  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1054  // Create a constant vector to hold the result of the reduction.
1055  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1056  Value reductionResult = arith::ConstantOp::create(
1057  rewriter, loc, acc.getType(),
1058  DenseElementsAttr::get(acc.getType(), zeroAttr));
1059  // For each slice of the source, extract the slice vector, do a reduction
1060  // and, insert the reduced value back to the result vector.
1061  for (int i = 0; i < nSlices; ++i) {
1062  SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1063  if (reductionDim == 1) {
1064  sliceOffsets = {i, 0};
1065  sliceSizes = {1, sourceW};
1066  } else {
1067  sliceOffsets = {0, i};
1068  sliceSizes = {sourceH, 1};
1069  }
1070  vector::ExtractStridedSliceOp extractOp =
1071  vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1072  sliceSizes, {1, 1});
1073  int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1074  Value slice = vector::ShapeCastOp::create(
1075  rewriter, loc,
1076  VectorType::get({nSliceElements}, sourceType.getElementType()),
1077  extractOp.getResult());
1078  Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1079  Value reduction =
1080  vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
1081  reductionResult =
1082  vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1083  }
1084  return reductionResult;
1085 }
1086 
1087 /// This patterns distribute the `vector.multi_reduction` operation across
1088 /// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1089 /// layouts for the source and accumulator vectors,
1090 /// * If the reduction dimension is distributed across lanes, the reduction is
1091 /// non-lane-local and the reduction is done using warp shuffles. Here we
1092 /// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1093 /// the warp op body.
1094 /// * If the reduction dimension is not distributed across lanes, the reduction
1095 /// is lane-local. In this case, we yield the source and accumulator vectors
1096 /// from the warp op and perform the lane-local reduction outside the warp op
1097 /// using a sequence of ReductionOps.
1098 /// Example 1 (Reduction is lane-local):
1099 /// ```
1100 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1101 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1102 /// %acc = "some_def"() : () -> (vector<32xf32>)
1103 /// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1104 /// vector<32xf32> gpu.yield %1 : vector<32xf32>
1105 /// }
1106 /// ```
1107 /// is lowered to:
1108 /// ```
1109 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1110 /// vector<1xf32>) {
1111 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1112 /// %acc = "some_def"() : () -> (vector<32xf32>)
1113 /// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1114 /// }
1115 /// %c = arith.constant dense<0.0> : vector<1xf32>
1116 /// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1117 /// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1118 /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1119 /// ```
1120 /// Example 2 (Reduction is non-lane-local):
1121 /// ```
1122 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1123 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1124 /// %acc = "some_def"() : () -> (vector<2xf32>)
1125 /// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1126 /// vector<2xf32>
1127 /// gpu.yield %1 : vector<2xf32>
1128 /// }
1129 /// ```
1130 /// is lowered to:
1131 /// ```
1132 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1133 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1134 /// %acc = "some_def"() : () -> (vector<2xf32>)
1135 /// %1 = arith.constant dense<0.0> : vector<2xf32>
1136 /// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1137 /// %3 = ("warp.reduction %2") : f32
1138 /// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1139 /// ... repeat for row 1
1140 /// gpu.yield %1 : vector<2xf32>
1141 /// }
1142 struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1143  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1144  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1145  PatternRewriter &rewriter) const override {
1146  OpOperand *yieldOperand =
1147  getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1148  if (!yieldOperand)
1149  return failure();
1150  auto reductionOp =
1151  cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1152  unsigned operandNumber = yieldOperand->getOperandNumber();
1153  VectorType sourceType = reductionOp.getSourceVectorType();
1154  // Only 2D vectors are supported.
1155  if (sourceType.getRank() != 2)
1156  return rewriter.notifyMatchFailure(warpOp,
1157  "Only 2D reductions are supported.");
1158  ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1159  // Only 1 reduction dimension supported. This also ensures that the result
1160  // is vector type.
1161  if (reductionDims.size() != 1)
1162  return rewriter.notifyMatchFailure(
1163  warpOp, "Only 1 reduction dimension is supported.");
1164  int64_t reductionDim = reductionDims[0];
1165  VectorType distributedResultType =
1166  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1167  VectorType resultType = cast<VectorType>(reductionOp.getType());
1168  xegpu::DistributeLayoutAttr sourceLayout =
1169  xegpu::getDistributeLayoutAttr(reductionOp.getSource());
1170 
1171  FailureOr<VectorType> sourceDistTypeOrFailure =
1172  getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1173  if (failed(sourceDistTypeOrFailure))
1174  return rewriter.notifyMatchFailure(
1175  warpOp, "Failed to distribute the source vector type.");
1176  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1177  // Only single dimension distribution is supported.
1178  bool dim0Distributed =
1179  sourceDistType.getShape()[0] != sourceType.getShape()[0];
1180  bool dim1Distributed =
1181  sourceDistType.getShape()[1] != sourceType.getShape()[1];
1182  if (dim0Distributed && dim1Distributed)
1183  return rewriter.notifyMatchFailure(
1184  warpOp, "Expecting source to be distributed in a single dimension.");
1185  int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1186  if (sourceDistDim == -1)
1187  return rewriter.notifyMatchFailure(
1188  warpOp, "Expecting a distributed source vector.");
1189  bool resultDistributed =
1190  distributedResultType.getNumElements() < resultType.getNumElements();
1191  // If the lane owns all the data required for reduction (i.e. reduction is
1192  // fully parallel accross lanes), then each lane owns part of the result
1193  // (i.e. result is distributed). If the reduction require cross-lane
1194  // shuffling, then the result is shared among all lanes (broadcasted).
1195  // Therefore we expect following cases:
1196  //
1197  // | Source vector | Reduction dim | Result vector |
1198  // |----------------------|----------------|----------------|
1199  // | dim-0 distributed | 0 | broadcasted |
1200  // | dim-0 distributed | 1 | distributed |
1201  // | dim-1 distributed | 0 | distributed |
1202  // | dim-1 distributed | 1 | broadcasted |
1203 
1204  bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1205  (sourceDistDim == 1 && reductionDim == 0);
1206  if (isReductionLaneLocal && !resultDistributed)
1207  return rewriter.notifyMatchFailure(
1208  warpOp, "Expecting a distributed result for lane-local reduction.");
1209 
1210  if (!isReductionLaneLocal && resultDistributed)
1211  return rewriter.notifyMatchFailure(
1212  warpOp,
1213  "Expecting a broadcasted result for non-lane-local reduction.");
1214 
1215  // Handle lane-local reduction case. In this case we fully distribute the
1216  // reduction result.
1217  if (isReductionLaneLocal) {
1218  // Yield the source and acc vectors from the WarpOp.
1219  SmallVector<size_t> newRetIndices;
1220  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1221  rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1222  {sourceDistType, distributedResultType}, newRetIndices);
1223  rewriter.setInsertionPointAfter(newWarpOp);
1224  Value result = lowerToVectorReductions(
1225  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1226  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1227  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1228  // Replace the warp op result with the final result.
1229  rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1230  return success();
1231  }
1232  // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1233  // of multiple ReductionOps. Actual distribution is done by the
1234  // WarpOpReduction pattern.
1235  rewriter.setInsertionPointAfter(reductionOp);
1236  Value result = lowerToVectorReductions(
1237  cast<TypedValue<VectorType>>(reductionOp.getSource()),
1238  cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1239  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1240  // Replace the warp op result with the final result.
1241  rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1242  return success();
1243  }
1244 };
1245 
1246 /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1247 /// `gpu.warp_execute_on_lane_0` region.
1248 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1249  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1250  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1251  PatternRewriter &rewriter) const override {
1252  OpOperand *yieldOperand =
1253  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1254  if (!yieldOperand)
1255  return failure();
1256  auto shapeCastOp =
1257  cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1258  unsigned operandNumber = yieldOperand->getOperandNumber();
1259  auto resultDistTy =
1260  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1261  xegpu::DistributeLayoutAttr sourceLayout =
1262  xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
1263  xegpu::DistributeLayoutAttr resultLayout =
1264  xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
1265  if (!sourceLayout || !resultLayout)
1266  return rewriter.notifyMatchFailure(
1267  warpOp,
1268  "the source or result of shape_cast op lacks distribution layout");
1269 
1270  // For rank reducing or increasing shape_cast ops, the lower rank layout
1271  // must be a slice of higher rank layout.
1272  int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1273  int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1274  if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1275  return rewriter.notifyMatchFailure(
1276  warpOp, "shape_cast is rank reducing but source layout is not a "
1277  "slice of result layout");
1278  if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1279  return rewriter.notifyMatchFailure(
1280  warpOp, "shape_cast is rank increasing but result layout is not a "
1281  "slice of source layout");
1282 
1283  FailureOr<VectorType> sourceDistTypeOrFailure =
1284  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1285  shapeCastOp.getSourceVectorType());
1286  if (failed(sourceDistTypeOrFailure))
1287  return rewriter.notifyMatchFailure(
1288  warpOp, "failed to get distributed vector type for source");
1289  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1290  // Create a new warp op that yields the source of the shape_cast op.
1291  SmallVector<size_t> newRetIndices;
1292  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1293  rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1294  newRetIndices);
1295  rewriter.setInsertionPointAfter(newWarpOp);
1296  Value source = newWarpOp.getResult(newRetIndices[0]);
1297  // Create a new shape_cast op outside the warp op.
1298  Value newShapeCast = vector::ShapeCastOp::create(
1299  rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1300  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1301  newShapeCast);
1302  return success();
1303  }
1304 };
1305 
1306 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1307 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1308 /// outside of the warp op.
1309 struct MemrefExtractAlignedPointerAsIndexDistribution final
1310  : public gpu::WarpDistributionPattern {
1311  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1312  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1313  PatternRewriter &rewriter) const override {
1314  OpOperand *operand = getWarpResult(
1315  warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1316  if (!operand)
1317  return rewriter.notifyMatchFailure(
1318  warpOp,
1319  "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1320  auto extractOp =
1321  operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1322  unsigned operandIdx = operand->getOperandNumber();
1323  SmallVector<size_t> newRetIndices;
1324  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1325  rewriter, warpOp, extractOp.getSource(),
1326  TypeRange{extractOp.getSource().getType()}, newRetIndices);
1327  rewriter.setInsertionPointAfter(newWarpOp);
1328  auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1329  rewriter, newWarpOp.getLoc(), extractOp.getType(),
1330  newWarpOp.getResult(newRetIndices[0]));
1331  Value distributedVal = newWarpOp.getResult(operandIdx);
1332  rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
1333  return success();
1334  }
1335 };
1336 
1337 /// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1338 /// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1339 /// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1340 /// created outside of the warp op with distributed source vector type (computed
1341 /// using assigned layout).
1342 struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1343  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1344  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1345  PatternRewriter &rewriter) const override {
1346  OpOperand *operand =
1347  getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1348  if (!operand)
1349  return rewriter.notifyMatchFailure(
1350  warpOp, "warp result is not a vector::BitCast op");
1351  auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1352  unsigned operandIdx = operand->getOperandNumber();
1353  VectorType distributedSourceType =
1354  getDistVecTypeBasedOnLaneLayout(
1355  xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
1356  bitcastOp.getSourceVectorType())
1357  .value_or(VectorType());
1358  if (!distributedSourceType)
1359  return rewriter.notifyMatchFailure(
1360  bitcastOp, "Failed to distribute the source vector type in "
1361  "vector::BitCast op");
1362  VectorType distributedResultType =
1363  cast<VectorType>(warpOp.getResult(operandIdx).getType());
1364  SmallVector<size_t> newRetIndices;
1365  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1366  rewriter, warpOp, bitcastOp.getSource(),
1367  TypeRange{distributedSourceType}, newRetIndices);
1368  rewriter.setInsertionPointAfter(newWarpOp);
1369  auto newBitcastOp = vector::BitCastOp::create(
1370  rewriter, newWarpOp.getLoc(), distributedResultType,
1371  newWarpOp.getResult(newRetIndices[0]));
1372  Value distributedVal = newWarpOp.getResult(operandIdx);
1373  rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1374  return success();
1375  }
1376 };
1377 
1378 /// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1379 /// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1380 /// supported. In most cases, transpose is a no op because it is entirely
1381 /// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1382 /// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1383 /// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1384 /// vector::TransposeOp outside of the warp op with distributed source vector
1385 /// type (computed using assigned layout).
1386 struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1387  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1388  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1389  PatternRewriter &rewriter) const override {
1390  OpOperand *operand =
1391  getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1392  if (!operand)
1393  return rewriter.notifyMatchFailure(
1394  warpOp, "warp result is not a vector::Transpose op");
1395  auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1396  unsigned operandIdx = operand->getOperandNumber();
1397  xegpu::DistributeLayoutAttr sourceLayout =
1398  xegpu::getDistributeLayoutAttr(transposeOp.getVector());
1399  xegpu::DistributeLayoutAttr resultLayout =
1400  xegpu::getDistributeLayoutAttr(transposeOp.getResult());
1401  if (!sourceLayout || !resultLayout)
1402  return rewriter.notifyMatchFailure(
1403  transposeOp,
1404  "the source or result vector of the transpose op lacks layout "
1405  "attribute");
1406  int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1407  int64_t resultRank = transposeOp.getResultVectorType().getRank();
1408  // Only 2D transposes are supported for now.
1409  // TODO: Support nD transposes.
1410  if (sourceRank != 2 || resultRank != 2)
1411  return rewriter.notifyMatchFailure(
1412  transposeOp, "the source or result vector of the transpose op "
1413  "does not have 2D layout");
1414  ArrayRef<int64_t> perm = transposeOp.getPermutation();
1415  // Result layout must be a transpose of source layout.
1416  if (!resultLayout.isTransposeOf(sourceLayout, perm))
1417  return rewriter.notifyMatchFailure(
1418  transposeOp,
1419  "the source or result vector layouts must be 2D transposes of each "
1420  "other");
1421  FailureOr<VectorType> distributedSourceTypeOrFailure =
1422  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1423  transposeOp.getSourceVectorType());
1424  if (failed(distributedSourceTypeOrFailure))
1425  return rewriter.notifyMatchFailure(
1426  transposeOp, "Failed to distribute the source vector type in "
1427  "vector::Transpose op");
1428  SmallVector<size_t> newRetIndices;
1429  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1430  rewriter, warpOp, transposeOp.getVector(),
1431  TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1432  rewriter.setInsertionPointAfter(newWarpOp);
1433  auto newTransposeOp = vector::TransposeOp::create(
1434  rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1435  perm);
1436  Value distributedVal = newWarpOp.getResult(operandIdx);
1437  rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
1438  return success();
1439  }
1440 };
1441 
1442 } // namespace
1443 
1444 namespace {
1445 struct XeGPUSubgroupDistributePass final
1446  : public xegpu::impl::XeGPUSubgroupDistributeBase<
1447  XeGPUSubgroupDistributePass> {
1448  XeGPUSubgroupDistributePass() = default;
1449  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1450  default;
1451  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1452  : XeGPUSubgroupDistributeBase(options) {}
1453  void runOnOperation() override;
1454 };
1455 } // namespace
1456 
1459  patterns
1460  .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1461  DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1462  GpuBarrierDistribution, VectorMultiReductionDistribution,
1463  LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1464  VectorBitcastDistribution,
1465  MemrefExtractAlignedPointerAsIndexDistribution>(
1466  patterns.getContext(),
1467  /*pattern benefit=*/regularPatternBenefit);
1468  patterns.add<VectorShapeCastDistribution>(
1469  patterns.getContext(),
1470  /*pattern benefit=*/highPatternBenefit);
1471 }
1472 
1473 void XeGPUSubgroupDistributePass::runOnOperation() {
1474  // Step 1: Attach layouts to op operands.
1475  // TODO: Following assumptions are made:
1476  // 1) It is assumed that there are no layout conflicts.
1477  // 2) Any existing layout attributes attached to the operands are ignored.
1478  Operation *op = getOperation();
1479  op->walk([&](Operation *op) {
1480  for (OpOperand &operand : op->getOpOperands()) {
1481  // Layouts are needed for vector type only.
1482  if (!isa<VectorType>(operand.get().getType()))
1483  continue;
1484 
1485  auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1486  if (!layout) {
1487  op->emitError("Could not find layout attribute for operand ")
1488  << operand.getOperandNumber() << " of operation " << op->getName();
1489  signalPassFailure();
1490  return;
1491  }
1492  xegpu::setDistributeLayoutAttr(operand, layout);
1493  }
1494  });
1495  // Step 2: Move all operations of a GPU function inside
1496  // gpu.warp_execute_on_lane_0 operation.
1497  {
1499  patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1500 
1501  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1502  signalPassFailure();
1503  return;
1504  }
1505  // At this point, we have moved the entire function body inside the
1506  // warpOp. Now move any scalar uniform code outside of the warpOp (like
1507  // GPU index ops, scalar constants, etc.). This will simplify the
1508  // later lowering and avoid custom patterns for these ops.
1509  getOperation()->walk([&](Operation *op) {
1510  if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1511  vector::moveScalarUniformCode(warpOp);
1512  });
1513  }
1514  // Step 3: Apply subgroup to workitem distribution patterns.
1517  // distributionFn is used by vector distribution patterns to determine the
1518  // distributed vector type for a given vector value. In XeGPU subgroup
1519  // distribution context, we compute this based on lane layout.
1520  auto distributionFn = [](Value val) {
1521  VectorType vecType = dyn_cast<VectorType>(val.getType());
1522  int64_t vecRank = vecType ? vecType.getRank() : 0;
1523  if (vecRank == 0)
1524  return AffineMap::get(val.getContext());
1525  // Get the layout of the vector type.
1526  xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
1527  // If no layout is specified, assume the inner most dimension is distributed
1528  // for now.
1529  if (!layout)
1531  vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
1532  SmallVector<unsigned int> distributedDims;
1533  for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1534  if (v > 1)
1535  distributedDims.push_back(i);
1536  }
1537  return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
1538  val.getContext());
1539  };
1540  // TODO: shuffleFn is not used.
1541  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1542  int64_t warpSz) { return Value(); };
1543 
1544  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
1545  vector::CombiningKind kind, uint32_t size) {
1546  // First reduce on a single thread to get per lane reduction value.
1547  Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
1548  // Parallel reduction using butterfly shuffles.
1549  for (uint64_t i = 1; i < size; i <<= 1) {
1550  Value shuffled =
1551  builder
1552  .create<gpu::ShuffleOp>(loc, laneVal, i,
1553  /*width=*/size,
1554  /*mode=*/gpu::ShuffleMode::XOR)
1555  .getShuffleResult();
1556  laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
1557  }
1558  return laneVal;
1559  };
1560 
1561  if (enableSGReductions)
1562  vector::populateDistributeReduction(
1563  patterns, warpReduction,
1564  /*pattern benefit=*/regularPatternBenefit);
1565 
1566  vector::populatePropagateWarpVectorDistributionPatterns(
1567  patterns, distributionFn, shuffleFn,
1568  /*pattern benefit=*/regularPatternBenefit);
1569  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1570  signalPassFailure();
1571  return;
1572  }
1573 
1574  // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
1575  // due to tensor desc type mismatches created by using upstream distribution
1576  // patterns (scf.for). This cleanup should only be done if all the ops are
1577  // distributed successfully, if some ops are still not distributed and remains
1578  // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
1579  // breaking the IR.
1580  bool foundWarpOp = false;
1581  getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1582  // Look for WarpOps that are not trivially dead.
1583  if (isOpTriviallyDead(warpOp))
1584  return WalkResult::advance();
1585  foundWarpOp = true;
1586  return WalkResult::interrupt();
1587  });
1588  if (foundWarpOp)
1589  return;
1590 
1591  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1592  // We are only interested in UnrealizedConversionCastOps there were added
1593  // for resolving SIMT type mismatches.
1594  if (!op->getAttr(resolveSIMTTypeMismatch))
1595  return WalkResult::skip();
1596 
1597  Value input = op.getOperand(0);
1598  Value output = op.getResult(0);
1599 
1600  // Both input and output must have tensor descriptor types.
1601  xegpu::TensorDescType inputDescType =
1602  mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1603  xegpu::TensorDescType outputDescType =
1604  mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1605  assert(inputDescType && outputDescType &&
1606  "Unrealized conversion cast must have tensor descriptor types");
1607 
1608  // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
1609  // This occurs inside scf.for body to resolve the block argument type to
1610  // SIMT type.
1611  if (inputDescType.getLayout()) {
1612  auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1613  if (argument) {
1614  argument.setType(output.getType());
1615  output.replaceAllUsesWith(argument);
1616  if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1617  argument.getOwner()->getParentOp())) {
1618  auto result = loopOp.getTiedLoopResult(argument);
1619  result.setType(output.getType());
1620  }
1621  }
1622  }
1623 
1624  // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
1625  // conversions. This occurs at the yield op of scf.for body to go back
1626  // from SIMT type to original type.
1627  if (outputDescType.getLayout())
1628  output.replaceAllUsesWith(input);
1629 
1630  if (op->use_empty())
1631  op->erase();
1632  return WalkResult::advance();
1633  });
1634 }
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:179
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
Definition: XeGPUUtils.cpp:106
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:117
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:432
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
Definition: XeGPUUtils.cpp:230
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:40
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314