MLIR  22.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeRange.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/IR/Visitors.h"
30 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 
38 namespace mlir {
39 namespace xegpu {
40 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
42 } // namespace xegpu
43 } // namespace mlir
44 
45 #define DEBUG_TYPE "xegpu-subgroup-distribute"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47 
48 using namespace mlir;
49 
50 static const char *const resolveSIMTTypeMismatch =
51  "resolve_simt_type_mismatch"; // Attribute name for identifying
52  // UnrelizedConversionCastOp added to resolve
53  // SIMT type mismatches.
54 
55 namespace {
56 
57 //===----------------------------------------------------------------------===//
58 // SIMT Distribution Patterns
59 //===----------------------------------------------------------------------===//
60 
61 /// In certain cases, we may need to favor XeGPU specific distribution patterns
62 /// over generic vector distribution patterns. In such cases, we can assign
63 /// priorities to patterns.
64 static constexpr unsigned regularPatternBenefit = 1;
65 static constexpr unsigned highPatternBenefit = 2;
66 
67 /// Helper function to get distributed vector type for a source vector type
68 /// according to the lane_layout. We simply divide each dimension of tensor
69 /// descriptor shape by corresponding lane_layout dimension. If
70 /// array_length > 1, that is appended to the front of the ditributed shape.
71 /// NOTE: This is the vector type that will be returned by the
72 /// gpu.warp_execute_on_lane0 op.
73 ///
74 /// Examples:
75 /// | original vector shape | lane_layout | distributed vector shape |
76 /// |-----------------------|-------------|--------------------------|
77 /// | 32x16 | [1, 16] | 32x1 |
78 /// | 32x16 | [2, 8] | 16x2 |
79 /// | 2x32x16 | [1, 16] | 2x32x1 |
80 static FailureOr<VectorType>
81 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
82  VectorType originalType) {
83  if (!layout)
84  return failure();
85  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
86  "Expecting a valid layout.");
87  SmallVector<int64_t> effectiveLaneLayout =
88  layout.getEffectiveLaneLayoutAsInt();
89  assert(static_cast<size_t>(originalType.getRank()) >=
90  effectiveLaneLayout.size() &&
91  "Rank of the original vector type should be greater or equal to the "
92  "size of the lane layout to distribute the vector type.");
93  SmallVector<int64_t> distributedShape(originalType.getShape());
94  // Only distribute the last `laneLayout.size()` dimensions. The remaining
95  // dimensions are not distributed.
96  unsigned distributionStart =
97  originalType.getRank() - effectiveLaneLayout.size();
98  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
99  if (i < distributionStart)
100  continue;
101 
102  // Check if the dimension can be distributed evenly.
103  if (dim % effectiveLaneLayout[i - distributionStart] != 0)
104  return failure();
105  distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
106  }
107  return VectorType::get(distributedShape, originalType.getElementType());
108 }
109 
110 /// Helper function to resolve types if the distributed type out of
111 /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
112 /// Example 1:
113 /// distributed type: vector<8x1xf32>
114 /// expected type: vector<8xf32>
115 /// resolved using,
116 /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
117 /// Example 2:
118 /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
119 /// expected type: xegpu.tensor_desc<8x16xf32>
120 /// resolved using,
121 /// %0 = unrealized_conversion_cast %1 :
122 /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
123 /// xegpu.tensor_desc<8x16xf32>
124 template <typename T>
125 static Value resolveDistributedTy(Value orig, T expected,
126  PatternRewriter &rewriter) {
127  // If orig and expected types are the same, return orig.
128  if (orig.getType() == expected)
129  return orig;
130  // If orig is a vector type, create a shape cast op to reconcile the types.
131  if (isa<VectorType>(orig.getType())) {
132  auto castOp =
133  vector::ShapeCastOp::create(rewriter, orig.getLoc(), expected, orig);
134  return castOp.getResult();
135  }
136  // If orig is a tensor descriptor type, create an unrealized conversion cast
137  // op to reconcile the types.
138  if (isa<xegpu::TensorDescType>(orig.getType())) {
139  auto castOp = UnrealizedConversionCastOp::create(rewriter, orig.getLoc(),
140  expected, orig);
141  castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
142  return castOp.getResult(0);
143  }
144  llvm_unreachable("Unsupported type for reconciliation");
145  return orig;
146 }
147 
148 /// Helper function to check if the layout is packed. Layout is packed if it is
149 /// 2D and lane_data[0] != 1 (data packed from col dimension).
150 /// TODO: Move to target info.
151 static bool requirePacked(const xegpu::LayoutAttr layout) {
152  if (!layout)
153  return false;
154  auto laneData = layout.getEffectiveLaneDataAsInt();
155  if (laneData.size() != 2)
156  return false;
157  return laneData[0] != 1;
158 }
159 
160 /// Helper function to check if the layout requires a transpose effect.
161 static bool requireTranspose(const xegpu::LayoutAttr layout,
162  const std::string &chipStr) {
163  // Return false for unsupported targets.
164  // TODO: Add more support or move to target info.
165  if (chipStr != "pvc" && chipStr != "bmg")
166  return false;
167  if (!layout)
168  return false;
169  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
170  if (laneLayout.size() != 2)
171  return false;
172  return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1;
173 }
174 
175 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
176 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
177 /// contained within a WarpExecuteOnLane0Op.
178 /// Example:
179 ///
180 /// ```
181 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
182 /// ...
183 /// ...
184 /// gpu.return %result: vector<8x16xf32>
185 /// }
186 /// ```
187 /// To
188 /// ```
189 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
190 /// %laneid = gpu.lane_id : index
191 /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
192 /// ...
193 /// ...
194 /// gpu.yield %result: vector<8x16xf32>
195 /// }
196 /// return %0
197 /// }
198 struct MoveFuncBodyToWarpExecuteOnLane0
199  : public OpRewritePattern<gpu::GPUFuncOp> {
201  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
202  PatternRewriter &rewriter) const override {
203  // If the function only contains a single void return, skip.
204  if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
205  return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
206  }))
207  return failure();
208  // If the function already moved inside a warp_execute_on_lane0, skip.
209  if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
210  return isa<gpu::WarpExecuteOnLane0Op>(op);
211  }))
212  return failure();
213  // Create a new function with the same signature and same attributes.
214  SmallVector<Type> workgroupAttributionsTypes =
215  llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
216  [](BlockArgument arg) { return arg.getType(); });
217  SmallVector<Type> privateAttributionsTypes =
218  llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
219  [](BlockArgument arg) { return arg.getType(); });
220  auto newGpuFunc = gpu::GPUFuncOp::create(
221  rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(),
222  gpuFuncOp.getFunctionType(), workgroupAttributionsTypes,
223  privateAttributionsTypes);
224  newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
225  // Create a WarpExecuteOnLane0Op with same arguments and results as the
226  // original gpuFuncOp.
227  rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
228  auto laneId = gpu::LaneIdOp::create(
229  rewriter, newGpuFunc.getLoc(), rewriter.getIndexType(),
230  /** upperBound = **/ mlir::IntegerAttr());
231  ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
232  auto warpOp = gpu::WarpExecuteOnLane0Op::create(
233  rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
234  xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
235  newGpuFunc.getArgumentTypes());
236  Block &warpBodyBlock = warpOp.getBodyRegion().front();
237  // Replace the ReturnOp of the original gpu function with a YieldOp.
238  auto origRetunOp =
239  cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
240  rewriter.setInsertionPointAfter(origRetunOp);
241  gpu::YieldOp::create(rewriter, origRetunOp.getLoc(),
242  origRetunOp.getOperands());
243  rewriter.eraseOp(origRetunOp);
244  // Move the original function body to the WarpExecuteOnLane0Op body.
245  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
246  warpOp.getBodyRegion().begin());
247  rewriter.eraseBlock(&warpBodyBlock);
248  // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
249  rewriter.setInsertionPointAfter(warpOp);
250  gpu::ReturnOp::create(rewriter, newGpuFunc.getLoc(), warpOp.getResults());
251  rewriter.replaceOp(gpuFuncOp, newGpuFunc);
252  return success();
253  }
254 };
255 
256 /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
257 /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
258 /// still contain the original op that will not be used by the yield op (and
259 /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
260 /// arguments. Tensor descriptor shape is not distributed because it is a
261 /// uniform value across all work items within the subgroup. However, the
262 /// layout information is dropped in the new tensor descriptor type.
263 ///
264 /// Example:
265 ///
266 /// ```
267 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
268 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
269 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
270 /// ...
271 /// %td = xegpu.create_nd_tdesc %arg0
272 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
273 /// vector.yield %td
274 /// }
275 /// ```
276 /// To
277 /// ```
278 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
279 /// ...
280 /// %dead = xegpu.create_nd_tdesc %arg0
281 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
282 /// vector.yield %arg0, %dead
283 /// }
284 /// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
285 /// -> !xegpu.tensor_desc<4x8xf32>
286 ///
287 /// ```
288 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
289  using gpu::WarpDistributionPattern::WarpDistributionPattern;
290  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
291  PatternRewriter &rewriter) const override {
292  OpOperand *operand =
293  getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
294  if (!operand)
295  return rewriter.notifyMatchFailure(
296  warpOp, "warp result is not a xegpu::CreateNdDesc op");
297  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
298  unsigned operandIdx = operand->getOperandNumber();
299 
300  xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
301  if (!layout)
302  return rewriter.notifyMatchFailure(
303  descOp, "the tensor descriptor lacks layout attribute");
304  // CreateNdOp must not have offsets.
305  if (descOp.getMixedOffsets().size())
306  return rewriter.notifyMatchFailure(
307  descOp, "xegpu::CreateNdDescOp must not have offsets");
308 
309  SmallVector<size_t> newRetIndices;
310  rewriter.setInsertionPoint(warpOp);
311  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
312  rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
313  /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
314 
315  SmallVector<Value> newDescOperands = llvm::map_to_vector(
316  newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
317  rewriter.setInsertionPointAfter(newWarpOp);
318  xegpu::TensorDescType distributedTensorDescTy =
319  descOp.getType().dropLayouts(); // Distributed tensor descriptor type
320  // does not contain layout info.
321  Value newDescOp = xegpu::CreateNdDescOp::create(
322  rewriter, newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
323  descOp->getAttrs());
324 
325  Value distributedVal = newWarpOp.getResult(operandIdx);
326  // Resolve the distributed type to the expected type.
327  newDescOp =
328  resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
329  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
330  return success();
331  }
332 };
333 
334 /// Distribute a store_nd op at the end of enclosing
335 /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
336 /// through the warp op interface they would be propagated as returned values.
337 /// Source vector is distributed based on lane layout. Appropriate cast ops are
338 /// inserted if the distributed types does not match expected xegpu SIMT types.
339 ///
340 /// Example:
341 ///
342 /// ```
343 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
344 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
345 /// ...
346 /// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
347 /// !xegpu.tensor_desc<4x8xf32, #layout0>
348 /// }
349 /// ```
350 /// To
351 /// ```
352 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
353 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
354 /// ...
355 /// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
356 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
357 /// }
358 /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
359 /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
360 /// #layout0>
361 /// -> !xegpu.tensor_desc<4x8xf32>
362 /// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
363 /// !xegpu.tensor_desc<4x8xf32>
364 ///
365 /// ```
366 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
367  using gpu::WarpDistributionPattern::WarpDistributionPattern;
368  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
369  PatternRewriter &rewriter) const override {
370  gpu::YieldOp yield = warpOp.getTerminator();
371  Operation *lastNode = yield->getPrevNode();
372  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
373  if (!storeOp)
374  return failure();
375 
376  SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
377  // Expecting offsets to be present.
378  if (offsets.empty())
379  return rewriter.notifyMatchFailure(storeOp,
380  "the store op must have offsets");
381  SmallVector<Value> offsetsAsValues =
382  vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
383  SmallVector<Type> offsetTypes = llvm::to_vector(
384  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
385  xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
386  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
387  if (!layout)
388  return rewriter.notifyMatchFailure(
389  storeOp, "the source tensor descriptor lacks layout attribute");
390 
391  FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
392  getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
393  if (failed(distributedTypeByWarpOpOrFailure))
394  return rewriter.notifyMatchFailure(storeOp,
395  "Failed to distribute the type");
396  VectorType distributedTypeByWarpOp =
397  distributedTypeByWarpOpOrFailure.value();
398 
399  SmallVector<size_t> newRetIndices;
400  SmallVector<Value> newYieldedValues = {storeOp.getValue(),
401  storeOp.getTensorDesc()};
402  SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
403  newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
404  newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
405  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
406  rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
407  // Create a new store op outside the warp op with the distributed vector
408  // type. Tensor descriptor is not distributed.
409  rewriter.setInsertionPointAfter(newWarpOp);
410  SmallVector<Value> newStoreOperands;
411 
412  // For the value operand, there can be a mismatch between the vector type
413  // distributed by the warp op and (xegpu-specific) distributed type
414  // supported by the store op. Type mismatch must be resolved using
415  // appropriate cast op.
416  FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
417  xegpu::getDistributedVectorType(storeOp.getTensorDescType());
418  if (failed(storeNdDistributedValueTyOrFailure))
419  return rewriter.notifyMatchFailure(
420  storeOp, "Failed to get distributed vector type for the store op");
421  newStoreOperands.push_back(resolveDistributedTy(
422  newWarpOp.getResult(newRetIndices[0]),
423  storeNdDistributedValueTyOrFailure.value(), rewriter));
424  // For the tensor descriptor operand, the layout attribute is dropped after
425  // distribution. Types needs to be resolved in this case also.
426  xegpu::TensorDescType distributedTensorDescTy =
427  storeOp.getTensorDescType().dropLayouts();
428  newStoreOperands.push_back(
429  resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
430  distributedTensorDescTy, rewriter));
431  // Collect offsets.
432  for (size_t i = 2; i < newRetIndices.size(); ++i)
433  newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
434 
435  auto newStoreOp =
436  xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
437  newStoreOperands, storeOp->getAttrs());
438  xegpu::removeLayoutAttrs(newStoreOp);
439  rewriter.eraseOp(storeOp);
440  return success();
441  }
442 };
443 
444 /// Distribute a load_nd op feeding into vector.yield op for the enclosing
445 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
446 /// The warp op will still contain the original op that will not be used by
447 /// the yield op (and should be cleaned up later). The yield op will
448 /// bypass the load's arguments. Only the loaded vector is distributed
449 /// according to lane layout and, tensor descriptor types is not
450 /// distributed. Appropriate cast ops are inserted if the distributed types does
451 /// not match expected xegpu SIMT types.
452 ///
453 /// Example:
454 ///
455 /// ```
456 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
457 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
458 /// (vector<4x1xf32>) {
459 /// ...
460 /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
461 /// ->
462 /// vector<4x8xf32>
463 /// gpu.yield %ld
464 /// }
465 /// ```
466 /// To
467 /// ```
468 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
469 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
470 /// ...
471 /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
472 /// vector<4x8xf32> gpu.yield %dead, %arg0
473 /// }
474 /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
475 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
476 /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
477 /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
478 ///
479 /// ```
480 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
481  using gpu::WarpDistributionPattern::WarpDistributionPattern;
482  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
483  PatternRewriter &rewriter) const override {
484  OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
485  if (!isa<xegpu::LoadNdOp>(op))
486  return false;
487  // Make sure the same load op is the last operation in the warp op body.
488  // This ensure that load op is not sinked earlier violating any barrier
489  // synchronizations.
490  gpu::YieldOp yield = warpOp.getTerminator();
491  return yield->getPrevNode() == op;
492  });
493 
494  if (!operand)
495  return rewriter.notifyMatchFailure(
496  warpOp, "warp result is not a xegpu::LoadNd op");
497 
498  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
499  // Chip information is required to decide if the layout requires transpose
500  // effect.
501  auto chipStr = xegpu::getChipStr(loadOp);
502  if (!chipStr)
503  return rewriter.notifyMatchFailure(
504  loadOp,
505  "xegpu::LoadNdOp require chip information to determine transpose "
506  "requirement");
507  // Expecting offsets to be present.
508  SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
509  if (offsets.empty())
510  return rewriter.notifyMatchFailure(loadOp,
511  "the load op must have offsets");
512  SmallVector<Value> offsetsAsValues =
513  vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
514  SmallVector<Type> offsetTypes = llvm::to_vector(
515  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
516 
517  xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
518  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
519  if (!layout)
520  return rewriter.notifyMatchFailure(
521  loadOp, "the source tensor descriptor lacks layout attribute");
522 
523  unsigned operandIdx = operand->getOperandNumber();
524  VectorType distributedTypeByWarpOp =
525  cast<VectorType>(warpOp.getResult(operandIdx).getType());
526 
527  SmallVector<size_t> newRetIndices;
528  SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
529  SmallVector<Type> newYieldedTypes = {tensorDescTy};
530  newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
531  newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
532  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
533  rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
534 
535  // Create a new load op outside the warp op with the distributed vector
536  // type.
537  rewriter.setInsertionPointAfter(newWarpOp);
538  FailureOr<VectorType> loadNdDistValueTyOrFailure =
539  xegpu::getDistributedVectorType(loadOp.getTensorDescType());
540  if (failed(loadNdDistValueTyOrFailure))
541  return rewriter.notifyMatchFailure(
542  loadOp, "Failed to get distributed vector type for the load op");
543  xegpu::TensorDescType distributedTensorDescTy =
544  loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
545  // descriptor type does not
546  // contain layout info.
547  SmallVector<Value> newLoadOperands{
548  resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
549  distributedTensorDescTy, rewriter)};
550  // Collect offsets.
551  for (size_t i = 1; i < newRetIndices.size(); ++i)
552  newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
553  auto newLoadOp = xegpu::LoadNdOp::create(
554  rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
555  newLoadOperands, loadOp->getAttrs());
556  xegpu::removeLayoutAttrs(newLoadOp);
557  // Set the packed attribute if the layout requires it.
558  newLoadOp.setPacked(requirePacked(layout));
559  // Set the transpose attribute if the layout requires it.
560  if (requireTranspose(layout, chipStr.value()))
561  newLoadOp.setTranspose(
562  DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
563  Value distributedVal = newWarpOp.getResult(operandIdx);
564  // There can be a conflict between the vector type distributed by the
565  // warp op and (xegpu-specific) distributed type supported by the load
566  // op. Resolve these mismatches by inserting a cast.
567  Value tyResolvedVal = resolveDistributedTy(
568  newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
569  rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
570  return success();
571  }
572 };
573 
574 /// Distribute a dpas op feeding into vector.yield op for the enclosing
575 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
576 /// The warp op will still contain the original op that will not be used by
577 /// the yield op (and should be cleaned up later). The yield op will
578 /// bypass the dpas's arguments. Appropriate cast ops are inserted if the
579 /// distributed types does not match expected xegpu SIMT types.
580 /// Example:
581 /// ```
582 /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
583 /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
584 /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
585 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
586 /// (vector<8x1xf32>) {
587 /// ...
588 /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
589 /// vector<8x16xf32>
590 /// gpu.yield %dpas
591 /// }
592 /// ```
593 /// To
594 /// ```
595 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
596 /// vector<8x1xf16>, vector<16x1xf16>) {
597 /// ...
598 /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
599 /// -> vector<8x16xf32>
600 /// gpu.yield %dead, %arg0, %arg1
601 /// }
602 /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
603 /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
604 /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
605 /// vector<8xf32>
606 /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
607 /// ```
608 struct DpasDistribution final : public gpu::WarpDistributionPattern {
609  using gpu::WarpDistributionPattern::WarpDistributionPattern;
610  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
611  PatternRewriter &rewriter) const override {
612  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
613  if (!operand)
614  return rewriter.notifyMatchFailure(warpOp,
615  "warp result is not a xegpu::Dpas op");
616 
617  auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
618  unsigned operandIdx = operand->getOperandNumber();
619  std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
620  std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
621  std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
622 
623  xegpu::LayoutAttr layoutA =
624  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
625  xegpu::LayoutAttr layoutB =
626  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
627  xegpu::LayoutAttr layoutOut =
628  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
629  if (!layoutA || !layoutB || !layoutOut)
630  return rewriter.notifyMatchFailure(
631  dpasOp,
632  "the xegpu::Dpas op lacks layout attribute for A, B or output");
633 
634  FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
635  getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
636  FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
637  getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
638  FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
639  getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
640  if (failed(distLhsTypeByWarpOpOrFailure) ||
641  failed(distRhsTypeByWarpOpOrFailure) ||
642  failed(distResultTypeByWarpOpOrFailure))
643  return rewriter.notifyMatchFailure(
644  dpasOp,
645  "Failed to distribute the A, B or output types in xegpu::Dpas op");
646 
647  llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
648  dpasOp.getRhs()};
649  llvm::SmallVector<Type, 3> newYieldTypes{
650  distLhsTypeByWarpOpOrFailure.value(),
651  distRhsTypeByWarpOpOrFailure.value()};
652  // Dpas acc operand is optional.
653  if (dpasOp.getAcc()) {
654  newYieldValues.push_back(dpasOp.getAcc());
655  newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
656  }
657  // Create a new warp op without the dpas.
658  SmallVector<size_t> newRetIndices;
659  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
660  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
661 
662  FailureOr<VectorType> expectedDistLhsTyOrFailure =
663  xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
664  FailureOr<VectorType> expectedDistRhsTyOrFailure =
665  xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
666  FailureOr<VectorType> expectedDistResultTyOrFailure =
667  xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
668  if (failed(expectedDistLhsTyOrFailure) ||
669  failed(expectedDistRhsTyOrFailure) ||
670  failed(expectedDistResultTyOrFailure))
671  return rewriter.notifyMatchFailure(
672  dpasOp,
673  "Failed to get distributed vector type for the dpas operands.");
674  // Create a new dpas op outside the warp op.
675  rewriter.setInsertionPointAfter(newWarpOp);
676  SmallVector<Value> newDpasOperands;
677  SmallVector<VectorType> newDpasOperandExpectedTypes;
678 
679  // Resolve the distributed types with the original types.
680  newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
681  newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
682  VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
683  if (dpasOp.getAcc())
684  newDpasOperandExpectedTypes.push_back(distributedResultTy);
685 
686  for (unsigned i = 0; i < newRetIndices.size(); i++) {
687  newDpasOperands.push_back(
688  resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
689  newDpasOperandExpectedTypes[i], rewriter));
690  }
691  auto newDpasOp = xegpu::DpasOp::create(rewriter, newWarpOp->getLoc(),
692  distributedResultTy, newDpasOperands,
693  dpasOp->getAttrs());
694  xegpu::removeLayoutAttrs(newDpasOp);
695  Value distributedVal = newWarpOp.getResult(operandIdx);
696  // Resolve the output type.
697  Value typeResolved =
698  resolveDistributedTy(newDpasOp.getResult(),
699  distResultTypeByWarpOpOrFailure.value(), rewriter);
700  rewriter.replaceAllUsesWith(distributedVal, typeResolved);
701  return success();
702  }
703 };
704 
705 /// Distribute a prefetch_nd op at the end of enclosing
706 /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
707 /// through the warp op interface they would be propagated as returned values.
708 /// Tensor descriptor shape is not distributed because it is a uniform value
709 /// across all work items within the subgroup. Appropriate cast ops are inserted
710 /// if the distributed types does not match expected xegpu SIMT types.
711 ///
712 /// Example:
713 ///
714 /// ```
715 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
716 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
717 /// ...
718 /// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
719 /// }
720 /// ```
721 /// To
722 /// ```
723 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
724 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
725 /// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
726 /// index
727 /// }
728 /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
729 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
730 /// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
731 ///
732 /// ```
733 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
734  using gpu::WarpDistributionPattern::WarpDistributionPattern;
735  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
736  PatternRewriter &rewriter) const override {
737  gpu::YieldOp yield = warpOp.getTerminator();
738  Operation *lastNode = yield->getPrevNode();
739  auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
740  if (!prefetchOp)
741  return failure();
742 
743  SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
744  // PrefetchNdOp must have offsets.
745  if (offsets.empty())
746  return rewriter.notifyMatchFailure(prefetchOp,
747  "the prefetch op must have offsets");
748  SmallVector<Value> offsetsAsValues =
749  vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
750  SmallVector<Type> offsetTypes = llvm::to_vector(
751  llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
752 
753  xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
754  if (!layout)
755  return rewriter.notifyMatchFailure(
756  prefetchOp, "the source tensor descriptor lacks layout attribute");
757 
758  SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
759  SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
760  newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
761  newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
762  SmallVector<size_t> newRetIndices;
763  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
764  rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
765  // Create a new prefetch op outside the warp op with updated tensor
766  // descriptor type. Source tensor descriptor require type resolution.
767  xegpu::TensorDescType newTensorDescTy =
768  prefetchOp.getTensorDescType().dropLayouts();
769  rewriter.setInsertionPointAfter(newWarpOp);
770  SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
771  newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
772  // Collect offsets.
773  for (size_t i = 1; i < newRetIndices.size(); ++i)
774  newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
775  xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
776  newPrefetchOperands, prefetchOp->getAttrs());
777  xegpu::removeLayoutAttrs(prefetchOp);
778  rewriter.eraseOp(prefetchOp);
779  return success();
780  }
781 };
782 
783 /// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
784 /// region. This will simply move the barrier op outside of the warp op.
785 struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
786  using gpu::WarpDistributionPattern::WarpDistributionPattern;
787  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
788  PatternRewriter &rewriter) const override {
789  gpu::YieldOp yield = warpOp.getTerminator();
790  Operation *lastNode = yield->getPrevNode();
791  // The last node must be a gpu::BarrierOp.
792  auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
793  if (!barrierOp)
794  return failure();
795  // Move the barrier op outside of the warp op.
796  rewriter.setInsertionPointAfter(warpOp);
797  gpu::BarrierOp::create(rewriter, barrierOp.getLoc(),
798  barrierOp->getResultTypes(),
799  barrierOp->getOperands(), barrierOp->getAttrs());
800  rewriter.eraseOp(barrierOp);
801  return success();
802  }
803 };
804 
805 /// Distribute a scattered store op. The offsets argument is required.
806 /// Both offset and mask vectors must be 1D and have #subgroup_size elements.
807 /// The layouts are fixed and implicit: one offset/mask per lane.
808 /// The pass changes the offset/mask vector shapes to a
809 /// single-element vector, **it is assumed that their producer will also be
810 /// distributed**. The payload vector also has a fixed distribution:
811 /// no chunk size -> vector of one element.
812 /// chunk size -> vector of the innermost dimension of the SG-payload.
813 /// Example 1 (no chunk size):
814 /// %mask = producer_op : vector<16xi1>
815 /// %offset = producer_op : vector<16xindex>
816 /// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
817 /// memref<256xf16>, vector<16xindex>, vector<16xi1>
818 /// To
819 /// %mask = producer_op : vector<1xi1>
820 /// %offset = producer_op : vector<1xindex>
821 /// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
822 /// memref<256xf16>, vector<1xindex>, vector<1xi1>
823 /// Example 2 (chunk size, same mask and offsets):
824 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
825 /// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
826 /// To
827 /// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
828 /// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
829 struct StoreDistribution final : public gpu::WarpDistributionPattern {
830  using gpu::WarpDistributionPattern::WarpDistributionPattern;
831  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
832  PatternRewriter &rewriter) const override {
833  Operation *lastNode = warpOp.getTerminator()->getPrevNode();
834  auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
835  if (!storeScatterOp)
836  return failure();
837  auto offsets = storeScatterOp.getOffsets();
838  if (!offsets || !isa<VectorType>(offsets.getType()))
839  return rewriter.notifyMatchFailure(
840  storeScatterOp, "Store op must have a vector of offsets argument");
841  VectorType offsetsTy = cast<VectorType>(offsets.getType());
842  VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
843  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
844  return rewriter.notifyMatchFailure(storeScatterOp,
845  "Expected 1D offsets and mask vector");
846  VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
847  if (storeVecTy.getRank() > 2)
848  return rewriter.notifyMatchFailure(
849  storeScatterOp, "Expected at most 2D result at SG level");
850 
851  std::string layoutPayloadName =
852  xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
853  std::string layoutOffsetsName =
854  xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
855  std::string layoutMaskName =
856  xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
857 
858  xegpu::LayoutAttr layoutPayload =
859  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
860  xegpu::LayoutAttr layoutOffsets =
861  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
862  xegpu::LayoutAttr layoutMask =
863  storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
864 
865  FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
866  getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
867  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
868  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
869  FailureOr<VectorType> distMaskByWarpOpOrFailure =
870  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
871  if (failed(distStoreVecByWarpOpOrFailure) ||
872  failed(distOffsetsByWarpOpOrFailure) ||
873  failed(distMaskByWarpOpOrFailure)) {
874  return rewriter.notifyMatchFailure(
875  storeScatterOp,
876  "Some vector operands have no layouts, using defaults instead.");
877  }
878  VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
879  VectorType expectedPayloadTy = VectorType::get(
880  {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
881 
882  SmallVector<size_t> newRetIndices;
883  SmallVector<Value> operands = storeScatterOp->getOperands();
884  SmallVector<Type> operandTypesToYield = {
885  expectedPayloadTy, operands[1].getType(),
886  distOffsetsByWarpOpOrFailure.value(),
887  distMaskByWarpOpOrFailure.value()};
888 
889  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
890  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
891  SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
892  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
893 
894  rewriter.setInsertionPointAfter(newWarpOp);
895  xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
896  rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
897  storeScatterOp->getAttrs());
899  rewriter.eraseOp(storeScatterOp);
900  return success();
901  }
902 };
903 
904 /// Distribute a scattered load op. The logic and requirements are the same as
905 /// for the scattered store distribution. The warpOp's payload vector is
906 /// expected to be distributed by the load's result consumer.
907 /// Example 1 (no chunk size):
908 /// %mask = producer_op : vector<16xi1>
909 /// %offset = producer_op : vector<16xindex>
910 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
911 /// vector<16xindex>, vector<16xi1> -> vector<16xf16>
912 /// To
913 /// %mask = producer_op : vector<1xi1>
914 /// %offset = producer_op : vector<1xindex>
915 /// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916 /// vector<1xindex>, vector<1xi1> -> vector<1xf16>
917 /// Example 2 (chunk size, same mask and offsets):
918 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
919 /// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
920 /// To
921 /// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
922 /// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
923 struct LoadDistribution final : public gpu::WarpDistributionPattern {
924  using gpu::WarpDistributionPattern::WarpDistributionPattern;
925  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
926  PatternRewriter &rewriter) const override {
927  OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
928  // Check if the yield operand that was produced by the *last* scattered
929  // load op to avoid sinking it before barriers (maintain memory order).
930  return isa<xegpu::LoadGatherOp>(op) &&
931  warpOp.getTerminator()->getPrevNode() == op;
932  });
933  if (!producedByLastLoad)
934  return rewriter.notifyMatchFailure(
935  warpOp, "The last op is not xegpu::LoadGatherOp");
936 
937  auto loadGatherOp =
938  producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
939  auto offsets = loadGatherOp.getOffsets();
940  if (!offsets || !isa<VectorType>(offsets.getType()) ||
941  !isa<VectorType>(loadGatherOp.getMask().getType()))
942  return rewriter.notifyMatchFailure(
943  loadGatherOp,
944  "Load op must have a vector arguments for offsets and mask");
945  VectorType offsetsTy = cast<VectorType>(offsets.getType());
946  VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
947  if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
948  return rewriter.notifyMatchFailure(loadGatherOp,
949  "Expected 1D offsets and mask vector");
950  // Assume offset and mask producers will be distributed as well.
951  std::string layoutOffsetsName =
952  xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
953  std::string layoutMaskName =
954  xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
955 
956  xegpu::LayoutAttr layoutOffsets =
957  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
958  xegpu::LayoutAttr layoutMask =
959  loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
960 
961  FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
962  getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
963  FailureOr<VectorType> distMaskByWarpOpOrFailure =
964  getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
965  if (failed(distOffsetsByWarpOpOrFailure) ||
966  failed(distMaskByWarpOpOrFailure)) {
967  return rewriter.notifyMatchFailure(
968  loadGatherOp,
969  "Some vector operands have no layouts, using defaults instead.");
970  }
971 
972  SmallVector<size_t> newRetIndices;
973  SmallVector<Value> operands = loadGatherOp->getOperands();
974  SmallVector<Type> operandTypesToYield = {
975  operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
976  distMaskByWarpOpOrFailure.value()};
977 
978  const unsigned operandIdx = producedByLastLoad->getOperandNumber();
979  VectorType loadVecTy =
980  cast<VectorType>(warpOp.getResult(operandIdx).getType());
981 
982  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
983  rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
984 
985  SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
986  newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
987 
988  rewriter.setInsertionPointAfter(newWarpOp);
989  xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
990  rewriter, newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
991  loadGatherOp->getAttrs());
993  Value distributedVal = newWarpOp.getResult(operandIdx);
994  rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
995  return success();
996  }
997 };
998 
999 /// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
1000 /// VectorReductionOps.
1001 static Value lowerToVectorReductions(TypedValue<VectorType> src,
1003  vector::CombiningKind kind,
1004  int64_t reductionDim, Location loc,
1005  PatternRewriter &rewriter) {
1006  // Expecting a 2D source vector.
1007  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
1008  VectorType sourceType = src.getType();
1009  int64_t sourceH = sourceType.getShape()[0];
1010  int64_t sourceW = sourceType.getShape()[1];
1011  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
1012  // Create a constant vector to hold the result of the reduction.
1013  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
1014  Value reductionResult = arith::ConstantOp::create(
1015  rewriter, loc, acc.getType(),
1016  DenseElementsAttr::get(acc.getType(), zeroAttr));
1017  // For each slice of the source, extract the slice vector, do a reduction
1018  // and, insert the reduced value back to the result vector.
1019  for (int i = 0; i < nSlices; ++i) {
1020  SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
1021  if (reductionDim == 1) {
1022  sliceOffsets = {i, 0};
1023  sliceSizes = {1, sourceW};
1024  } else {
1025  sliceOffsets = {0, i};
1026  sliceSizes = {sourceH, 1};
1027  }
1028  vector::ExtractStridedSliceOp extractOp =
1029  vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
1030  sliceSizes, {1, 1});
1031  int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
1032  Value slice = vector::ShapeCastOp::create(
1033  rewriter, loc,
1034  VectorType::get({nSliceElements}, sourceType.getElementType()),
1035  extractOp.getResult());
1036  Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
1037  Value reduction =
1038  vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
1039  reductionResult =
1040  vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
1041  }
1042  return reductionResult;
1043 }
1044 
1045 /// This patterns distribute the `vector.multi_reduction` operation across
1046 /// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
1047 /// layouts for the source and accumulator vectors,
1048 /// * If the reduction dimension is distributed across lanes, the reduction is
1049 /// non-lane-local and the reduction is done using warp shuffles. Here we
1050 /// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
1051 /// the warp op body.
1052 /// * If the reduction dimension is not distributed across lanes, the reduction
1053 /// is lane-local. In this case, we yield the source and accumulator vectors
1054 /// from the warp op and perform the lane-local reduction outside the warp op
1055 /// using a sequence of ReductionOps.
1056 /// Example 1 (Reduction is lane-local):
1057 /// ```
1058 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1059 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1060 /// %acc = "some_def"() : () -> (vector<32xf32>)
1061 /// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
1062 /// vector<32xf32> gpu.yield %1 : vector<32xf32>
1063 /// }
1064 /// ```
1065 /// is lowered to:
1066 /// ```
1067 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
1068 /// vector<1xf32>) {
1069 /// %0 = "some_def"() : () -> (vector<16x32xf32>)
1070 /// %acc = "some_def"() : () -> (vector<32xf32>)
1071 /// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
1072 /// }
1073 /// %c = arith.constant dense<0.0> : vector<1xf32>
1074 /// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
1075 /// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
1076 /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
1077 /// ```
1078 /// Example 2 (Reduction is non-lane-local):
1079 /// ```
1080 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1081 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1082 /// %acc = "some_def"() : () -> (vector<2xf32>)
1083 /// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
1084 /// vector<2xf32>
1085 /// gpu.yield %1 : vector<2xf32>
1086 /// }
1087 /// ```
1088 /// is lowered to:
1089 /// ```
1090 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
1091 /// %0 = "some_def"() : () -> (vector<2x32xf32>)
1092 /// %acc = "some_def"() : () -> (vector<2xf32>)
1093 /// %1 = arith.constant dense<0.0> : vector<2xf32>
1094 /// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
1095 /// %3 = ("warp.reduction %2") : f32
1096 /// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
1097 /// ... repeat for row 1
1098 /// gpu.yield %1 : vector<2xf32>
1099 /// }
1100 struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
1101  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1102  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1103  PatternRewriter &rewriter) const override {
1104  OpOperand *yieldOperand =
1105  getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
1106  if (!yieldOperand)
1107  return failure();
1108  auto reductionOp =
1109  cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
1110  unsigned operandNumber = yieldOperand->getOperandNumber();
1111  VectorType sourceType = reductionOp.getSourceVectorType();
1112  // Only 2D vectors are supported.
1113  if (sourceType.getRank() != 2)
1114  return rewriter.notifyMatchFailure(warpOp,
1115  "Only 2D reductions are supported.");
1116  ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
1117  // Only 1 reduction dimension supported. This also ensures that the result
1118  // is vector type.
1119  if (reductionDims.size() != 1)
1120  return rewriter.notifyMatchFailure(
1121  warpOp, "Only 1 reduction dimension is supported.");
1122  int64_t reductionDim = reductionDims[0];
1123  VectorType distributedResultType =
1124  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1125  VectorType resultType = cast<VectorType>(reductionOp.getType());
1126  xegpu::DistributeLayoutAttr sourceLayout =
1127  xegpu::getDistributeLayoutAttr(reductionOp.getSource());
1128 
1129  FailureOr<VectorType> sourceDistTypeOrFailure =
1130  getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
1131  if (failed(sourceDistTypeOrFailure))
1132  return rewriter.notifyMatchFailure(
1133  warpOp, "Failed to distribute the source vector type.");
1134  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1135  // Only single dimension distribution is supported.
1136  bool dim0Distributed =
1137  sourceDistType.getShape()[0] != sourceType.getShape()[0];
1138  bool dim1Distributed =
1139  sourceDistType.getShape()[1] != sourceType.getShape()[1];
1140  if (dim0Distributed && dim1Distributed)
1141  return rewriter.notifyMatchFailure(
1142  warpOp, "Expecting source to be distributed in a single dimension.");
1143  int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
1144  if (sourceDistDim == -1)
1145  return rewriter.notifyMatchFailure(
1146  warpOp, "Expecting a distributed source vector.");
1147  bool resultDistributed =
1148  distributedResultType.getNumElements() < resultType.getNumElements();
1149  // If the lane owns all the data required for reduction (i.e. reduction is
1150  // fully parallel accross lanes), then each lane owns part of the result
1151  // (i.e. result is distributed). If the reduction require cross-lane
1152  // shuffling, then the result is shared among all lanes (broadcasted).
1153  // Therefore we expect following cases:
1154  //
1155  // | Source vector | Reduction dim | Result vector |
1156  // |----------------------|----------------|----------------|
1157  // | dim-0 distributed | 0 | broadcasted |
1158  // | dim-0 distributed | 1 | distributed |
1159  // | dim-1 distributed | 0 | distributed |
1160  // | dim-1 distributed | 1 | broadcasted |
1161 
1162  bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
1163  (sourceDistDim == 1 && reductionDim == 0);
1164  if (isReductionLaneLocal && !resultDistributed)
1165  return rewriter.notifyMatchFailure(
1166  warpOp, "Expecting a distributed result for lane-local reduction.");
1167 
1168  if (!isReductionLaneLocal && resultDistributed)
1169  return rewriter.notifyMatchFailure(
1170  warpOp,
1171  "Expecting a broadcasted result for non-lane-local reduction.");
1172 
1173  // Handle lane-local reduction case. In this case we fully distribute the
1174  // reduction result.
1175  if (isReductionLaneLocal) {
1176  // Yield the source and acc vectors from the WarpOp.
1177  SmallVector<size_t> newRetIndices;
1178  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1179  rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
1180  {sourceDistType, distributedResultType}, newRetIndices);
1181  rewriter.setInsertionPointAfter(newWarpOp);
1182  Value result = lowerToVectorReductions(
1183  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
1184  cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
1185  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1186  // Replace the warp op result with the final result.
1187  rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1188  return success();
1189  }
1190  // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
1191  // of multiple ReductionOps. Actual distribution is done by the
1192  // WarpOpReduction pattern.
1193  rewriter.setInsertionPointAfter(reductionOp);
1194  Value result = lowerToVectorReductions(
1195  cast<TypedValue<VectorType>>(reductionOp.getSource()),
1196  cast<TypedValue<VectorType>>(reductionOp.getAcc()),
1197  reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
1198  // Replace the warp op result with the final result.
1199  rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
1200  return success();
1201  }
1202 };
1203 
1204 /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
1205 /// `gpu.warp_execute_on_lane_0` region.
1206 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
1207  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1208  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1209  PatternRewriter &rewriter) const override {
1210  OpOperand *yieldOperand =
1211  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1212  if (!yieldOperand)
1213  return failure();
1214  auto shapeCastOp =
1215  cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
1216  unsigned operandNumber = yieldOperand->getOperandNumber();
1217  auto resultDistTy =
1218  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1219  xegpu::DistributeLayoutAttr sourceLayout =
1220  xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
1221  xegpu::DistributeLayoutAttr resultLayout =
1222  xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
1223  if (!sourceLayout || !resultLayout)
1224  return rewriter.notifyMatchFailure(
1225  warpOp,
1226  "the source or result of shape_cast op lacks distribution layout");
1227 
1228  // For rank reducing or increasing shape_cast ops, the lower rank layout
1229  // must be a slice of higher rank layout.
1230  int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
1231  int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
1232  if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
1233  return rewriter.notifyMatchFailure(
1234  warpOp, "shape_cast is rank reducing but source layout is not a "
1235  "slice of result layout");
1236  if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
1237  return rewriter.notifyMatchFailure(
1238  warpOp, "shape_cast is rank increasing but result layout is not a "
1239  "slice of source layout");
1240 
1241  FailureOr<VectorType> sourceDistTypeOrFailure =
1242  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1243  shapeCastOp.getSourceVectorType());
1244  if (failed(sourceDistTypeOrFailure))
1245  return rewriter.notifyMatchFailure(
1246  warpOp, "failed to get distributed vector type for source");
1247  VectorType sourceDistType = sourceDistTypeOrFailure.value();
1248  // Create a new warp op that yields the source of the shape_cast op.
1249  SmallVector<size_t> newRetIndices;
1250  auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1251  rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
1252  newRetIndices);
1253  rewriter.setInsertionPointAfter(newWarpOp);
1254  Value source = newWarpOp.getResult(newRetIndices[0]);
1255  // Create a new shape_cast op outside the warp op.
1256  Value newShapeCast = vector::ShapeCastOp::create(
1257  rewriter, shapeCastOp.getLoc(), resultDistTy, source);
1258  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
1259  newShapeCast);
1260  return success();
1261  }
1262 };
1263 
1264 /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
1265 /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
1266 /// outside of the warp op.
1267 struct MemrefExtractAlignedPointerAsIndexDistribution final
1268  : public gpu::WarpDistributionPattern {
1269  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1270  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1271  PatternRewriter &rewriter) const override {
1272  OpOperand *operand = getWarpResult(
1273  warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
1274  if (!operand)
1275  return rewriter.notifyMatchFailure(
1276  warpOp,
1277  "warp result is not a memref::MemrefExtractAlignedPointerAsIndex op");
1278  auto extractOp =
1279  operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
1280  unsigned operandIdx = operand->getOperandNumber();
1281  SmallVector<size_t> newRetIndices;
1282  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1283  rewriter, warpOp, extractOp.getSource(),
1284  TypeRange{extractOp.getSource().getType()}, newRetIndices);
1285  rewriter.setInsertionPointAfter(newWarpOp);
1286  auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
1287  rewriter, newWarpOp.getLoc(), extractOp.getType(),
1288  newWarpOp.getResult(newRetIndices[0]));
1289  Value distributedVal = newWarpOp.getResult(operandIdx);
1290  rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
1291  return success();
1292  }
1293 };
1294 
1295 /// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1296 /// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1297 /// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1298 /// created outside of the warp op with distributed source vector type (computed
1299 /// using assigned layout).
1300 struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
1301  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1302  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1303  PatternRewriter &rewriter) const override {
1304  OpOperand *operand =
1305  getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
1306  if (!operand)
1307  return rewriter.notifyMatchFailure(
1308  warpOp, "warp result is not a vector::BitCast op");
1309  auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
1310  unsigned operandIdx = operand->getOperandNumber();
1311  VectorType distributedSourceType =
1312  getDistVecTypeBasedOnLaneLayout(
1313  xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
1314  bitcastOp.getSourceVectorType())
1315  .value_or(VectorType());
1316  if (!distributedSourceType)
1317  return rewriter.notifyMatchFailure(
1318  bitcastOp, "Failed to distribute the source vector type in "
1319  "vector::BitCast op");
1320  VectorType distributedResultType =
1321  cast<VectorType>(warpOp.getResult(operandIdx).getType());
1322  SmallVector<size_t> newRetIndices;
1323  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1324  rewriter, warpOp, bitcastOp.getSource(),
1325  TypeRange{distributedSourceType}, newRetIndices);
1326  rewriter.setInsertionPointAfter(newWarpOp);
1327  auto newBitcastOp = vector::BitCastOp::create(
1328  rewriter, newWarpOp.getLoc(), distributedResultType,
1329  newWarpOp.getResult(newRetIndices[0]));
1330  Value distributedVal = newWarpOp.getResult(operandIdx);
1331  rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
1332  return success();
1333  }
1334 };
1335 
1336 /// Distribute a vector::TransposeOp feeding into yield op of an enclosing
1337 /// `gpu.warp_execute_on_lane_0` region. Currently only 2D transposes are
1338 /// supported. In most cases, transpose is a no op because it is entirely
1339 /// handled using the layouts (e.g. 16x1 -> 1x16). However, if each lane owns
1340 /// multiple slices of data after distribution (e.g. 16x2 -> 2x16), a lane-local
1341 /// transpose (i.e. shuffle) is needed. Therefore, we create an equivalent
1342 /// vector::TransposeOp outside of the warp op with distributed source vector
1343 /// type (computed using assigned layout).
1344 struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
1345  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1346  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
1347  PatternRewriter &rewriter) const override {
1348  OpOperand *operand =
1349  getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
1350  if (!operand)
1351  return rewriter.notifyMatchFailure(
1352  warpOp, "warp result is not a vector::Transpose op");
1353  auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
1354  unsigned operandIdx = operand->getOperandNumber();
1355  xegpu::DistributeLayoutAttr sourceLayout =
1356  xegpu::getDistributeLayoutAttr(transposeOp.getVector());
1357  xegpu::DistributeLayoutAttr resultLayout =
1358  xegpu::getDistributeLayoutAttr(transposeOp.getResult());
1359  if (!sourceLayout || !resultLayout)
1360  return rewriter.notifyMatchFailure(
1361  transposeOp,
1362  "the source or result vector of the transpose op lacks layout "
1363  "attribute");
1364  int64_t sourceRank = transposeOp.getSourceVectorType().getRank();
1365  int64_t resultRank = transposeOp.getResultVectorType().getRank();
1366  // Only 2D transposes are supported for now.
1367  // TODO: Support nD transposes.
1368  if (sourceRank != 2 || resultRank != 2)
1369  return rewriter.notifyMatchFailure(
1370  transposeOp, "the source or result vector of the transpose op "
1371  "does not have 2D layout");
1372  ArrayRef<int64_t> perm = transposeOp.getPermutation();
1373  // Result layout must be a transpose of source layout.
1374  if (!resultLayout.isTransposeOf(sourceLayout, perm))
1375  return rewriter.notifyMatchFailure(
1376  transposeOp,
1377  "the source or result vector layouts must be 2D transposes of each "
1378  "other");
1379  FailureOr<VectorType> distributedSourceTypeOrFailure =
1380  getDistVecTypeBasedOnLaneLayout(sourceLayout,
1381  transposeOp.getSourceVectorType());
1382  if (failed(distributedSourceTypeOrFailure))
1383  return rewriter.notifyMatchFailure(
1384  transposeOp, "Failed to distribute the source vector type in "
1385  "vector::Transpose op");
1386  SmallVector<size_t> newRetIndices;
1387  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1388  rewriter, warpOp, transposeOp.getVector(),
1389  TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
1390  rewriter.setInsertionPointAfter(newWarpOp);
1391  auto newTransposeOp = vector::TransposeOp::create(
1392  rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1393  perm);
1394  Value distributedVal = newWarpOp.getResult(operandIdx);
1395  rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
1396  return success();
1397  }
1398 };
1399 
1400 } // namespace
1401 
1402 namespace {
1403 struct XeGPUSubgroupDistributePass final
1404  : public xegpu::impl::XeGPUSubgroupDistributeBase<
1405  XeGPUSubgroupDistributePass> {
1406  XeGPUSubgroupDistributePass() = default;
1407  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1408  default;
1409  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1410  : XeGPUSubgroupDistributeBase(options) {}
1411  void runOnOperation() override;
1412 };
1413 } // namespace
1414 
1417  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1418  LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1419  GpuBarrierDistribution, VectorMultiReductionDistribution,
1420  LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1421  VectorBitcastDistribution,
1422  MemrefExtractAlignedPointerAsIndexDistribution>(
1423  patterns.getContext(),
1424  /*pattern benefit=*/regularPatternBenefit);
1425  patterns.add<VectorShapeCastDistribution>(
1426  patterns.getContext(),
1427  /*pattern benefit=*/highPatternBenefit);
1428 }
1429 
1430 void XeGPUSubgroupDistributePass::runOnOperation() {
1431  // Step 1: Attach layouts to op operands.
1432  // TODO: Following assumptions are made:
1433  // 1) It is assumed that there are no layout conflicts.
1434  // 2) Any existing layout attributes attached to the operands are ignored.
1435  Operation *op = getOperation();
1436  op->walk([&](Operation *op) {
1437  for (OpOperand &operand : op->getOpOperands()) {
1438  // Layouts are needed for vector type only.
1439  if (!isa<VectorType>(operand.get().getType()))
1440  continue;
1441 
1442  auto layout = xegpu::getDistributeLayoutAttr(operand.get());
1443  if (!layout) {
1444  op->emitError("Could not find layout attribute for operand ")
1445  << operand.getOperandNumber() << " of operation " << op->getName();
1446  signalPassFailure();
1447  return;
1448  }
1449  xegpu::setDistributeLayoutAttr(operand, layout);
1450  }
1451  });
1452  // Step 2: Move all operations of a GPU function inside
1453  // gpu.warp_execute_on_lane_0 operation.
1454  {
1456  patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1457 
1458  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1459  signalPassFailure();
1460  return;
1461  }
1462  // At this point, we have moved the entire function body inside the
1463  // warpOp. Now move any scalar uniform code outside of the warpOp (like
1464  // GPU index ops, scalar constants, etc.). This will simplify the
1465  // later lowering and avoid custom patterns for these ops.
1466  getOperation()->walk([&](Operation *op) {
1467  if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
1468  vector::moveScalarUniformCode(warpOp);
1469  });
1470  }
1471  // Step 3: Apply subgroup to workitem distribution patterns.
1474  // distributionFn is used by vector distribution patterns to determine the
1475  // distributed vector type for a given vector value. In XeGPU subgroup
1476  // distribution context, we compute this based on lane layout.
1477  auto distributionFn = [](Value val) {
1478  VectorType vecType = dyn_cast<VectorType>(val.getType());
1479  int64_t vecRank = vecType ? vecType.getRank() : 0;
1480  if (vecRank == 0)
1481  return AffineMap::get(val.getContext());
1482  // Get the layout of the vector type.
1483  xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
1484  // If no layout is specified, assume the inner most dimension is distributed
1485  // for now.
1486  if (!layout)
1488  vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
1489  SmallVector<unsigned int> distributedDims;
1490  for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
1491  if (v > 1)
1492  distributedDims.push_back(i);
1493  }
1494  return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
1495  val.getContext());
1496  };
1497  // TODO: shuffleFn is not used.
1498  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1499  int64_t warpSz) { return Value(); };
1500 
1501  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
1502  vector::CombiningKind kind, uint32_t size) {
1503  // First reduce on a single thread to get per lane reduction value.
1504  Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
1505  // Parallel reduction using butterfly shuffles.
1506  for (uint64_t i = 1; i < size; i <<= 1) {
1507  Value shuffled =
1508  builder
1509  .create<gpu::ShuffleOp>(loc, laneVal, i,
1510  /*width=*/size,
1511  /*mode=*/gpu::ShuffleMode::XOR)
1512  .getShuffleResult();
1513  laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
1514  }
1515  return laneVal;
1516  };
1517 
1518  if (enableSGReductions)
1519  vector::populateDistributeReduction(
1520  patterns, warpReduction,
1521  /*pattern benefit=*/regularPatternBenefit);
1522 
1523  vector::populatePropagateWarpVectorDistributionPatterns(
1524  patterns, distributionFn, shuffleFn,
1525  /*pattern benefit=*/regularPatternBenefit);
1526  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1527  signalPassFailure();
1528  return;
1529  }
1530 
1531  // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
1532  // due to tensor desc type mismatches created by using upstream distribution
1533  // patterns (scf.for). This cleanup should only be done if all the ops are
1534  // distributed successfully, if some ops are still not distributed and remains
1535  // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
1536  // breaking the IR.
1537  bool foundWarpOp = false;
1538  getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
1539  // Look for WarpOps that are not trivially dead.
1540  if (isOpTriviallyDead(warpOp))
1541  return WalkResult::advance();
1542  foundWarpOp = true;
1543  return WalkResult::interrupt();
1544  });
1545  if (foundWarpOp)
1546  return;
1547 
1548  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
1549  // We are only interested in UnrealizedConversionCastOps there were added
1550  // for resolving SIMT type mismatches.
1551  if (!op->getAttr(resolveSIMTTypeMismatch))
1552  return WalkResult::skip();
1553 
1554  Value input = op.getOperand(0);
1555  Value output = op.getResult(0);
1556 
1557  // Both input and output must have tensor descriptor types.
1558  xegpu::TensorDescType inputDescType =
1559  mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
1560  xegpu::TensorDescType outputDescType =
1561  mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
1562  assert(inputDescType && outputDescType &&
1563  "Unrealized conversion cast must have tensor descriptor types");
1564 
1565  // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
1566  // This occurs inside scf.for body to resolve the block argument type to
1567  // SIMT type.
1568  if (inputDescType.getLayout()) {
1569  auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
1570  if (argument) {
1571  argument.setType(output.getType());
1572  output.replaceAllUsesWith(argument);
1573  if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
1574  argument.getOwner()->getParentOp())) {
1575  auto result = loopOp.getTiedLoopResult(argument);
1576  result.setType(output.getType());
1577  }
1578  }
1579  }
1580 
1581  // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
1582  // conversions. This occurs at the yield op of scf.for body to go back
1583  // from SIMT type to original type.
1584  if (outputDescType.getLayout())
1585  output.replaceAllUsesWith(input);
1586 
1587  if (op->use_empty())
1588  op->erase();
1589  return WalkResult::advance();
1590  });
1591 }
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static const char *const resolveSIMTTypeMismatch
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
constexpr unsigned subgroupSize
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:179
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
Definition: XeGPUUtils.cpp:106
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:117
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:432
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
Definition: XeGPUUtils.cpp:230
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:40
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314