MLIR  21.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/TypeRange.h"
30 #include "mlir/IR/Value.h"
31 #include "mlir/IR/Visitors.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/raw_ostream.h"
42 
43 namespace mlir {
44 namespace xegpu {
45 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
46 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
47 } // namespace xegpu
48 } // namespace mlir
49 
50 #define DEBUG_TYPE "xegpu-subgroup-distribute"
51 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
52 
53 using namespace mlir;
54 using namespace mlir::dataflow;
55 
56 /// HW dependent constants.
57 /// TODO: These constants should be queried from the target information.
58 constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
59 /// If DPAS A or B operands have low precision element types they must be packed
60 /// according to the following sizes.
61 constexpr unsigned packedSizeInBitsForDefault =
62  16; // Minimum packing size per register for DPAS A.
63 constexpr unsigned packedSizeInBitsForDpasB =
64  32; // Minimum packing size per register for DPAS B.
65 
66 namespace {
67 
68 //===----------------------------------------------------------------------===//
69 // Layout
70 //===----------------------------------------------------------------------===//
71 
72 /// Helper class to store the ND layout of lanes within a subgroup and data
73 /// owned by each lane.
74 struct Layout {
76  Layout() = default;
77  Layout(std::initializer_list<int64_t> list) : layout(list) {}
78  void print(llvm::raw_ostream &os) const;
79  size_t size() const { return layout.size(); }
80  int64_t operator[](size_t idx) const;
81 };
82 
83 void Layout::print(llvm::raw_ostream &os) const {
84  os << llvm::interleaved_array(layout);
85 }
86 
87 int64_t Layout::operator[](size_t idx) const {
88  assert(idx < layout.size() && "Index out of bounds.");
89  return layout[idx];
90 }
91 
92 /// LaneLayout represents the logical layout of lanes within a subgroup when it
93 /// accesses some value. LaneData represents the logical layout of data owned by
94 /// each work item.
95 using LaneLayout = Layout;
96 using LaneData = Layout;
97 
98 //===----------------------------------------------------------------------===//
99 // LayoutInfo
100 //===----------------------------------------------------------------------===//
101 
102 /// Helper class for tracking the analysis state of an mlir value. For layout
103 /// propagation, the analysis state is simply the lane_layout and lane_data of
104 /// each value. Purpose of this analysis to propagate some unique layout for
105 /// each value in the program starting from a set of anchor operations (like
106 /// DPAS, StoreNd, etc.).
107 ///
108 /// Given this, LayoutInfo satisifies the following properties:
109 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
110 /// assigned`.
111 /// 2) Two LayoutInfo values are equal if they are both assigned or
112 /// both not assigned. The concrete value of assigned state does not matter.
113 /// 3) The meet operator works as follows:
114 /// - If current state is assigned, return the current state. (already
115 /// a unique layout is assigned. don't change it)
116 /// - Otherwise, return the other state.
117 
118 struct LayoutInfo {
119 private:
120  LaneLayout laneLayout;
121  LaneData laneData;
122 
123 public:
124  LayoutInfo() = default;
125  LayoutInfo(const LaneLayout &layout, const LaneData &data)
126  : laneLayout(layout), laneData(data) {}
127 
128  // Two lattice values are equal if they have `some` layout. The actual
129  // content of the layout does not matter.
130  bool operator==(const LayoutInfo &other) const {
131  return this->isAssigned() == other.isAssigned();
132  }
133 
134  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
135 
136  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
137 
138  void print(raw_ostream &os) const;
139 
140  bool isAssigned() const {
141  return laneLayout.size() > 0 && laneData.size() > 0;
142  }
143 
144  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
145 
146  const LaneLayout &getLayout() const { return laneLayout; }
147  const LaneData &getData() const { return laneData; }
148  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
149  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
150 };
151 
152 void LayoutInfo::print(raw_ostream &os) const {
153  if (isAssigned()) {
154  os << "lane_layout: ";
155  laneLayout.print(os);
156  os << ", lane_data: ";
157  laneData.print(os);
158  } else {
159  os << "Not assigned.";
160  }
161 }
162 
163 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
164  if (!lhs.isAssigned())
165  return rhs;
166  return lhs;
167 }
168 
169 /// Since this is a backward analysis, join method is not used.
170 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
171  llvm_unreachable("Join should not be triggered by layout propagation.");
172 }
173 
174 /// Get the transposed layout according to the given permutation.
175 LayoutInfo
176 LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
177  if (!isAssigned())
178  return {};
179  LaneLayout newLayout;
180  LaneData newData;
181  for (int64_t idx : permutation) {
182  newLayout.layout.push_back(laneLayout.layout[idx]);
183  newData.layout.push_back(laneData.layout[idx]);
184  }
185  return LayoutInfo(newLayout, newData);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // LayoutInfoLattice
190 //===----------------------------------------------------------------------===//
191 
192 /// Lattice holding the LayoutInfo for each value.
193 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
195  using Lattice::Lattice;
196 };
197 
198 /// Helper Functions to get default layouts. A `default layout` is a layout that
199 /// is assigned to a value when the layout is not fixed by some anchor operation
200 /// (like DPAS).
201 
202 /// Helper Function to get the default layout for uniform values like constants.
203 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
204 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
205 static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
206  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
207  if (rank == 1)
208  return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
209  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
210 }
211 
212 /// Helper to get the default layout for a vector type.
213 static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
214  // Expecting a 1D or 2D vector.
215  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
216  "Expected 1D or 2D vector.");
217  // Expecting int or float element type.
218  assert(vectorTy.getElementType().isIntOrFloat() &&
219  "Expected int or float element type.");
220  // If the rank is 1, then return default layout for 1D vector.
221  if (vectorTy.getRank() == 1)
222  return getDefaultLayoutInfo(1);
223  // Packing factor is determined by the element type bitwidth.
224  int packingFactor = 1;
225  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
226  if (bitwidth < packedSizeInBitsForDefault)
227  packingFactor = packedSizeInBitsForDefault / bitwidth;
228  return LayoutInfo(LaneLayout({1, subgroupSize}),
229  LaneData({1, packingFactor}));
230 }
231 
232 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
233 /// is set according to the following criteria:
234 /// * For A operand, the data must be packed in minimum
235 /// `packedSizeInBitsForDefault`
236 /// * For B operand, the data must be packed in minimum
237 /// `packedSizeInBitsForDpasB`
238 static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
239  unsigned operandNum) {
240  Type elementTy = vectorTy.getElementType();
241  assert(elementTy.isIntOrFloat() &&
242  "Expected int or float type in DPAS operands");
243  LaneLayout layout({1, subgroupSize});
244  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
245  // must have the VNNI format.
246  if (operandNum == 1 &&
248  LaneData data(
250  return LayoutInfo(layout, data);
251  }
252  // Otherwise, return the default layout for the vector type.
253  return getDefaultLayoutInfo(vectorTy);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // LayoutInfoPropagation
258 //===----------------------------------------------------------------------===//
259 
260 /// Backward data flow analysis to propagate the lane_layout and lane_data of
261 /// each value in the program. Currently, the layouts for operands DPAS,
262 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
263 /// this analysis is to propagate those known layouts to all their producers and
264 /// (other) consumers.
265 class LayoutInfoPropagation
266  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
267 private:
268  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
270 
271  void visitStoreNdOp(xegpu::StoreNdOp store,
274 
275  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
278 
279  void visitLoadNdOp(xegpu::LoadNdOp load,
282 
283  void visitLoadGatherOp(xegpu::LoadGatherOp load,
286 
287  void visitTransposeOp(vector::TransposeOp transpose,
290 
291  void visitVectorBitcastOp(vector::BitCastOp bitcast,
294 
295  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
298 
299  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
302 
303  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
306 
307  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
310 
311 public:
312  LayoutInfoPropagation(DataFlowSolver &solver,
313  SymbolTableCollection &symbolTable)
314  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
316 
317  LogicalResult
318  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
319  ArrayRef<const LayoutInfoLattice *> results) override;
320 
321  void visitBranchOperand(OpOperand &operand) override {};
322 
323  void visitCallOperand(OpOperand &operand) override {};
324 
325  void visitExternalCall(CallOpInterface call,
327  ArrayRef<const LayoutInfoLattice *> results) override {
328  };
329 
330  void setToExitState(LayoutInfoLattice *lattice) override {
331  (void)lattice->meet(LayoutInfo());
332  }
333 };
334 } // namespace
335 
336 LogicalResult LayoutInfoPropagation::visitOperation(
340  .Case<xegpu::DpasOp>(
341  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
342  .Case<xegpu::StoreNdOp>(
343  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
344  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
345  visitStoreScatterOp(storeScatterOp, operands, results);
346  })
347  .Case<xegpu::LoadNdOp>(
348  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
349  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
350  visitLoadGatherOp(loadGatherOp, operands, results);
351  })
352  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
353  visitCreateDescOp(createDescOp, operands, results);
354  })
355  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
356  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
357  })
358  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
359  visitPrefetchNdOp(prefetchNdOp, operands, results);
360  })
361  // No need to propagate the layout to operands in CreateNdDescOp because
362  // they are scalars (offsets, sizes, etc.).
363  .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
364  .Case<vector::TransposeOp>([&](auto transposeOp) {
365  visitTransposeOp(transposeOp, operands, results);
366  })
367  .Case<vector::BitCastOp>([&](auto bitcastOp) {
368  visitVectorBitcastOp(bitcastOp, operands, results);
369  })
370  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
371  visitVectorMultiReductionOp(reductionOp, operands, results);
372  })
373  // All other ops.
374  .Default([&](Operation *op) {
375  for (const LayoutInfoLattice *r : results) {
376  for (LayoutInfoLattice *operand : operands) {
377  // Propagate the layout of the result to the operand.
378  if (r->getValue().isAssigned())
379  meet(operand, *r);
380  }
381  }
382  });
383  // Add a dependency from each result to program point after the operation.
384  for (const LayoutInfoLattice *r : results) {
385  addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
386  }
387  return success();
388 }
389 
390 void LayoutInfoPropagation::visitPrefetchNdOp(
391  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
393  // Here we assign the default layout to the tensor descriptor operand of
394  // prefetch.
395  auto tdescTy = prefetch.getTensorDescType();
396  auto prefetchLayout = getDefaultLayoutInfo(
397  VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
398  // Propagate the layout to the source tensor descriptor.
399  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
400 }
401 
402 void LayoutInfoPropagation::visitVectorMultiReductionOp(
403  vector::MultiDimReductionOp reduction,
406  // The layout of the result must be present.
407  LayoutInfo resultLayout = results[0]->getValue();
408  if (!resultLayout.isAssigned())
409  return;
410  // We only consider 2D -> 1D reductions at this point.
411  assert(resultLayout.getLayout().size() == 1 &&
412  "Expected 1D layout for reduction result.");
413  // Given that the result is 1D, the layout of the operand should be 2D with
414  // default layout.
415  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
416  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
417  // Accumulator should have the same layout as the result.
418  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
419 }
420 
421 /// Propagate the layout of the result tensor to the source tensor descriptor in
422 /// UpdateNdOffsetOp.
423 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
424  xegpu::UpdateNdOffsetOp updateNdOffset,
427  // The layout of the result must be present.
428  LayoutInfo resultLayout = results[0]->getValue();
429  if (!resultLayout.isAssigned())
430  return;
431  // Propagate the layout to the source operand.
432  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
433 }
434 
435 /// Set the layouts for DPAS A, B, and C operands.
436 void LayoutInfoPropagation::visitDpasOp(
437  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
439  VectorType aTy = dpas.getLhsType();
440  VectorType bTy = dpas.getRhsType();
441  propagateIfChanged(operands[0],
442  operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
443  propagateIfChanged(operands[1],
444  operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
445  if (operands.size() > 2) {
446  VectorType cTy = dpas.getAccType();
447  propagateIfChanged(operands[2],
448  operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
449  }
450 }
451 
452 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
453 void LayoutInfoPropagation::visitStoreNdOp(
454  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
456  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
457  // Both operands should have the same layout
458  for (LayoutInfoLattice *operand : operands) {
459  propagateIfChanged(operand, operand->meet(storeLayout));
460  }
461 }
462 
463 /// Propagate the layout of the value to the tensor descriptor operand in
464 /// LoadNdOp.
465 void LayoutInfoPropagation::visitLoadNdOp(
466  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
468  LayoutInfo valueLayout = results[0]->getValue();
469  // Need the layout of the value to propagate to the tensor descriptor.
470  if (!valueLayout.isAssigned())
471  return;
472  LayoutInfo tensorDescLayout = valueLayout;
473  // LoadNdOp has the transpose effect. However, at the stage of this analysis
474  // this effect is not expected and should be abstracted away. Emit a warning.
475  if (auto transpose = load.getTranspose()) {
476  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
477  "LayoutInfoPropagation stage.");
478  tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
479  }
480  // Propagate the new layout to the tensor descriptor operand.
481  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
482 }
483 
484 /// For vector::TransposeOp, the layout of the result is transposed and
485 /// propagated to the operand.
486 void LayoutInfoPropagation::visitTransposeOp(
487  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
489  // Need the layout of transpose result to propagate to the operands.
490  LayoutInfo resultLayout = results[0]->getValue();
491  if (!resultLayout.isAssigned())
492  return;
493  LayoutInfo newLayout =
494  resultLayout.getTransposedLayout(transpose.getPermutation());
495  // Propagate the new layout to the vector operand.
496  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
497 }
498 
499 /// For vector::BitCastOp, the lane_data of the source layout is changed based
500 /// on the bit width of the source and result types.
501 void LayoutInfoPropagation::visitVectorBitcastOp(
502  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
504  // Need the layout of bitcast result to propagate to the operands.
505  LayoutInfo resultLayout = results[0]->getValue();
506  if (!resultLayout.isAssigned())
507  return;
508  int inElemTyBitWidth =
509  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
510  int outElemTyBitWidth =
511  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
512 
513  // LaneLayout does not change.
514  const LaneLayout &newLaneLayout = resultLayout.getLayout();
515  const LaneData &currData = resultLayout.getData();
516  LaneData newLaneData;
517  // It's a widening bitcast
518  if (inElemTyBitWidth < outElemTyBitWidth) {
519  int ratio = outElemTyBitWidth / inElemTyBitWidth;
520  newLaneData = resultLayout.getData()[0] == 1
521  ? LaneData({1, currData[1] * ratio})
522  : LaneData({currData[0] * ratio, 1});
523  } else {
524  // It's a narrowing bitcast
525  int ratio = inElemTyBitWidth / outElemTyBitWidth;
526  newLaneData = resultLayout.getData()[0] == 1
527  ? LaneData({1, currData[1] / ratio})
528  : LaneData({currData[0] / ratio, 1});
529  }
530 
531  propagateIfChanged(operands[0],
532  operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
533 }
534 
535 /// Propagate the layout of the result to the tensor descriptor and mask
536 /// operands in LoadGatherOp.
537 void LayoutInfoPropagation::visitLoadGatherOp(
538  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
540  LayoutInfo valueLayout = results[0]->getValue();
541  // Need the layout of the value to propagate to the tensor descriptor.
542  if (!valueLayout.isAssigned())
543  return;
544 
545  LayoutInfo tensorDescLayout = valueLayout;
546  if (load.getTranspose()) {
547  // LoadGatherOp has the transpose effect. However, at the stage of this
548  // analyis this effect is not expected and should be abstracted away. Emit
549  // a warning.
550  load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
551  "LayoutInfoPropagation stage.");
552  tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
553  }
554  // Mask operand should have 1D default layout.
555  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
556  // Propagate the new layout to the tensor descriptor operand.
557  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
558  // Propagate the new layout to the mask operand.
559  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
560 }
561 
562 /// Propagate the layout of the descriptor to the vector offset operand in
563 /// CreateDescOp.
564 void LayoutInfoPropagation::visitCreateDescOp(
565  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
567  LayoutInfo descLayout = results[0]->getValue();
568  // Need the layout of the descriptor to propagate to the operands.
569  if (!descLayout.isAssigned())
570  return;
571  // For offset operand propagate 1D default layout.
572  LayoutInfo layout = getDefaultLayoutInfo(1);
573  propagateIfChanged(operands[1], operands[1]->meet(layout));
574 }
575 
576 /// Set the layout for the value, tensor descriptor, and mask operands in the
577 /// StoreScatterOp.
578 void LayoutInfoPropagation::visitStoreScatterOp(
579  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
581  // Currently, for 2D StoreScatterOp we expect that the height dimension of
582  // the tensor descriptor is equal to the subgroup size. This is ensured by
583  // the op verifier.
584  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
585  if (tdescShape.size() > 1)
586  assert(
587  tdescShape[0] == subgroupSize &&
588  "Expected the first dimension of 2D tensor descriptor to be equal to "
589  "subgroup size.");
590 
591  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
592  LayoutInfo storeScatterLayout = valueLayout;
593  if (storeScatter.getTranspose()) {
594  // StoreScatteOp allows transpose effect. However, at the stage of this
595  // analyis this effect is not expected and should be abstracted away. Emit
596  // a warning.
597  storeScatter.emitWarning("Transpose effect is not expected for "
598  "StoreScatterOp at LayoutInfoPropagation stage.");
599  storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
600  }
601  // Propagate the value layout.
602  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
603  // Propagate the tensor descriptor layout.
604  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
605  // Use default 1D layout for mask operand.
606  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
607  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
608 }
609 
610 namespace {
611 
612 //===----------------------------------------------------------------------===//
613 // RunLayoutInfoPropagation
614 //===----------------------------------------------------------------------===//
615 
616 /// Driver class for running the LayoutInfoPropagation analysis.
617 class RunLayoutInfoPropagation {
618 public:
619  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
620 
621  RunLayoutInfoPropagation(Operation *op) : target(op) {
622  SymbolTableCollection symbolTable;
623  solver.load<DeadCodeAnalysis>();
624  solver.load<SparseConstantPropagation>();
625  solver.load<LayoutInfoPropagation>(symbolTable);
626  (void)solver.initializeAndRun(op);
627  }
628 
629  LayoutInfo getLayoutInfo(Value val);
630 
631  void printAnalysisResult(llvm::raw_ostream &os);
632 
633 private:
634  DataFlowSolver solver;
635  const Operation *target;
636 };
637 } // namespace
638 
639 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
640  auto *state = solver.lookupState<LayoutInfoLattice>(val);
641  if (!state)
642  return {};
643  return state->getValue();
644 }
645 
646 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
647  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
648  os << "function: " << funcOp.getName() << ":\n";
649  // Function arguments
650  for (BlockArgument arg : funcOp.getArguments()) {
651  LayoutInfo layout = getLayoutInfo(arg);
652  os << "argument: " << arg << "\n";
653  os << "layout : ";
654  layout.print(os);
655  os << "\n";
656  }
657  // Function ops
658  funcOp.walk([&](Operation *op) {
659  // Skip ops that do not have results
660  if (op->getResults().empty())
661  return;
662  os << "op : ";
663  // For control-flow ops, print the op name only.
664  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
665  os << op->getName();
666  else
667  op->print(os);
668  os << "\n";
669  // Print the layout for each result.
670  for (auto [i, r] : llvm::enumerate(op->getResults())) {
671  LayoutInfo layout = getLayoutInfo(r);
672  os << "layout for result #" << i << ": ";
673  layout.print(os);
674  os << "\n";
675  }
676  });
677  };
678 
680  if (auto modOp = dyn_cast<ModuleOp>(target)) {
681  for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
682  funcOps.push_back(funcOp);
683  }
684  // Collect all GpuFuncOps in the module.
685  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
686  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
687  funcOps.push_back(gpuFuncOp);
688  }
689  }
690  }
691  // Print the analysis result for each function.
692  for (FunctionOpInterface funcOp : funcOps) {
693  printFunctionResult(funcOp);
694  }
695 }
696 
697 namespace {
698 
699 //===----------------------------------------------------------------------===//
700 // LayoutAttrAssignment
701 //===----------------------------------------------------------------------===//
702 
703 /// This class is responsible for assigning the layout attributes to the ops and
704 /// their users based on the layout propagation analysis result.
705 class LayoutAttrAssignment {
706 public:
707  LayoutAttrAssignment(Operation *top,
708  function_ref<LayoutInfo(Value)> getLayout)
709  : getAnalysisResult(getLayout), top(top) {}
710 
711  LogicalResult run();
712 
713 private:
714  LogicalResult assign(Operation *op);
715  void assignToUsers(Value v, xegpu::LayoutAttr layout);
716  xegpu::LayoutAttr getLayoutAttrForValue(Value v);
717  LogicalResult resolveConflicts();
718  // Callable to get the layout of a value based on the layout propagation
719  // analysis.
720  function_ref<LayoutInfo(Value)> getAnalysisResult;
721  Operation *top;
722 };
723 
724 } // namespace
725 
726 /// Helper to assign the layout attribute to the users of the value.
727 void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
728  for (OpOperand &user : v.getUses()) {
729  Operation *owner = user.getOwner();
730  std::string attrName = xegpu::getLayoutName(user);
731  owner->setAttr(attrName, layout);
732  }
733 }
734 
735 /// Convert the layout assigned to a value to xegpu::LayoutAttr.
736 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
737  LayoutInfo layout = getAnalysisResult(v);
738  if (!layout.isAssigned())
739  return {};
740  SmallVector<int, 2> laneLayout, laneData;
741  for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
742  layout.getDataAsArrayRef())) {
743  laneLayout.push_back(static_cast<int>(layout));
744  laneData.push_back(static_cast<int>(data));
745  }
746  return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
747 }
748 
749 /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
750 /// based on the layout propagation analysis result.
751 LogicalResult LayoutAttrAssignment::assign(Operation *op) {
752  // For function ops, propagate the function argument layout to the users.
753  if (auto func = dyn_cast<FunctionOpInterface>(op)) {
754  for (BlockArgument arg : func.getArguments()) {
755  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
756  if (layoutInfo) {
757  assignToUsers(arg, layoutInfo);
758  }
759  }
760  return success();
761  }
762  // If no results, move on.
763  if (op->getNumResults() == 0)
764  return success();
765  // If all the results are scalars, move on.
766  if (llvm::all_of(op->getResultTypes(),
767  [](Type t) { return t.isIntOrIndexOrFloat(); }))
768  return success();
769  // If the op has more than one result and at least one result is a tensor
770  // descriptor, exit. This case is not supported yet.
771  // TODO: Support this case.
772  if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type t) {
773  return isa<xegpu::TensorDescType>(t);
774  })) {
775  LLVM_DEBUG(
776  DBGS() << op->getName()
777  << " op has more than one result and at least one is a tensor "
778  "descriptor. This case is not handled.\n");
779  return failure();
780  }
781  // If the result is a tensor descriptor, attach the layout to the tensor
782  // descriptor itself.
783  if (auto tensorDescTy =
784  dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
785  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
786  if (!layoutInfo) {
787  LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
788  return failure();
789  }
790 
791  // Clone the op, attach the layout to the result tensor descriptor, and
792  // remove the original op.
793  OpBuilder builder(op);
794  Operation *newOp = builder.clone(*op);
795  auto newTensorDescTy = xegpu::TensorDescType::get(
796  tensorDescTy.getContext(), tensorDescTy.getShape(),
797  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
798  newOp->getResult(0).setType(newTensorDescTy);
799  op->replaceAllUsesWith(newOp->getResults());
800  op->erase();
801  return success();
802  }
803  // Otherwise simply attach the layout to the op itself.
804  for (auto r : op->getOpResults()) {
805  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
806  if (layoutInfo) {
807  std::string attrName = xegpu::getLayoutName(r);
808  op->setAttr(attrName, layoutInfo);
809  // Attach the layout attribute to the users of the result.
810  assignToUsers(r, layoutInfo);
811  }
812  }
813  return success();
814 }
815 
816 /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
817 LogicalResult LayoutAttrAssignment::run() {
818  auto walkResult = top->walk([&](Operation *op) {
819  if (failed(assign(op)))
820  return WalkResult::interrupt();
821  return WalkResult::advance();
822  });
823 
824  if (walkResult.wasInterrupted())
825  return failure();
826 
827  return resolveConflicts();
828 }
829 
830 /// TODO: Implement the layout conflict resolution. This must ensure mainly two
831 /// things:
832 /// 1) Is a given layout supported by the op? (need to query the target
833 /// HW info). Otherwise can we achieve this layout using a layout conversion?
834 /// 2) Do all the operands have the required layout? If not, can it
835 /// be resolved using a layout conversion?
836 LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
837 
838 namespace {
839 
840 //===----------------------------------------------------------------------===//
841 // SIMT Distribution Patterns
842 //===----------------------------------------------------------------------===//
843 
844 /// Helper function to get distributed vector type for a source vector type
845 /// according to the lane_layout. We simply divide each dimension of tensor
846 /// descriptor shape by corresponding lane_layout dimension. If
847 /// array_length > 1, that is appended to the front of the ditributed shape.
848 /// NOTE: This is the vector type that will be returned by the
849 /// gpu.warp_execute_on_lane0 op.
850 ///
851 /// Examples:
852 /// | original vector shape | lane_layout | distributed vector shape |
853 /// |-----------------------|-------------|--------------------------|
854 /// | 32x16 | [1, 16] | 32x1 |
855 /// | 32x16 | [2, 8] | 16x2 |
856 /// | 2x32x16 | [1, 16] | 2x32x1 |
857 static FailureOr<VectorType>
858 getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
859  VectorType originalType) {
860  if (!layout)
861  return failure();
862 
863  auto laneLayout = layout.getLaneLayout().asArrayRef();
864  assert(originalType.getShape().size() >= laneLayout.size() &&
865  "Rank of the original vector type should be greater or equal to the "
866  "size of the lane layout to distribute the vector type.");
867  SmallVector<int64_t> distributedShape(originalType.getShape());
868  // Only distribute the last `laneLayout.size()` dimensions. The remaining
869  // dimensions are not distributed.
870  unsigned distributionStart = originalType.getRank() - laneLayout.size();
871  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
872  if (i < distributionStart) {
873  continue;
874  }
875  // Check if the dimension can be distributed evenly.
876  if (dim % laneLayout[i - distributionStart] != 0)
877  return failure();
878  distributedShape[i] = dim / laneLayout[i - distributionStart];
879  }
880  return VectorType::get(distributedShape, originalType.getElementType());
881 }
882 
883 /// Helper function to resolve types if the distributed type out of
884 /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
885 /// Example 1:
886 /// distributed type: vector<8x1xf32>
887 /// expected type: vector<8xf32>
888 /// resolved using,
889 /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
890 /// Example 2:
891 /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
892 /// expected type: xegpu.tensor_desc<8x16xf32>
893 /// resolved using,
894 /// %0 = unrealized_conversion_cast %1 :
895 /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
896 /// xegpu.tensor_desc<8x16xf32>
897 template <typename T>
898 static Value resolveDistributedTy(Value orig, T expected,
899  PatternRewriter &rewriter) {
900  // If orig and expected types are the same, return orig.
901  if (orig.getType() == expected)
902  return orig;
903  // If orig is a vector type, create a shape cast op to reconcile the types.
904  if (isa<VectorType>(orig.getType())) {
905  auto castOp =
906  rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
907  return castOp.getResult();
908  }
909  // If orig is a tensor descriptor type, create an unrealized conversion cast
910  // op to reconcile the types.
911  if (isa<xegpu::TensorDescType>(orig.getType())) {
912  auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
913  expected, orig);
914  return castOp.getResult(0);
915  }
916  llvm_unreachable("Unsupported type for reconciliation");
917  return orig;
918 }
919 
920 /// Helper function to filter out the temporary layout attributes attached
921 /// during the layout assignment process. These are not needed after going to
922 /// SIMT.
924 removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
926  for (NamedAttribute attr : attrs) {
927  if (!isa<xegpu::LayoutAttr>(attr.getValue()))
928  newAttrs.push_back(attr);
929  }
930  return newAttrs;
931 }
932 
933 /// Helper function to check if the layout is packed. Layout is packed if it is
934 /// 2D and lane_data[0] != 1 (data packed from col dimension).
935 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
936  if (layout == xegpu::LayoutAttr())
937  return false;
938  DenseI32ArrayAttr laneData = layout.getLaneData();
939  if (!laneData || laneData.size() != 2)
940  return false;
941  return laneData.asArrayRef()[0] != 1;
942 }
943 
944 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
945 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
946 /// contained within a WarpExecuteOnLane0Op.
947 /// Example:
948 ///
949 /// ```
950 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
951 /// ...
952 /// ...
953 /// gpu.return %result: vector<8x16xf32>
954 /// }
955 /// ```
956 /// To
957 /// ```
958 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
959 /// %laneid = gpu.lane_id : index
960 /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
961 /// ...
962 /// ...
963 /// gpu.yield %result: vector<8x16xf32>
964 /// }
965 /// return %0
966 /// }
967 struct MoveFuncBodyToWarpExecuteOnLane0
968  : public OpRewritePattern<gpu::GPUFuncOp> {
970  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
971  PatternRewriter &rewriter) const override {
972  // If the function only contains a single void return, skip.
973  if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
974  return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
975  }))
976  return failure();
977  // If the function already moved inside a warp_execute_on_lane0, skip.
978  if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
979  return isa<gpu::WarpExecuteOnLane0Op>(op);
980  }))
981  return failure();
982  // Create a new function with the same signature.
983  auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
984  gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
985  // Create a WarpExecuteOnLane0Op with same arguments and results as the
986  // original gpuFuncOp.
987  rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
988  auto laneId = rewriter.create<gpu::LaneIdOp>(
989  newGpuFunc.getLoc(), rewriter.getIndexType(),
990  /** upperBound = **/ mlir::IntegerAttr());
991  ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
992  auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
993  laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
994  newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
995  Block &warpBodyBlock = warpOp.getBodyRegion().front();
996  // Replace the ReturnOp of the original gpu function with a YieldOp.
997  auto origRetunOp =
998  cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
999  rewriter.setInsertionPointAfter(origRetunOp);
1000  rewriter.create<gpu::YieldOp>(origRetunOp.getLoc(),
1001  origRetunOp.getOperands());
1002  rewriter.eraseOp(origRetunOp);
1003  // Move the original function body to the WarpExecuteOnLane0Op body.
1004  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
1005  warpOp.getBodyRegion().begin());
1006  rewriter.eraseBlock(&warpBodyBlock);
1007  // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
1008  rewriter.setInsertionPointAfter(warpOp);
1009  rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
1010  rewriter.replaceOp(gpuFuncOp, newGpuFunc);
1011  return success();
1012  }
1013 };
1014 
1015 /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
1016 /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
1017 /// still contain the original op that will not be used by the yield op (and
1018 /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
1019 /// arguments. Tensor descriptor shape is not distributed because it is a
1020 /// uniform value across all work items within the subgroup. However, the
1021 /// layout information is dropped in the new tensor descriptor type.
1022 ///
1023 /// Example:
1024 ///
1025 /// ```
1026 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1027 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1028 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1029 /// ...
1030 /// %td = xegpu.create_nd_tdesc %arg0[0, 0]
1031 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1032 /// vector.yield %td
1033 /// }
1034 /// ```
1035 /// To
1036 /// ```
1037 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
1038 /// ...
1039 /// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1040 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1041 /// vector.yield %arg0, %dead
1042 /// }
1043 /// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
1044 /// -> !xegpu.tensor_desc<4x8xf32>
1045 ///
1046 /// ```
1047 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
1048  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1049  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1050  PatternRewriter &rewriter) const override {
1051  OpOperand *operand =
1052  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
1053  if (!operand)
1054  return rewriter.notifyMatchFailure(
1055  subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
1056  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
1057  unsigned operandIdx = operand->getOperandNumber();
1058 
1059  xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
1060  if (!layout)
1061  return rewriter.notifyMatchFailure(
1062  descOp, "the tensor descriptor lacks layout attribute");
1063 
1064  SmallVector<size_t> newRetIndices;
1065  SmallVector<Value> newYieldValues;
1066  SmallVector<Type> newYieldTypes;
1067 
1068  for (Value operand : descOp->getOperands()) {
1069  newYieldValues.push_back(operand);
1070  newYieldTypes.push_back(operand.getType());
1071  }
1072  rewriter.setInsertionPoint(subgroupOp);
1073  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1074  rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
1075  /* new yielded types = */ newYieldTypes, newRetIndices);
1076 
1077  SmallVector<Value> newDescOperands;
1078  for (size_t i : newRetIndices) {
1079  newDescOperands.push_back(newWarpOp.getResult(i));
1080  }
1081  rewriter.setInsertionPointAfter(newWarpOp);
1082  xegpu::TensorDescType distributedTensorDescTy =
1083  descOp.getType().dropLayouts(); // Distributed tensor descriptor type
1084  // does not contain layout info.
1085  auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
1086  newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
1087  descOp->getAttrs());
1088 
1089  Value distributedVal = newWarpOp.getResult(operandIdx);
1090  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
1091  return success();
1092  }
1093 };
1094 
1095 /// Distribute a store_nd op at the end of enclosing
1096 /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
1097 /// through the warp op interface they would be propagated as returned values.
1098 /// Source vector is distributed based on lane layout. Appropriate cast ops are
1099 /// inserted if the distributed types does not match expected xegpu SIMT types.
1100 ///
1101 /// Example:
1102 ///
1103 /// ```
1104 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1105 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
1106 /// ...
1107 /// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1108 /// !xegpu.tensor_desc<4x8xf32, #layout0>
1109 /// }
1110 /// ```
1111 /// To
1112 /// ```
1113 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1114 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1115 /// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1116 /// #layout0>
1117 /// }
1118 /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
1119 /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1120 /// #layout0>
1121 /// -> !xegpu.tensor_desc<4x8xf32>
1122 /// xegpu.store_nd %0, %1: vector<4xf32>,
1123 /// !xegpu.tensor_desc<4x8xf32>
1124 ///
1125 /// ```
1126 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
1127  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1129  PatternRewriter &rewriter) const override {
1130  auto yield = cast<gpu::YieldOp>(
1131  subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1132  Operation *lastNode = yield->getPrevNode();
1133  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
1134  if (!storeOp)
1135  return failure();
1136 
1137  xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
1138  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1139  if (!layout)
1140  return rewriter.notifyMatchFailure(
1141  storeOp, "the source tensor descriptor lacks layout attribute");
1142 
1143  FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
1144  getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1145  if (failed(distributedTypeByWarpOpOrFailure))
1146  return rewriter.notifyMatchFailure(storeOp,
1147  "Failed to distribute the type");
1148  VectorType distributedTypeByWarpOp =
1149  distributedTypeByWarpOpOrFailure.value();
1150 
1151  SmallVector<size_t> newRetIndices;
1152  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1153  rewriter, subgroupOp,
1154  /* new yielded values = */
1155  ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
1156  /* new yielded types = */
1157  TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
1158  newRetIndices);
1159  // Create a new store op outside the warp op with the distributed vector
1160  // type. Tensor descriptor is not distributed.
1161  rewriter.setInsertionPointAfter(newWarpOp);
1162  SmallVector<Value> newStoreOperands;
1163 
1164  // For the value operand, there can be a mismatch between the vector type
1165  // distributed by the warp op and (xegpu-specific) distributed type
1166  // supported by the store op. Type mismatch must be resolved using
1167  // appropriate cast op.
1168  FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
1169  xegpu::getDistributedVectorType(storeOp.getTensorDescType());
1170  if (failed(storeNdDistributedValueTyOrFailure))
1171  return rewriter.notifyMatchFailure(
1172  storeOp, "Failed to get distributed vector type for the store op");
1173  newStoreOperands.push_back(resolveDistributedTy(
1174  newWarpOp.getResult(newRetIndices[0]),
1175  storeNdDistributedValueTyOrFailure.value(), rewriter));
1176  // For the tensor descriptor operand, the layout attribute is dropped after
1177  // distribution. Types needs to be resolved in this case also.
1178  xegpu::TensorDescType distributedTensorDescTy =
1179  storeOp.getTensorDescType().dropLayouts();
1180  newStoreOperands.push_back(
1181  resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1182  distributedTensorDescTy, rewriter));
1183 
1184  rewriter.create<xegpu::StoreNdOp>(
1185  newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
1186  removeTemporaryLayoutAttributes(storeOp->getAttrs()));
1187  rewriter.eraseOp(storeOp);
1188  return success();
1189  }
1190 };
1191 
1192 /// Distribute a load_nd op feeding into vector.yield op for the enclosing
1193 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1194 /// The warp op will still contain the original op that will not be used by
1195 /// the yield op (and should be cleaned up later). The yield op will
1196 /// bypass the load's arguments. Only the loaded vector is distributed
1197 /// according to lane layout and, tensor descriptor types is not
1198 /// distributed. Appropriate cast ops are inserted if the distributed types does
1199 /// not match expected xegpu SIMT types.
1200 ///
1201 /// Example:
1202 ///
1203 /// ```
1204 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1205 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1206 /// (vector<4x1xf32>) {
1207 /// ...
1208 /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1209 /// ->
1210 /// vector<4x8xf32>
1211 /// gpu.yield %ld
1212 /// }
1213 /// ```
1214 /// To
1215 /// ```
1216 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1217 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1218 /// ...
1219 /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
1220 /// vector<4x8xf32> gpu.yield %dead, %arg0
1221 /// }
1222 /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1223 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1224 /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
1225 /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
1226 ///
1227 /// ```
1228 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
1229  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1230  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1231  PatternRewriter &rewriter) const override {
1232  OpOperand *operand =
1233  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
1234  if (!operand)
1235  return rewriter.notifyMatchFailure(
1236  subgroupOp, "warp result is not a xegpu::LoadNd op");
1237 
1238  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
1239  xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
1240  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1241  if (!layout)
1242  return rewriter.notifyMatchFailure(
1243  loadOp, "the source tensor descriptor lacks layout attribute");
1244 
1245  unsigned operandIdx = operand->getOperandNumber();
1246  VectorType distributedTypeByWarpOp =
1247  cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
1248 
1249  SmallVector<size_t> newRetIndices;
1250  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1251  rewriter, subgroupOp,
1252  /* new yielded values = */ loadOp.getTensorDesc(),
1253  /* new yielded types = */ tensorDescTy, newRetIndices);
1254 
1255  // Create a new load op outside the warp op with the distributed vector
1256  // type.
1257  rewriter.setInsertionPointAfter(newWarpOp);
1258  FailureOr<VectorType> loadNdDistValueTyOrFailure =
1259  xegpu::getDistributedVectorType(loadOp.getTensorDescType());
1260  if (failed(loadNdDistValueTyOrFailure))
1261  return rewriter.notifyMatchFailure(
1262  loadOp, "Failed to get distributed vector type for the load op");
1263  xegpu::TensorDescType distributedTensorDescTy =
1264  loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
1265  // descriptor type does not
1266  // contain layout info.
1267  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
1268  newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1269  resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
1270  distributedTensorDescTy, rewriter),
1271  removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1272  // Set the packed attribute if the layout requires it.
1273  newLoadOp.setPacked(hasPackedLayout(layout));
1274  Value distributedVal = newWarpOp.getResult(operandIdx);
1275  // There can be a conflict between the vector type distributed by the
1276  // warp op and (xegpu-specific) distributed type supported by the load
1277  // op. Resolve these mismatches by inserting a cast.
1278  Value tyResolvedVal = resolveDistributedTy(
1279  newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1280  rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
1281  return success();
1282  }
1283 };
1284 
1285 /// Distribute a dpas op feeding into vector.yield op for the enclosing
1286 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1287 /// The warp op will still contain the original op that will not be used by
1288 /// the yield op (and should be cleaned up later). The yield op will
1289 /// bypass the dpas's arguments. Appropriate cast ops are inserted if the
1290 /// distributed types does not match expected xegpu SIMT types.
1291 /// Example:
1292 /// ```
1293 /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1294 /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
1295 /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1296 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1297 /// (vector<8x1xf32>) {
1298 /// ...
1299 /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
1300 /// vector<8x16xf32>
1301 /// gpu.yield %dpas
1302 /// }
1303 /// ```
1304 /// To
1305 /// ```
1306 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
1307 /// vector<8x1xf16>, vector<16x1xf16>) {
1308 /// ...
1309 /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
1310 /// -> vector<8x16xf32>
1311 /// gpu.yield %dead, %arg0, %arg1
1312 /// }
1313 /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
1314 /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
1315 /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
1316 /// vector<8xf32>
1317 /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
1318 /// ```
1319 struct DpasDistribution final : public gpu::WarpDistributionPattern {
1320  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1321  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1322  PatternRewriter &rewriter) const override {
1323  OpOperand *operand =
1324  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
1325  if (!operand)
1326  return rewriter.notifyMatchFailure(subgroupOp,
1327  "warp result is not a xegpu::Dpas op");
1328 
1329  auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
1330  unsigned operandIdx = operand->getOperandNumber();
1331  std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
1332  std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
1333  std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
1334 
1335  xegpu::LayoutAttr layoutA =
1336  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1337  xegpu::LayoutAttr layoutB =
1338  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
1339  xegpu::LayoutAttr layoutOut =
1340  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
1341  if (!layoutA || !layoutB || !layoutOut)
1342  return rewriter.notifyMatchFailure(
1343  dpasOp,
1344  "the xegpu::Dpas op lacks layout attribute for A, B or output");
1345 
1346  FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
1347  getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
1348  FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
1349  getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
1350  FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
1351  getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
1352  if (failed(distLhsTypeByWarpOpOrFailure) ||
1353  failed(distRhsTypeByWarpOpOrFailure) ||
1354  failed(distResultTypeByWarpOpOrFailure))
1355  return rewriter.notifyMatchFailure(
1356  dpasOp,
1357  "Failed to distribute the A, B or output types in xegpu::Dpas op");
1358 
1359  llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
1360  dpasOp.getRhs()};
1361  llvm::SmallVector<Type, 3> newYieldTypes{
1362  distLhsTypeByWarpOpOrFailure.value(),
1363  distRhsTypeByWarpOpOrFailure.value()};
1364  // Dpas acc operand is optional.
1365  if (dpasOp.getAcc()) {
1366  newYieldValues.push_back(dpasOp.getAcc());
1367  newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
1368  }
1369  // Create a new warp op without the dpas.
1370  SmallVector<size_t> newRetIndices;
1371  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1372  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1373 
1374  FailureOr<VectorType> expectedDistLhsTyOrFailure =
1375  xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
1376  FailureOr<VectorType> expectedDistRhsTyOrFailure =
1377  xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
1378  FailureOr<VectorType> expectedDistResultTyOrFailure =
1379  xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
1380  if (failed(expectedDistLhsTyOrFailure) ||
1381  failed(expectedDistRhsTyOrFailure) ||
1382  failed(expectedDistResultTyOrFailure))
1383  return rewriter.notifyMatchFailure(
1384  dpasOp,
1385  "Failed to get distributed vector type for the dpas operands.");
1386  // Create a new dpas op outside the warp op.
1387  rewriter.setInsertionPointAfter(newWarpOp);
1388  SmallVector<Value> newDpasOperands;
1389  SmallVector<VectorType> newDpasOperandExpectedTypes;
1390 
1391  // Resolve the distributed types with the original types.
1392  newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1393  newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1394  VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
1395  if (dpasOp.getAcc())
1396  newDpasOperandExpectedTypes.push_back(distributedResultTy);
1397 
1398  for (unsigned i = 0; i < newRetIndices.size(); i++) {
1399  newDpasOperands.push_back(
1400  resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
1401  newDpasOperandExpectedTypes[i], rewriter));
1402  }
1403  Value newDpasOp = rewriter.create<xegpu::DpasOp>(
1404  newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
1405  removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
1406  Value distributedVal = newWarpOp.getResult(operandIdx);
1407  // Resolve the output type.
1408  newDpasOp = resolveDistributedTy(
1409  newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
1410  rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
1411  return success();
1412  }
1413 };
1414 
1415 /// Sink an update_nd_offset op feeding into yield op of an enclosing
1416 /// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1417 /// original op that will not be used by the yield op (and should be cleaned
1418 /// up later). The yield op will bypass the updateOp's arguments. The tensor
1419 /// descriptor type is not distributed. Appropriate cast ops are inserted if
1420 /// the distributed types does not match expected xegpu SIMT types.
1421 /// Example:
1422 /// ```
1423 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1424 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1425 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1426 /// ...
1427 /// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1428 /// !xegpu.tensor_desc<4x8xf32, #layout0>
1429 /// gpu.yield %update
1430 /// }
1431 /// ...
1432 /// ```
1433 /// To
1434 /// ```
1435 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1436 /// !xegpu.tensor_desc<4x8xf32, #layout0>,
1437 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
1438 /// ...
1439 /// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1440 /// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1441 /// gpu.yield %dead, %arg0, %c32, %c16
1442 /// }
1443 /// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1444 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1445 /// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
1446 /// !xegpu.tensor_desc<4x8xf32>
1447 /// ...
1448 /// ```
1449 struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1450  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1451  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1452  PatternRewriter &rewriter) const override {
1453  OpOperand *operand =
1454  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1455  if (!operand)
1456  return rewriter.notifyMatchFailure(
1457  subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
1458  auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
1459  unsigned operandIdx = operand->getOperandNumber();
1460  // new update op does not have layout attribute.
1461  xegpu::TensorDescType newTensorDescTy =
1462  updateOp.getTensorDescType().dropLayouts();
1463 
1464  SmallVector<Value, 3> newYieldValues;
1465  SmallVector<Type, 3> newYieldTypes;
1466  for (Value operand : updateOp->getOperands()) {
1467  newYieldValues.push_back(operand);
1468  if (isa<xegpu::TensorDescType>(operand.getType())) {
1469  newYieldTypes.push_back(newTensorDescTy);
1470  } else {
1471  newYieldTypes.push_back(operand.getType());
1472  }
1473  }
1474  SmallVector<size_t> newRetIndices;
1475  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1476  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1477  rewriter.setInsertionPointAfter(newWarpOp);
1478  SmallVector<Value> newUpdateOperands;
1479  for (size_t i : newRetIndices) {
1480  // For the tensor descriptor operand, the layout attribute is dropped
1481  // after distribution. Types needs to be resolved in this case.
1482  if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1483  newUpdateOperands.push_back(resolveDistributedTy(
1484  newWarpOp.getResult(i), newTensorDescTy, rewriter));
1485  } else {
1486  newUpdateOperands.push_back(newWarpOp.getResult(i));
1487  }
1488  }
1489  // Create a new update op outside the warp op.
1490  auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
1491  newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1492  removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1493  Value distributedVal = newWarpOp.getResult(operandIdx);
1494  rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
1495  return success();
1496  }
1497 };
1498 
1499 /// Distribute a prefetch_nd op at the end of enclosing
1500 /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1501 /// through the warp op interface they would be propagated as returned values.
1502 /// Tensor descriptor shape is not distributed because it is a uniform value
1503 /// across all work items within the subgroup. Appropriate cast ops are inserted
1504 /// if the distributed types does not match expected xegpu SIMT types.
1505 ///
1506 /// Example:
1507 ///
1508 /// ```
1509 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1510 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
1511 /// ...
1512 /// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
1513 /// }
1514 /// ```
1515 /// To
1516 /// ```
1517 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
1518 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1519 /// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
1520 /// }
1521 /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1522 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1523 /// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
1524 ///
1525 /// ```
1526 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1527  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1528  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1529  PatternRewriter &rewriter) const override {
1530  auto yield = cast<gpu::YieldOp>(
1531  subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1532  Operation *lastNode = yield->getPrevNode();
1533  auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1534  if (!prefetchOp)
1535  return failure();
1536  xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
1537  if (!layout)
1538  return rewriter.notifyMatchFailure(
1539  prefetchOp, "the source tensor descriptor lacks layout attribute");
1540 
1541  SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
1542  SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
1543  SmallVector<size_t> newRetIndices;
1544  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1545  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1546  // Create a new prefetch op outside the warp op with updated tensor
1547  // descriptor type. Source tensor descriptor require type resolution.
1548  xegpu::TensorDescType newTensorDescTy =
1549  prefetchOp.getTensorDescType().dropLayouts();
1550  rewriter.setInsertionPointAfter(newWarpOp);
1551  SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
1552  newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1553  rewriter.create<xegpu::PrefetchNdOp>(
1554  newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
1555  removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1556  rewriter.eraseOp(prefetchOp);
1557  return success();
1558  }
1559 };
1560 
1561 } // namespace
1562 
1563 namespace {
1564 struct XeGPUSubgroupDistributePass final
1565  : public xegpu::impl::XeGPUSubgroupDistributeBase<
1566  XeGPUSubgroupDistributePass> {
1567  XeGPUSubgroupDistributePass() = default;
1568  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1569  default;
1570  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1571  : XeGPUSubgroupDistributeBase(options) {}
1572  void runOnOperation() override;
1573 };
1574 } // namespace
1575 
1578  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1579  LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1580  UpdateNdOffsetDistribution>(patterns.getContext());
1581 }
1582 
1583 void XeGPUSubgroupDistributePass::runOnOperation() {
1584  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
1585  // Print the analysis result and exit. (for testing purposes)
1586  if (printOnly) {
1587  auto &os = llvm::outs();
1588  analyis.printAnalysisResult(os);
1589  return;
1590  }
1591  auto getPropagatedLayout = [&](Value val) {
1592  return analyis.getLayoutInfo(val);
1593  };
1594 
1595  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
1596  // propagation analysis result.
1597  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
1598  if (failed(layoutAssignment.run())) {
1599  signalPassFailure();
1600  return;
1601  }
1602 
1603  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
1604  // operation.
1605  {
1607  patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1608 
1609  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1610  signalPassFailure();
1611  return;
1612  }
1613  // At this point, we have moved the entire function body inside the warpOp.
1614  // Now move any scalar uniform code outside of the warpOp (like GPU index
1615  // ops, scalar constants, etc.). This will simplify the later lowering and
1616  // avoid custom patterns for these ops.
1617  getOperation()->walk([&](Operation *op) {
1618  if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1619  vector::moveScalarUniformCode(warpOp);
1620  }
1621  });
1622  }
1623  // Finally, do the SIMD to SIMT distribution.
1626  // TODO: distributionFn and shuffleFn are not used at this point.
1627  auto distributionFn = [](Value val) {
1628  VectorType vecType = dyn_cast<VectorType>(val.getType());
1629  int64_t vecRank = vecType ? vecType.getRank() : 0;
1630  OpBuilder builder(val.getContext());
1631  if (vecRank == 0)
1632  return AffineMap::get(val.getContext());
1633  return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
1634  };
1635  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1636  int64_t warpSz) { return Value(); };
1637  vector::populatePropagateWarpVectorDistributionPatterns(
1638  patterns, distributionFn, shuffleFn);
1639  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1640  signalPassFailure();
1641  return;
1642  }
1643 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
constexpr unsigned subgroupSize
HW dependent constants.
#define DBGS()
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
IndexType getIndexType()
Definition: Builders.cpp:51
The general data-flow analysis solver.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
Definition: XeGPUUtils.cpp:104
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:38
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314