MLIR  22.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeRange.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/IR/Visitors.h"
30 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 
38 namespace mlir {
39 namespace xegpu {
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
42 } // namespace xegpu
43 } // namespace mlir
44 
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47 
48 using namespace mlir;
49 
50 static const char *const resolveSIMTTypeMismatch =
51  "resolve_simt_type_mismatch"; // Attribute name for identifying
52  // UnrelizedConversionCastOp added to resolve
53  // SIMT type mismatches.
54 
55 namespace {
56 
57 //===----------------------------------------------------------------------===//
58 // SIMT Distribution Patterns
59 //===----------------------------------------------------------------------===//
60 
61 /// In certain cases, we may need to favor XeGPU specific distribution patterns
62 /// over generic vector distribution patterns. In such cases, we can assign
63 /// priorities to patterns.
64 static constexpr unsigned regularPatternBenefit = 1;
65 static constexpr unsigned highPatternBenefit = 2;
66 
67 /// Helper function to get distributed vector type for a source vector type
68 /// according to the lane_layout. We simply divide each dimension of tensor
69 /// descriptor shape by corresponding lane_layout dimension. If
70 /// array_length > 1, that is appended to the front of the ditributed shape.
71 /// NOTE: This is the vector type that will be returned by the
72 /// gpu.warp_execute_on_lane0 op.
73 ///
74 /// Examples:
75 /// | original vector shape | lane_layout | distributed vector shape |
76 /// |-----------------------|-------------|--------------------------|
77 /// | 32x16 | [1, 16] | 32x1 |
78 /// | 32x16 | [2, 8] | 16x2 |
79 /// | 2x32x16 | [1, 16] | 2x32x1 |
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82  VectorType originalType) {
83  if (!layout)
84  return failure();
85  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86  "Expecting a valid layout.");
87  SmallVector<int64_t> effectiveLaneLayout =
88  layout.getEffectiveLaneLayoutAsInt();
89  assert(static_cast<size_t>(originalType.getRank()) >=
90  effectiveLaneLayout.size() &&
91  "Rank of the original vector type should be greater or equal to the "
92  "size of the lane layout to distribute the vector type.");
93  SmallVector<int64_t> distributedShape(originalType.getShape());
94  // Only distribute the last `laneLayout.size()` dimensions. The remaining
95  // dimensions are not distributed.
96  unsigned distributionStart =
97  originalType.getRank() - effectiveLaneLayout.size();
98  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
99  if (i < distributionStart)
100  continue;
101 
102  // Check if the dimension can be distributed evenly.
103  if (dim % effectiveLaneLayout[i - distributionStart] != 0)
104  return failure();
105  distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
106  }
107  return VectorType::get(distributedShape, originalType.getElementType());
108 }
109 
110 /// Helper function to resolve types if the distributed type out of
111 /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
112 /// Example 1:
113 /// distributed type: vector<8x1xf32>
114 /// expected type: vector<8xf32>
115 /// resolved using,
116 /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
117 /// Example 2:
118 /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
119 /// expected type: xegpu.tensor_desc<8x16xf32>
120 /// resolved using,
121 /// %0 = unrealized_conversion_cast %1 :
122 /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
123 /// xegpu.tensor_desc<8x16xf32>
124 template <typename T>
125 static Value resolveDistributedTy(Value orig, T expected,
126  PatternRewriter &rewriter) {
127  // If orig and expected types are the same, return orig.
128  if (orig.getType() == expected)
129  return orig;
130  // If orig is a vector type, create a shape cast op to reconcile the types.
131  if (isa<VectorType>(orig.getType())) {
132  auto castOp =
133  vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
134  return castOp.getResult();
135  }
136  // If orig is a tensor descriptor type, create an unrealized conversion cast
137  // op to reconcile the types.
138  if (isa<xegpu::TensorDescType>(orig.getType())) {
139  auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
140  expected, orig);
141  castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
142  return castOp.getResult(0);
143  }
144  llvm_unreachable("Unsupported type for reconciliation");
145  return orig;
146 }
147 
148 /// Helper function to check if the layout is packed. Layout is packed if it is
149 /// 2D and lane_data[0] != 1 (data packed from col dimension).
150 /// TODO: Move to target info.
151 static bool requirePacked(const xegpu::LayoutAttr layout) {
152  if (!layout)
153  return false;
154  auto laneData = layout.getEffectiveLaneDataAsInt();
155  if (laneData.size() != 2)
156  return false;
157  return laneData[0] != 1;
158 }
159 
160 /// Helper function to check if the layout requires a transpose effect.
161 static bool requireTranspose(const xegpu::LayoutAttr layout,
162  const xegpu::uArch::uArch *uArch) {
163  // Return false for unsupported targets.
164  // TODO: Add more support or move to target info.
165  if (uArch->getName().equals_insensitive("pvc") &&
166  uArch->getName().equals_insensitive("bmg"))
167  return false;
168  if (!layout)
169  return false;
170  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
171  if (laneLayout.size() != 2)
172  return false;
173  return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
174 }
175 
176 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
177 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
178 /// contained within a WarpExecuteOnLane0Op.
179 /// Example:
180 ///
181 /// ```
182 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
183 /// ...
184 /// ...
185 /// gpu.return %result: vector<8x16xf32>
186 /// }
187 /// ```
188 /// To
189 /// ```
190 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
191 /// %laneid = gpu.lane_id : index
192 /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
193 /// ...
194 /// ...
195 /// gpu.yield %result: vector<8x16xf32>
196 /// }
197 /// return %0
198 /// }
199 struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
201  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
202  PatternRewriter &rewriter) const override {
203  auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
204  if (!uArch)
205  return rewriter.notifyMatchFailure(
206  gpuFuncOp, "Subgroup distribution requires target attribute attached "
207  "to set the warp size");
208  // If the function only contains a single void return, skip.
209  if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
210  return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
211  }))
212  return failure();
213  // If the function already moved inside a warp_execute_on_lane0, skip.
214  if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
215  return isa<gpu::WarpExecuteOnLane0Op>(op);
216  }))
217  return failure();
218  // Create a new function with the same signature and same attributes.
219  SmallVector<Type> workgroupAttributionsTypes =
220  llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
221  [](BlockArgument arg) { return arg.getType(); });
222  SmallVector<Type> privateAttributionsTypes =
223  llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
224  [](BlockArgument arg) { return arg.getType(); });
225  auto newGpuFunc = gpu::GPUFuncOp::create(
226  rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
227  gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
228  privateAttributionsTypes);
229  newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
230  // Create a WarpExecuteOnLane0Op with same arguments and results as the
231  // original gpuFuncOp.
232  rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
233  auto laneId = gpu::LaneIdOp::create(
234  rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
235  /** upperBound = **/ mlir::IntegerAttr());
236  ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
237  auto warpOp = gpu::WarpExecuteOnLane0Op::create(
238  rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
239  uArch->getSubgroupSize(), newGpuFunc.getArguments(),
240  newGpuFunc.getArgumentTypes());
241  Block &warpBodyBlock = warpOp.getBodyRegion().front();
242  // Replace the ReturnOp of the original gpu function with a YieldOp.
243  auto origRetunOp =
244  cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
245  rewriter.setInsertionPointAfter(origRetunOp);
246  gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
247  origRetunOp.getOperands());
248  rewriter.eraseOp(origRetunOp);
249  // Move the original function body to the WarpExecuteOnLane0Op body.
250  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
251  warpOp.getBodyRegion().begin());
252  rewriter.eraseBlock(&warpBodyBlock);
253  // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
254  rewriter.setInsertionPointAfter(warpOp);
255  gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
256  rewriter.replaceOp(gpuFuncOp, newGpuFunc);
257  return success();
258  }
259 };
260 
261 /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
262 /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
263 /// still contain the original op that will not be used by the yield op (and
264 /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
265 /// arguments. Tensor descriptor shape is not distributed because it is a
266 /// uniform value across all work items within the subgroup. However, the
267 /// layout information is dropped in the new tensor descriptor type.
268 ///
269 /// Example:
270 ///
271 /// ```
272 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
273 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
274 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
275 /// ...
276 /// %td = xegpu.create_nd_tdesc %arg0
277 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
278 /// vector.yield %td
279 /// }
280 /// ```
281 /// To
282 /// ```
283 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
284 /// ...
285 /// %dead = xegpu.create_nd_tdesc %arg0
286 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
287 /// vector.yield %arg0, %dead
288 /// }
289 /// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
290 /// -> !xegpu.tensor_desc<4x8xf32>
291 ///
292 /// ```
293 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
294  using gpu::WarpDistributionPattern::WarpDistributionPattern;
295  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
296  PatternRewriter &rewriter) const override {
297  OpOperand *operand =
298  getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
299  if (!operand)
300  return rewriter.notifyMatchFailure(
301  warpOp, "warp result is not a xegpu::CreateNdDesc op");
302  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
303  unsigned operandIdx = operand->getOperandNumber();
304 
305  xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
306  if (!layout)
307  return rewriter.notifyMatchFailure(
308  descOp, "the tensor descriptor lacks layout attribute");
309  // CreateNdOp must not have offsets.
310  if (descOp.getMixedOffsets().size())
311  return rewriter.notifyMatchFailure(
312  descOp, "xegpu::CreateNdDescOp must not have offsets");
313 
314  SmallVector<size_t> newRetIndices;
315  rewriter.setInsertionPoint(warpOp);
316  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
317  rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
318  /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
319 
320  SmallVector<Value> newDescOperands = llvm::map_to_vector(
321  newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
322  rewriter.setInsertionPointAfter(newWarpOp);
323  xegpu::TensorDescType distributedTensorDescTy =
324  descOp.getType().dropLayouts(); // Distributed tensor descriptor type
325  // does not contain layout info.
326  Value newDescOp = xegpu::CreateNdDescOp::create(
327  rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
328  descOp->getAttrs());
329 
330  Value distributedVal = newWarpOp.getResult(operandIdx);
331  // Resolve the distributed type to the expected type.
332  newDescOp =
333  resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
334  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
335  return success();
336  }
337 };
338 
339 /// Distribute a store_nd op at the end of enclosing
340 /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
341 /// through the warp op interface they would be propagated as returned values.
342 /// Source vector is distributed based on lane layout. Appropriate cast ops are
343 /// inserted if the distributed types does not match expected xegpu SIMT types.
344 ///
345 /// Example:
346 ///
347 /// ```
348 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
349 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
350 /// ...
351 /// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
352 /// !xegpu.tensor_desc<4x8xf32, #layout0>
353 /// }
354 /// ```
355 /// To
356 /// ```
357 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
358 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
359 /// ...
360 /// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
361 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
362 /// }
363 /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
364 /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
365 /// #layout0>
366 /// -> !xegpu.tensor_desc<4x8xf32>
367 /// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
368 /// !xegpu.tensor_desc<4x8xf32>
369 ///
370 /// ```
371 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372  using gpu::WarpDistributionPattern::WarpDistributionPattern;
373  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
374  PatternRewriter &rewriter) const override {
375  gpu::YieldOp yield = warpOp.getTerminator();
376  Operation *lastNode = yield->getPrevNode();
377  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
378  if (!storeOp)
379  return failure();
380 
381  SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
382  // Expecting offsets to be present.
383  if (offsets.empty())
384  return rewriter.notifyMatchFailure(storeOp,
385  "the store op must have offsets");
386  SmallVector<Value> offsetsAsValues =
387  vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
388  SmallVector<Type> offsetTypes = llvm::to_vector(
389  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
390  xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
391  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
392  if (!layout)
393  return rewriter.notifyMatchFailure(
394  storeOp, "the source tensor descriptor lacks layout attribute");
395 
396  FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
397  getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
398  if (failed(distributedTypeByWarpOpOrFailure))
399  return rewriter.notifyMatchFailure(storeOp,
400  "Failed to distribute the type");
401  VectorType distributedTypeByWarpOp =
402  distributedTypeByWarpOpOrFailure.value();
403 
404  SmallVector<size_t> newRetIndices;
405  SmallVector<Value> newYieldedValues = {storeOp.getValue(),
406  storeOp.getTensorDesc()};
407  SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
408  newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
409  newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
410  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
411  rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
412  // Create a new store op outside the warp op with the distributed vector
413  // type. Tensor descriptor is not distributed.
414  rewriter.setInsertionPointAfter(newWarpOp);
415  SmallVector<Value> newStoreOperands;
416 
417  // For the value operand, there can be a mismatch between the vector type
418  // distributed by the warp op and (xegpu-specific) distributed type
419  // supported by the store op. Type mismatch must be resolved using
420  // appropriate cast op.
421  FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
422  xegpu::getDistributedVectorType(storeOp.getTensorDescType());
423  if (failed(storeNdDistributedValueTyOrFailure))
424  return rewriter.notifyMatchFailure(
425  storeOp, "Failed to get distributed vector type for the store op");
426  newStoreOperands.push_back(resolveDistributedTy(
427  newWarpOp.getResult(newRetIndices[0]),
428  storeNdDistributedValueTyOrFailure.value(), rewriter));
429  // For the tensor descriptor operand, the layout attribute is dropped after
430  // distribution. Types needs to be resolved in this case also.
431  xegpu::TensorDescType distributedTensorDescTy =
432  storeOp.getTensorDescType().dropLayouts();
433  newStoreOperands.push_back(
434  resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
435  distributedTensorDescTy, rewriter));
436  // Collect offsets.
437  for (size_t i = 2; i < newRetIndices.size(); ++i)
438  newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
439 
440  auto newStoreOp =
441  xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
442  newStoreOperands, storeOp->getAttrs());
443  xegpu::removeLayoutAttrs(newStoreOp);
444  rewriter.eraseOp(storeOp);
445  return success();
446  }
447 };
448 
449 /// Distribute a load_nd op feeding into vector.yield op for the enclosing
450 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
451 /// The warp op will still contain the original op that will not be used by
452 /// the yield op (and should be cleaned up later). The yield op will
453 /// bypass the load's arguments. Only the loaded vector is distributed
454 /// according to lane layout and, tensor descriptor types is not
455 /// distributed. Appropriate cast ops are inserted if the distributed types does
456 /// not match expected xegpu SIMT types.
457 ///
458 /// Example:
459 ///
460 /// ```
461 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
462 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
463 /// (vector<4x1xf32>) {
464 /// ...
465 /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
466 /// ->
467 /// vector<4x8xf32>
468 /// gpu.yield %ld
469 /// }
470 /// ```
471 /// To
472 /// ```
473 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
474 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
475 /// ...
476 /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
477 /// vector<4x8xf32> gpu.yield %dead, %arg0
478 /// }
479 /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
480 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
481 /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
482 /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
483 ///
484 /// ```
485 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
486  using gpu::WarpDistributionPattern::WarpDistributionPattern;
487  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
488  PatternRewriter &rewriter) const override {
489  OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
490  if (!isa<xegpu::LoadNdOp>(op))
491  return false;
492  // Make sure the same load op is the last operation in the warp op body.
493  // This ensure that load op is not sinked earlier violating any barrier
494  // synchronizations.
495  gpu::YieldOp yield = warpOp.getTerminator();
496  return yield->getPrevNode() == op;
497  });
498 
499  if (!operand)
500  return rewriter.notifyMatchFailure(
501  warpOp, "warp result is not a xegpu::LoadNd op");
502 
503  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
504  auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
505  if (!uArch)
506  return rewriter.notifyMatchFailure(
507  loadOp, "xegpu::LoadNdOp require target attribute attached to "
508  "determine transpose "
509  "requirement");
510  // Chip information is required to decide if the layout requires transpose
511  // effect.
512  // Expecting offsets to be present.
513  SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
514  if (offsets.empty())
515  return rewriter.notifyMatchFailure(loadOp,
516  "the load op must have offsets");
517  SmallVector<Value> offsetsAsValues =
518  vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
519  SmallVector<Type> offsetTypes = llvm::to_vector(
520  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
521 
522  xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
523  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
524  if (!layout)
525  return rewriter.notifyMatchFailure(
526  loadOp, "the source tensor descriptor lacks layout attribute");
527 
528  unsigned operandIdx = operand->getOperandNumber();
529  VectorType distributedTypeByWarpOp =
530  cast<VectorType>(warpOp.getResult(operandIdx).getType());
531 
532  SmallVector<size_t> newRetIndices;
533  SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
534  SmallVector<Type> newYieldedTypes = {tensorDescTy};
535  newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
536  newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
537  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
538  rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
539 
540  // Create a new load op outside the warp op with the distributed vector
541  // type.
542  rewriter.setInsertionPointAfter(newWarpOp);
543  FailureOr<VectorType> loadNdDistValueTyOrFailure =
544  xegpu::getDistributedVectorType(loadOp.getTensorDescType());
545  if (failed(loadNdDistValueTyOrFailure))
546  return rewriter.notifyMatchFailure(
547  loadOp, "Failed to get distributed vector type for the load op");
548  xegpu::TensorDescType distributedTensorDescTy =
549  loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
550  // descriptor type does not
551  // contain layout info.
552  SmallVector<Value> newLoadOperands{
553  resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
554  distributedTensorDescTy, rewriter)};
555  // Collect offsets.
556  for (size_t i = 1; i < newRetIndices.size(); ++i)
557  newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
558  auto newLoadOp = xegpu::LoadNdOp::create(
559  rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
560  newLoadOperands, loadOp->getAttrs());
561  xegpu::removeLayoutAttrs(newLoadOp);
562  // Set the packed attribute if the layout requires it.
563  newLoadOp.setPacked(requirePacked(layout));
564  // Set the transpose attribute if the layout requires it.
565  if (requireTranspose(layout, uArch))
566  newLoadOp.setTranspose(
567  DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
568  Value distributedVal = newWarpOp.getResult(operandIdx);
569  // There can be a conflict between the vector type distributed by the
570  // warp op and (xegpu-specific) distributed type supported by the load
571  // op. Resolve these mismatches by inserting a cast.
572  Value tyResolvedVal = resolveDistributedTy(
573  newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
574  rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
575  return success();
576  }
577 };
578 
579 /// Distribute a dpas op feeding into vector.yield op for the enclosing
580 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
581 /// The warp op will still contain the original op that will not be used by
582 /// the yield op (and should be cleaned up later). The yield op will
583 /// bypass the dpas's arguments. Appropriate cast ops are inserted if the
584 /// distributed types does not match expected xegpu SIMT types.
585 /// Example:
586 /// ```
587 /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
588 /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
589 /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
590 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
591 /// (vector<8x1xf32>) {
592 /// ...
593 /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
594 /// vector<8x16xf32>
595 /// gpu.yield %dpas
596 /// }
597 /// ```
598 /// To
599 /// ```
600 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
601 /// vector<8x1xf16>, vector<16x1xf16>) {
602 /// ...
603 /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
604 /// -> vector<8x16xf32>
605 /// gpu.yield %dead, %arg0, %arg1
606 /// }
607 /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
608 /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
609 /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
610 /// vector<8xf32>
611 /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
612 /// ```
613 struct DpasDistribution final : public gpu::WarpDistributionPattern {
614  using gpu::WarpDistributionPattern::WarpDistributionPattern;
615  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
616  PatternRewriter &rewriter) const override {
617  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
618  if (!operand)
619  return rewriter.notifyMatchFailure(warpOp,
620  "warp result is not a xegpu::Dpas op");
621 
622  auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
623  unsigned operandIdx = operand->getOperandNumber();
624  std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
625  std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
626  std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
627 
628  xegpu::LayoutAttr layoutA =
629  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
630  xegpu::LayoutAttr layoutB =
631  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
632  xegpu::LayoutAttr layoutOut =
633  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
634  if (!layoutA || !layoutB || !layoutOut)
635  return rewriter.notifyMatchFailure(
636  dpasOp,
637  "the xegpu::Dpas op lacks layout attribute for A, B or output");
638 
639  FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
640  getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
641  FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
642  getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
643  FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
644  getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
645  if (failed(distLhsTypeByWarpOpOrFailure) ||
646  failed(distRhsTypeByWarpOpOrFailure) ||
647  failed(distResultTypeByWarpOpOrFailure))
648  return rewriter.notifyMatchFailure(
649  dpasOp,
650  "Failed to distribute the A, B or output types in xegpu::Dpas op");
651 
652  llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
653  dpasOp.getRhs()};
654  llvm::SmallVector<Type, 3> newYieldTypes{
655  distLhsTypeByWarpOpOrFailure.value(),
656  distRhsTypeByWarpOpOrFailure.value()};
657  // Dpas acc operand is optional.
658  if (dpasOp.getAcc()) {
659  newYieldValues.push_back(dpasOp.getAcc());
660  newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
661  }
662  // Create a new warp op without the dpas.
663  SmallVector<size_t> newRetIndices;
664  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
665  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
666 
667  FailureOr<VectorType> expectedDistLhsTyOrFailure =
668  xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
669  FailureOr<VectorType> expectedDistRhsTyOrFailure =
670  xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
671  FailureOr<VectorType> expectedDistResultTyOrFailure =
672  xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
673  if (failed(expectedDistLhsTyOrFailure) ||
674  failed(expectedDistRhsTyOrFailure) ||
675  failed(expectedDistResultTyOrFailure))
676  return rewriter.notifyMatchFailure(
677  dpasOp,
678  "Failed to get distributed vector type for the dpas operands.");
679  // Create a new dpas op outside the warp op.
680  rewriter.setInsertionPointAfter(newWarpOp);
681  SmallVector<Value> newDpasOperands;
682  SmallVector<VectorType> newDpasOperandExpectedTypes;
683 
684  // Resolve the distributed types with the original types.
685  newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
686  newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
687  VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
688  if (dpasOp.getAcc())
689  newDpasOperandExpectedTypes.push_back(distributedResultTy);
690 
691  for (unsigned i = 0; i < newRetIndices.size(); i++) {
692  newDpasOperands.push_back(
693  resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
694  newDpasOperandExpectedTypes[i], rewriter));
695  }
696  auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
697  distributedResultTy, newDpasOperands,
698  dpasOp->getAttrs());
699  xegpu::removeLayoutAttrs(newDpasOp);
700  Value distributedVal = newWarpOp.getResult(operandIdx);
701  // Resolve the output type.
702  Value typeResolved =
703  resolveDistributedTy(newDpasOp.getResult(),
704  distResultTypeByWarpOpOrFailure.value(), rewriter);
705  rewriter.replaceAllUsesWith(distributedVal, typeResolved);
706  return success();
707  }
708 };
709 
710 /// Distribute a prefetch_nd op at the end of enclosing
711 /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
712 /// through the warp op interface they would be propagated as returned values.
713 /// Tensor descriptor shape is not distributed because it is a uniform value
714 /// across all work items within the subgroup. Appropriate cast ops are inserted
715 /// if the distributed types does not match expected xegpu SIMT types.
716 ///
717 /// Example:
718 ///
719 /// ```
720 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
721 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
722 /// ...
723 /// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
724 /// }
725 /// ```
726 /// To
727 /// ```
728 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
729 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
730 /// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
731 /// index
732 /// }
733 /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
734 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
735 /// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
736 ///
737 /// ```
738 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
739  using gpu::WarpDistributionPattern::WarpDistributionPattern;
740  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
741  PatternRewriter &rewriter) const override {
742  gpu::YieldOp yield = warpOp.getTerminator();
743  Operation *lastNode = yield->getPrevNode();
744  auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
745  if (!prefetchOp)
746  return failure();
747 
748  SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
749  // PrefetchNdOp must have offsets.
750  if (offsets.empty())
751  return rewriter.notifyMatchFailure(prefetchOp,
752  "the prefetch op must have offsets");
753  SmallVector<Value> offsetsAsValues =
754  vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
755  SmallVector<Type> offsetTypes = llvm::to_vector(
756  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
757 
758  xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
759  if (!layout)
760  return rewriter.notifyMatchFailure(
761  prefetchOp, "the source tensor descriptor lacks layout attribute");
762 
763  SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
764  SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
765  newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
766  newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
767  SmallVector<size_t> newRetIndices;
768  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
769  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
770  // Create a new prefetch op outside the warp op with updated tensor
771  // descriptor type. Source tensor descriptor require type resolution.
772  xegpu::TensorDescType newTensorDescTy =
773  prefetchOp.getTensorDescType().dropLayouts();
774  rewriter.setInsertionPointAfter(newWarpOp);
775  SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
776  newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
777  // Collect offsets.
778  for (size_t i = 1; i < newRetIndices.size(); ++i)
779  newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
780  xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
781  newPrefetchOperands, prefetchOp->getAttrs());
782  xegpu::removeLayoutAttrs(prefetchOp);
783  rewriter.eraseOp(prefetchOp);
784  return success();
785  }
786 };
787 
788 /// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
789 /// region. This will simply move the barrier op outside of the warp op.
790 struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
791  using gpu::WarpDistributionPattern::WarpDistributionPattern;
792  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
793  PatternRewriter &rewriter) const override {
794  gpu::YieldOp yield = warpOp.getTerminator();
795  Operation *lastNode = yield->getPrevNode();
796  // The last node must be a gpu::BarrierOp.
797  auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
798  if (!barrierOp)
799  return failure();
800  // Move the barrier op outside of the warp op.
801  rewriter.setInsertionPointAfter(warpOp);
802  gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
803  barrierOp->getResultTypes(),
804  barrierOp->getOperands(), barrierOp->getAttrs());
805  rewriter.eraseOp(barrierOp);
806  return success();
807  }
808 };
809 
810 /// Distribute a scattered store op. The offsets argument is required.
811 /// Both offset and mask vectors must be 1D and have #subgroup_size elements.
812 /// The layouts are fixed and implicit: one offset/mask per lane.
813 /// The pass changes the offset/mask vector shapes to a
814 /// single-element vector, **it is assumed that their producer will also be
815 /// distributed**. The payload vector also has a fixed distribution:
816 /// no chunk size -> vector of one element.
817 /// chunk size -> vector of the innermost dimension of the SG-payload.
818 /// Example 1 (no chunk size):
819 /// %mask = producer_op : vector<16xi1>
820 /// %offset = producer_op : vector<16xindex>
821 /// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822 /// memref<256xf16>, vector<16xindex>, vector<16xi1>
823 /// To
824 /// %mask = producer_op : vector<1xi1>
825 /// %offset = producer_op : vector<1xindex>
826 /// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827 /// memref<256xf16>, vector<1xindex>, vector<1xi1>
828 /// Example 2 (chunk size, same mask and offsets):
829 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830 /// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831 /// To
832 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833 /// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834 struct StoreDistribution final : public gpu::WarpDistributionPattern {
835  using gpu::WarpDistributionPattern::WarpDistributionPattern;
836  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
837  PatternRewriter &rewriter) const override {
838  Operation *lastNode = warpOp.getTerminator()->getPrevNode();
839  auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840  if (!storeScatterOp)
841  return failure();
842  auto offsets = storeScatterOp.getOffsets();
843  if (!offsets || !isa<VectorType>(offsets.getType()))
844  return rewriter.notifyMatchFailure(
845  storeScatterOp, "Store op must have a vector of offsets argument");
846  VectorType offsetsTy = cast<VectorType>(offsets.getType());
847  VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
848  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
849  return rewriter.notifyMatchFailure(storeScatterOp,
850  "Expected 1D offsets and mask vector");
851  VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
852  if (storeVecTy.getRank() > 2)
853  return rewriter.notifyMatchFailure(
854  storeScatterOp, "Expected at most 2D result at SG level");
855 
856  std::string layoutPayloadName =
857  xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
858  std::string layoutOffsetsName =
859  xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
860  std::string layoutMaskName =
861  xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
862 
863  xegpu::LayoutAttr layoutPayload =
864  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
865  xegpu::LayoutAttr layoutOffsets =
866  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
867  xegpu::LayoutAttr layoutMask =
868  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
869 
870  FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871  getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
872  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
874  FailureOr<VectorType> distMaskByWarpOpOrFailure =
875  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
876  if (failed(distStoreVecByWarpOpOrFailure) ||
877  failed(distOffsetsByWarpOpOrFailure) ||
878  failed(distMaskByWarpOpOrFailure)) {
879  return rewriter.notifyMatchFailure(
880  storeScatterOp,
881  "Some vector operands have no layouts, using defaults instead.");
882  }
883  // Distributed store payload type according to the lane layout.
884  VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
885  // Expected distributed payload type is always 1D.
886  VectorType expectedPayloadTy =
887  VectorType::get({distPayloadTyByWarpOp.getNumElements()},
888  distPayloadTyByWarpOp.getElementType());
889 
890  SmallVector<size_t> newRetIndices;
891  SmallVector<Value> operands = storeScatterOp->getOperands();
892  SmallVector<Type> operandTypesToYield = {
893  distPayloadTyByWarpOp, operands[1].getType(),
894  distOffsetsByWarpOpOrFailure.value(),
895  distMaskByWarpOpOrFailure.value()};
896 
897  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
898  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
899  SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
900  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
901  // The payload operand may need type adjustment due to mismatch between warp
902  // distributed type and expected SIMT type.
903  rewriter.setInsertionPointAfter(newWarpOp);
904  newStoreScatterOpOperands[0] = resolveDistributedTy(
905  newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
906  xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
907  rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
908  storeScatterOp->getAttrs());
910  rewriter.eraseOp(storeScatterOp);
911  return success();
912  }
913 };
914 
915 /// Distribute a scattered load op. The logic and requirements are the same as
916 /// for the scattered store distribution. The warpOp's payload vector is
917 /// expected to be distributed by the load's result consumer.
918 /// Example 1 (no chunk size):
919 /// %mask = producer_op : vector<16xi1>
920 /// %offset = producer_op : vector<16xindex>
921 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
922 /// vector<16xindex>, vector<16xi1> -> vector<16xf16>
923 /// To
924 /// %mask = producer_op : vector<1xi1>
925 /// %offset = producer_op : vector<1xindex>
926 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
927 /// vector<1xindex>, vector<1xi1> -> vector<1xf16>
928 /// Example 2 (chunk size, same mask and offsets):
929 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
930 /// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
931 /// To
932 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
933 /// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
934 struct LoadDistribution final : public gpu::WarpDistributionPattern {
935  using gpu::WarpDistributionPattern::WarpDistributionPattern;
936  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
937  PatternRewriter &rewriter) const override {
938  OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
939  // Check if the yield operand that was produced by the *last* scattered
940  // load op to avoid sinking it before barriers (maintain memory order).
941  return isa<xegpu::LoadGatherOp>(op) &&
942  warpOp.getTerminator()->getPrevNode() == op;
943  });
944  if (!producedByLastLoad)
945  return rewriter.notifyMatchFailure(
946  warpOp, "The last op is not xegpu::LoadGatherOp");
947 
948  auto loadGatherOp =
949  producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
950  auto offsets = loadGatherOp.getOffsets();
951  if (!offsets || !isa<VectorType>(offsets.getType()) ||
952  !isa<VectorType>(loadGatherOp.getMask().getType()))
953  return rewriter.notifyMatchFailure(
954  loadGatherOp,
955  "Load op must have a vector arguments for offsets and mask");
956  VectorType offsetsTy = cast<VectorType>(offsets.getType());
957  VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
958  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
959  return rewriter.notifyMatchFailure(loadGatherOp,
960  "Expected 1D offsets and mask vector");
961  // Assume offset and mask producers will be distributed as well.
962  std::string layoutOffsetsName =
963  xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
964  std::string layoutMaskName =
965  xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
966 
967  xegpu::LayoutAttr layoutOffsets =
968  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
969  xegpu::LayoutAttr layoutMask =
970  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
971 
972  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
973  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
974  FailureOr<VectorType> distMaskByWarpOpOrFailure =
975  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
976  if (failed(distOffsetsByWarpOpOrFailure) ||
977  failed(distMaskByWarpOpOrFailure)) {
978  return rewriter.notifyMatchFailure(
979  loadGatherOp,
980  "Some vector operands have no layouts, using defaults instead.");
981  }
982 
983  SmallVector<size_t> newRetIndices;
984  SmallVector<Value> operands = loadGatherOp->getOperands();
985  SmallVector<Type> operandTypesToYield = {
986  operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
987  distMaskByWarpOpOrFailure.value()};
988 
989  const unsigned operandIdx = producedByLastLoad->getOperandNumber();
990  VectorType distResultTy =
991  cast<VectorType>(warpOp.getResult(operandIdx).getType());
992  // Distributed load op will always be 1D.
993  VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
994  distResultTy.getElementType());
995 
996  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
997  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
998 
999  SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
1000  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
1001 
1002  rewriter.setInsertionPointAfter(newWarpOp);
1003  xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
1004  rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
1005  loadGatherOp->getAttrs());
1006  xegpu::removeLayoutAttrs(newOp);
1007  Value distributedVal = newWarpOp.getResult(operandIdx);
1008  // Resolve the output type and replace all uses.
1009  rewriter.replaceAllUsesWith(
1010  distributedVal,
1011  resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
1012  return success();
1013  }
1014 };
1015 
1016 /// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1017 /// VectorReductionOps. We also insert layouts for the newly created ops.
1018 static Value lowerToVectorReductions(TypedValue<VectorType> src,
1020  vector::CombiningKind kind,
1021  int64_t reductionDim, Location loc,
1022  PatternRewriter &rewriter) {
1023  // Expecting a 2D source vector.
1024  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1025  VectorType sourceType = src.getType();
1026  int64_t sourceH = sourceType.getShape()[0];
1027  int64_t sourceW = sourceType.getShape()[1];
1028  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1029  // Create a constant vector to hold the result of the reduction.
1030  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1031  Value reductionResult = arith::ConstantOp::create(
1032  rewriter, loc, acc.getType(),
1033  DenseElementsAttr::get(acc.getType(), zeroAttr));
1034  // Reduction result should have the same layout as the accumulator.
1035  xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
1037  // For each slice of the source, extract the slice vector, do a reduction
1038  // and, insert the reduced value back to the result vector.
1039  for (int i = 0; i < nSlices; ++i) {
1040  SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1041  if (reductionDim == 1) {
1042  sliceOffsets = {i, 0};
1043  sliceSizes = {1, sourceW};
1044  } else {
1045  sliceOffsets = {0, i};
1046  sliceSizes = {sourceH, 1};
1047  }
1048  vector::ExtractStridedSliceOp extractOp =
1049  vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1050  sliceSizes, {1, 1});
1051  int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1052  vector::ShapeCastOp slice = vector::ShapeCastOp::create(
1053  rewriter, loc,
1054  VectorType::get({nSliceElements}, sourceType.getElementType()),
1055  extractOp.getResult());
1056  // Shape cast is currently handled in xegpu side. So layouts must be
1057  // retained during lowering. Shape cast output has the same layout as the
1058  // accumulator. Shape cast source has the same layout as the original
1059  // reduction source.
1060  // TODO: other ops generated here may also need layout attributes.
1061  xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
1063  xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
1065  // Extract and reduction results in scalars, so no result layout is needed.
1066  Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1067  Value reduction = vector::ReductionOp::create(
1068  rewriter, loc, kind, slice.getResult(), accExtract);
1069  reductionResult =
1070  vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1071  }
1072  return reductionResult;
1073 }
1074 
1075 /// This patterns distribute the `vector.multi_reduction` operation across
1076 /// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1077 /// layouts for the source and accumulator vectors,
1078 /// * If the reduction dimension is distributed across lanes, the reduction is
1079 /// non-lane-local and the reduction is done using warp shuffles. Here we
1080 /// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1081 /// the warp op body.
1082 /// * If the reduction dimension is not distributed across lanes, the reduction
1083 /// is lane-local. In this case, we yield the source and accumulator vectors
1084 /// from the warp op and perform the lane-local reduction outside the warp op
1085 /// using a sequence of ReductionOps.
1086 /// Example 1 (Reduction is lane-local):
1087 /// ```
1088 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1089 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1090 /// %acc = "some_def"() : () -> (vector<32xf32>)
1091 /// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1092 /// vector<32xf32> gpu.yield %1 : vector<32xf32>
1093 /// }
1094 /// ```
1095 /// is lowered to:
1096 /// ```
1097 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1098 /// vector<1xf32>) {
1099 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1100 /// %acc = "some_def"() : () -> (vector<32xf32>)
1101 /// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1102 /// }
1103 /// %c = arith.constant dense<0.0> : vector<1xf32>
1104 /// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1105 /// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1106 /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1107 /// ```
1108 /// Example 2 (Reduction is non-lane-local):
1109 /// ```
1110 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1111 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1112 /// %acc = "some_def"() : () -> (vector<2xf32>)
1113 /// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1114 /// vector<2xf32>
1115 /// gpu.yield %1 : vector<2xf32>
1116 /// }
1117 /// ```
1118 /// is lowered to:
1119 /// ```
1120 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1121 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1122 /// %acc = "some_def"() : () -> (vector<2xf32>)
1123 /// %1 = arith.constant dense<0.0> : vector<2xf32>
1124 /// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1125 /// %3 = ("warp.reduction %2") : f32
1126 /// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1127 /// ... repeat for row 1
1128 /// gpu.yield %1 : vector<2xf32>
1129 /// }
1130 struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1131  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1132  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1133  PatternRewriter &rewriter) const override {
1134  OpOperand *yieldOperand =
1135  getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1136  if (!yieldOperand)
1137  return failure();
1138  auto reductionOp =
1139  cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1140  unsigned operandIdx = yieldOperand->getOperandNumber();
1141  VectorType sourceType = reductionOp.getSourceVectorType();
1142  // Only 2D vectors are supported.
1143  if (sourceType.getRank() != 2)
1144  return rewriter.notifyMatchFailure(warpOp,
1145  "Only 2D reductions are supported.");
1146  ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1147  // Only 1 reduction dimension supported. This also ensures that the result
1148  // is vector type.
1149  if (reductionDims.size() != 1)
1150  return rewriter.notifyMatchFailure(
1151  warpOp, "Only 1 reduction dimension is supported.");
1152  int64_t reductionDim = reductionDims[0];
1153  VectorType distributedResultType =
1154  cast<VectorType>(warpOp.getResult(operandIdx).getType());
1155  VectorType resultType = cast<VectorType>(reductionOp.getType());
1156  xegpu::DistributeLayoutAttr sourceLayout =
1157  xegpu::getDistributeLayoutAttr(reductionOp.getSource());
1158 
1159  FailureOr<VectorType> sourceDistTypeOrFailure =
1160  getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1161  if (failed(sourceDistTypeOrFailure))
1162  return rewriter.notifyMatchFailure(
1163  warpOp, "Failed to distribute the source vector type.");
1164  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1165  // Only single dimension distribution is supported.
1166  bool dim0Distributed =
1167  sourceDistType.getShape()[0] != sourceType.getShape()[0];
1168  bool dim1Distributed =
1169  sourceDistType.getShape()[1] != sourceType.getShape()[1];
1170  if (dim0Distributed && dim1Distributed)
1171  return rewriter.notifyMatchFailure(
1172  warpOp, "Expecting source to be distributed in a single dimension.");
1173  int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1174  if (sourceDistDim == -1)
1175  return rewriter.notifyMatchFailure(
1176  warpOp, "Expecting a distributed source vector.");
1177  bool resultDistributed =
1178  distributedResultType.getNumElements() < resultType.getNumElements();
1179  // If the lane owns all the data required for reduction (i.e. reduction is
1180  // fully parallel accross lanes), then each lane owns part of the result
1181  // (i.e. result is distributed). If the reduction require cross-lane
1182  // shuffling, then the result is shared among all lanes (broadcasted).
1183  // Therefore we expect following cases:
1184  //
1185  // | Source vector | Reduction dim | Result vector |
1186  // |----------------------|----------------|----------------|
1187  // | dim-0 distributed | 0 | broadcasted |
1188  // | dim-0 distributed | 1 | distributed |
1189  // | dim-1 distributed | 0 | distributed |
1190  // | dim-1 distributed | 1 | broadcasted |
1191 
1192  bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1193  (sourceDistDim == 1 && reductionDim == 0);
1194  if (isReductionLaneLocal && !resultDistributed)
1195  return rewriter.notifyMatchFailure(
1196  warpOp, "Expecting a distributed result for lane-local reduction.");
1197 
1198  if (!isReductionLaneLocal && resultDistributed)
1199  return rewriter.notifyMatchFailure(
1200  warpOp,
1201  "Expecting a broadcasted result for non-lane-local reduction.");
1202 
1203  // Handle lane-local reduction case. In this case we fully distribute the
1204  // reduction result.
1205  if (isReductionLaneLocal) {
1206  // Yield the source and acc vectors from the WarpOp.
1207  SmallVector<size_t> newRetIndices;
1208  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1209  rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1210  {sourceDistType, distributedResultType}, newRetIndices);
1211  rewriter.setInsertionPointAfter(newWarpOp);
1212  Value result = lowerToVectorReductions(
1213  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1214  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1215  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1216  // Replace the warp op result with the final result.
1217  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
1218  return success();
1219  }
1220  // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1221  // of multiple ReductionOps. Actual distribution is done by the
1222  // WarpOpReduction pattern.
1223  rewriter.setInsertionPointAfter(reductionOp);
1224  Value result = lowerToVectorReductions(
1225  cast<TypedValue<VectorType>>(reductionOp.getSource()),
1226  cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1227  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1228  // Replace the warp op result with the final result.
1229  rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1230  return success();
1231  }
1232 };
1233 
1234 /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1235 /// `gpu.warp_execute_on_lane_0` region.
1236 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1237  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1238  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1239  PatternRewriter &rewriter) const override {
1240  OpOperand *yieldOperand =
1241  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1242  if (!yieldOperand)
1243  return failure();
1244  auto shapeCastOp =
1245  cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1246  unsigned operandNumber = yieldOperand->getOperandNumber();
1247  auto resultDistTy =
1248  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1249  xegpu::DistributeLayoutAttr sourceLayout =
1250  xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
1251  xegpu::DistributeLayoutAttr resultLayout =
1252  xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
1253  if (!sourceLayout || !resultLayout)
1254  return rewriter.notifyMatchFailure(
1255  warpOp,
1256  "the source or result of shape_cast op lacks distribution layout");
1257 
1258  // For rank reducing or increasing shape_cast ops, the lower rank layout
1259  // must be a slice of higher rank layout.
1260  int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1261  int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1262  if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1263  return rewriter.notifyMatchFailure(
1264  warpOp, "shape_cast is rank reducing but source layout is not a "
1265  "slice of result layout");
1266  if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1267  return rewriter.notifyMatchFailure(
1268  warpOp, "shape_cast is rank increasing but result layout is not a "
1269  "slice of source layout");
1270 
1271  FailureOr<VectorType> sourceDistTypeOrFailure =
1272  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1273  shapeCastOp.getSourceVectorType());
1274  if (failed(sourceDistTypeOrFailure))
1275  return rewriter.notifyMatchFailure(
1276  warpOp, "failed to get distributed vector type for source");
1277  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1278  // Create a new warp op that yields the source of the shape_cast op.
1279  SmallVector<size_t> newRetIndices;
1280  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1281  rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1282  newRetIndices);
1283  rewriter.setInsertionPointAfter(newWarpOp);
1284  Value source = newWarpOp.getResult(newRetIndices[0]);
1285  // Create a new shape_cast op outside the warp op.
1286  Value newShapeCast = vector::ShapeCastOp::create(
1287  rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1288  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1289  newShapeCast);
1290  return success();
1291  }
1292 };
1293 
1294 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1295 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1296 /// outside of the warp op.
1297 struct MemrefExtractAlignedPointerAsIndexDistribution final
1298  : public gpu::WarpDistributionPattern {
1299  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1300  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1301  PatternRewriter &rewriter) const override {
1302  OpOperand *operand = getWarpResult(
1303  warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1304  if (!operand)
1305  return rewriter.notifyMatchFailure(
1306  warpOp,
1307  "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1308  auto extractOp =
1309  operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1310  unsigned operandIdx = operand->getOperandNumber();
1311  SmallVector<size_t> newRetIndices;
1312  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1313  rewriter, warpOp, extractOp.getSource(),
1314  TypeRange{extractOp.getSource().getType()}, newRetIndices);
1315  rewriter.setInsertionPointAfter(newWarpOp);
1316  auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1317  rewriter, newWarpOp.getLoc(), extractOp.getType(),
1318  newWarpOp.getResult(newRetIndices[0]));
1319  Value distributedVal = newWarpOp.getResult(operandIdx);
1320  rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
1321  return success();
1322  }
1323 };
1324 
1325 /// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1326 /// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1327 /// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1328 /// created outside of the warp op with distributed source vector type (computed
1329 /// using assigned layout).
1330 struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1331  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1332  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1333  PatternRewriter &rewriter) const override {
1334  OpOperand *operand =
1335  getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1336  if (!operand)
1337  return rewriter.notifyMatchFailure(
1338  warpOp, "warp result is not a vector::BitCast op");
1339  auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1340  unsigned operandIdx = operand->getOperandNumber();
1341  VectorType distributedSourceType =
1342  getDistVecTypeBasedOnLaneLayout(
1343  xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
1344  bitcastOp.getSourceVectorType())
1345  .value_or(VectorType());
1346  if (!distributedSourceType)
1347  return rewriter.notifyMatchFailure(
1348  bitcastOp, "Failed to distribute the source vector type in "
1349  "vector::BitCast op");
1350  VectorType distributedResultType =
1351  cast<VectorType>(warpOp.getResult(operandIdx).getType());
1352  SmallVector<size_t> newRetIndices;
1353  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1354  rewriter, warpOp, bitcastOp.getSource(),
1355  TypeRange{distributedSourceType}, newRetIndices);
1356  rewriter.setInsertionPointAfter(newWarpOp);
1357  auto newBitcastOp = vector::BitCastOp::create(
1358  rewriter, newWarpOp.getLoc(), distributedResultType,
1359  newWarpOp.getResult(newRetIndices[0]));
1360  Value distributedVal = newWarpOp.getResult(operandIdx);
1361  rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1362  return success();
1363  }
1364 };
1365 
1366 /// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1367 /// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1368 /// supported. In most cases, transpose is a no op because it is entirely
1369 /// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1370 /// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1371 /// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1372 /// vector::TransposeOp outside of the warp op with distributed source vector
1373 /// type (computed using assigned layout).
1374 struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1375  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1376  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1377  PatternRewriter &rewriter) const override {
1378  OpOperand *operand =
1379  getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1380  if (!operand)
1381  return rewriter.notifyMatchFailure(
1382  warpOp, "warp result is not a vector::Transpose op");
1383  auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1384  unsigned operandIdx = operand->getOperandNumber();
1385  xegpu::DistributeLayoutAttr sourceLayout =
1386  xegpu::getDistributeLayoutAttr(transposeOp.getVector());
1387  xegpu::DistributeLayoutAttr resultLayout =
1388  xegpu::getDistributeLayoutAttr(transposeOp.getResult());
1389  if (!sourceLayout || !resultLayout)
1390  return rewriter.notifyMatchFailure(
1391  transposeOp,
1392  "the source or result vector of the transpose op lacks layout "
1393  "attribute");
1394  int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1395  int64_t resultRank = transposeOp.getResultVectorType().getRank();
1396  // Only 2D transposes are supported for now.
1397  // TODO: Support nD transposes.
1398  if (sourceRank != 2 || resultRank != 2)
1399  return rewriter.notifyMatchFailure(
1400  transposeOp, "the source or result vector of the transpose op "
1401  "does not have 2D layout");
1402  ArrayRef<int64_t> perm = transposeOp.getPermutation();
1403  // Result layout must be a transpose of source layout.
1404  if (!resultLayout.isTransposeOf(sourceLayout, perm))
1405  return rewriter.notifyMatchFailure(
1406  transposeOp,
1407  "the source or result vector layouts must be 2D transposes of each "
1408  "other");
1409  FailureOr<VectorType> distributedSourceTypeOrFailure =
1410  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1411  transposeOp.getSourceVectorType());
1412  if (failed(distributedSourceTypeOrFailure))
1413  return rewriter.notifyMatchFailure(
1414  transposeOp, "Failed to distribute the source vector type in "
1415  "vector::Transpose op");
1416  SmallVector<size_t> newRetIndices;
1417  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1418  rewriter, warpOp, transposeOp.getVector(),
1419  TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1420  rewriter.setInsertionPointAfter(newWarpOp);
1421  auto newTransposeOp = vector::TransposeOp::create(
1422  rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1423  perm);
1424  Value distributedVal = newWarpOp.getResult(operandIdx);
1425  rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
1426  return success();
1427  }
1428 };
1429 
1430 } // namespace
1431 
1432 namespace {
1433 struct XeGPUSubgroupDistributePass final
1434  : public xegpu::impl::XeGPUSubgroupDistributeBase<
1435  XeGPUSubgroupDistributePass> {
1436  void runOnOperation() override;
1437 };
1438 } // namespace
1439 
1442  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1443  LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1444  GpuBarrierDistribution, VectorMultiReductionDistribution,
1445  LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1446  VectorBitcastDistribution,
1447  MemrefExtractAlignedPointerAsIndexDistribution>(
1448  patterns.getContext(),
1449  /*pattern benefit=*/regularPatternBenefit);
1450  patterns.add<VectorShapeCastDistribution>(
1451  patterns.getContext(),
1452  /*pattern benefit=*/highPatternBenefit);
1453 }
1454 
1457  patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
1458 }
1459 
1460 void XeGPUSubgroupDistributePass::runOnOperation() {
1461  // Step 1: Attach layouts to op operands.
1462  // TODO: Following assumptions are made:
1463  // 1) It is assumed that there are no layout conflicts.
1464  // 2) Any existing layout attributes attached to the operands are ignored.
1465  Operation *op = getOperation();
1466  op->walk([&](Operation *op) {
1467  for (OpOperand &operand : op->getOpOperands()) {
1468  // Layouts are needed for vector type only.
1469  if (!isa<VectorType>(operand.get().getType()))
1470  continue;
1471 
1472  auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1473  if (!layout) {
1474  op->emitError("Could not find layout attribute for operand ")
1475  << operand.getOperandNumber() << " of operation " << op->getName();
1476  signalPassFailure();
1477  return;
1478  }
1479  xegpu::setDistributeLayoutAttr(operand, layout);
1480  }
1481  });
1482  // Step 2: Move all operations of a GPU function inside
1483  // gpu.warp_execute_on_lane_0 operation.
1484  {
1487 
1488  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1489  signalPassFailure();
1490  return;
1491  }
1492  // At this point, we have moved the entire function body inside the
1493  // warpOp. Now move any scalar uniform code outside of the warpOp (like
1494  // GPU index ops, scalar constants, etc.). This will simplify the
1495  // later lowering and avoid custom patterns for these ops.
1496  getOperation()->walk([&](Operation *op) {
1497  if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1498  vector::moveScalarUniformCode(warpOp);
1499  });
1500  }
1501  // Step 3: Apply subgroup to workitem distribution patterns.
1504  // distributionFn is used by vector distribution patterns to determine the
1505  // distributed vector type for a given vector value. In XeGPU subgroup
1506  // distribution context, we compute this based on lane layout.
1507  auto distributionFn = [](Value val) {
1508  VectorType vecType = dyn_cast<VectorType>(val.getType());
1509  int64_t vecRank = vecType ? vecType.getRank() : 0;
1510  if (vecRank == 0)
1511  return AffineMap::get(val.getContext());
1512  // Get the layout of the vector type.
1513  xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
1514  // If no layout is specified, that means no distribution.
1515  if (!layout)
1516  return AffineMap::getMultiDimMapWithTargets(vecRank, {},
1517  val.getContext());
1518  // Expecting vector and layout rank to match.
1519  assert(layout.getRank() == vecRank &&
1520  "Expecting vector and layout rank to match");
1521  // A dimension is distributed only if layout suggests there are
1522  // multiple lanes assigned for this dimension and the shape can be evenly
1523  // distributed to those lanes.
1524  SmallVector<unsigned int> distributedDims;
1525  for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1526  if (v > 1 && vecType.getShape()[i] % v == 0)
1527  distributedDims.push_back(i);
1528  }
1529  return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
1530  val.getContext());
1531  };
1532  // TODO: shuffleFn is not used.
1533  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1534  int64_t warpSz) { return Value(); };
1535 
1536  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
1537  vector::CombiningKind kind, uint32_t size) {
1538  // First reduce on a single thread to get per lane reduction value.
1539  Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
1540  // Parallel reduction using butterfly shuffles.
1541  for (uint64_t i = 1; i < size; i <<= 1) {
1542  Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
1543  /*width=*/size,
1544  /*mode=*/gpu::ShuffleMode::XOR)
1545  .getShuffleResult();
1546  laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
1547  }
1548  return laneVal;
1549  };
1550 
1551  vector::populateDistributeReduction(
1552  patterns, warpReduction,
1553  /*pattern benefit=*/regularPatternBenefit);
1554 
1555  vector::populatePropagateWarpVectorDistributionPatterns(
1556  patterns, distributionFn, shuffleFn,
1557  /*pattern benefit=*/regularPatternBenefit);
1558  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1559  signalPassFailure();
1560  return;
1561  }
1562 
1563  // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
1564  // due to tensor desc type mismatches created by using upstream distribution
1565  // patterns (scf.for). This cleanup should only be done if all the ops are
1566  // distributed successfully, if some ops are still not distributed and remains
1567  // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
1568  // breaking the IR.
1569  bool foundWarpOp = false;
1570  getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1571  // Look for WarpOps that are not trivially dead.
1572  if (isOpTriviallyDead(warpOp))
1573  return WalkResult::advance();
1574  foundWarpOp = true;
1575  return WalkResult::interrupt();
1576  });
1577  if (foundWarpOp)
1578  return;
1579 
1580  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1581  // We are only interested in UnrealizedConversionCastOps there were added
1582  // for resolving SIMT type mismatches.
1583  if (!op->getAttr(resolveSIMTTypeMismatch))
1584  return WalkResult::skip();
1585 
1586  Value input = op.getOperand(0);
1587  Value output = op.getResult(0);
1588 
1589  // Both input and output must have tensor descriptor types.
1590  xegpu::TensorDescType inputDescType =
1591  mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1592  xegpu::TensorDescType outputDescType =
1593  mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1594  assert(inputDescType && outputDescType &&
1595  "Unrealized conversion cast must have tensor descriptor types");
1596 
1597  // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
1598  // This occurs inside scf.for body to resolve the block argument type to
1599  // SIMT type.
1600  if (inputDescType.getLayout()) {
1601  auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1602  if (argument) {
1603  argument.setType(output.getType());
1604  output.replaceAllUsesWith(argument);
1605  if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1606  argument.getOwner()->getParentOp())) {
1607  auto result = loopOp.getTiedLoopResult(argument);
1608  result.setType(output.getType());
1609  }
1610  }
1611  }
1612 
1613  // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
1614  // conversions. This occurs at the yield op of scf.for body to go back
1615  // from SIMT type to original type.
1616  if (outputDescType.getLayout())
1617  output.replaceAllUsesWith(input);
1618 
1619  if (op->use_empty())
1620  op->erase();
1621  return WalkResult::advance();
1622  });
1623 }
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:370
const uArch * getUArch(llvm::StringRef archName)
Definition: IntelGpuXe2.h:268
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:178
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns)
Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
Definition: XeGPUUtils.cpp:105
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:116
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:448
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
Definition: XeGPUUtils.cpp:229
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:40
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
virtual int getSubgroupSize() const =0
StringRef getName() const
Definition: uArchBase.h:152