MLIR  21.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/TypeRange.h"
30 #include "mlir/IR/Value.h"
31 #include "mlir/IR/Visitors.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/raw_ostream.h"
42 
43 namespace mlir {
44 namespace xegpu {
45 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
46 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
47 } // namespace xegpu
48 } // namespace mlir
49 
50 #define DEBUG_TYPE "xegpu-subgroup-distribute"
51 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
52 
53 using namespace mlir;
54 using namespace mlir::dataflow;
55 
56 /// HW dependent constants.
57 /// TODO: These constants should be queried from the target information.
58 constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
59 /// If DPAS A or B operands have low precision element types they must be packed
60 /// according to the following sizes.
61 constexpr unsigned packedSizeInBitsForDefault =
62  16; // Minimum packing size per register for DPAS A.
63 constexpr unsigned packedSizeInBitsForDpasB =
64  32; // Minimum packing size per register for DPAS B.
65 static const char *const operandLayoutNamePrefix = "layout_operand_";
66 static const char *const resultLayoutNamePrefix = "layout_result_";
67 
68 namespace {
69 
70 //===----------------------------------------------------------------------===//
71 // Layout
72 //===----------------------------------------------------------------------===//
73 
74 /// Helper class to store the ND layout of lanes within a subgroup and data
75 /// owned by each lane.
76 struct Layout {
78  Layout() = default;
79  Layout(std::initializer_list<int64_t> list) : layout(list) {}
80  void print(llvm::raw_ostream &os) const;
81  size_t size() const { return layout.size(); }
82  int64_t operator[](size_t idx) const;
83 };
84 
85 void Layout::print(llvm::raw_ostream &os) const {
86  os << llvm::interleaved_array(layout);
87 }
88 
89 int64_t Layout::operator[](size_t idx) const {
90  assert(idx < layout.size() && "Index out of bounds.");
91  return layout[idx];
92 }
93 
94 /// LaneLayout represents the logical layout of lanes within a subgroup when it
95 /// accesses some value. LaneData represents the logical layout of data owned by
96 /// each work item.
97 using LaneLayout = Layout;
98 using LaneData = Layout;
99 
100 //===----------------------------------------------------------------------===//
101 // LayoutInfo
102 //===----------------------------------------------------------------------===//
103 
104 /// Helper class for tracking the analysis state of an mlir value. For layout
105 /// propagation, the analysis state is simply the lane_layout and lane_data of
106 /// each value. Purpose of this analysis to propagate some unique layout for
107 /// each value in the program starting from a set of anchor operations (like
108 /// DPAS, StoreNd, etc.).
109 ///
110 /// Given this, LayoutInfo satisifies the following properties:
111 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
112 /// assigned`.
113 /// 2) Two LayoutInfo values are equal if they are both assigned or
114 /// both not assigned. The concrete value of assigned state does not matter.
115 /// 3) The meet operator works as follows:
116 /// - If current state is assigned, return the current state. (already
117 /// a unique layout is assigned. don't change it)
118 /// - Otherwise, return the other state.
119 
120 struct LayoutInfo {
121 private:
122  LaneLayout laneLayout;
123  LaneData laneData;
124 
125 public:
126  LayoutInfo() = default;
127  LayoutInfo(const LaneLayout &layout, const LaneData &data)
128  : laneLayout(layout), laneData(data) {}
129 
130  // Two lattice values are equal if they have `some` layout. The actual
131  // content of the layout does not matter.
132  bool operator==(const LayoutInfo &other) const {
133  return this->isAssigned() == other.isAssigned();
134  }
135 
136  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
137 
138  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
139 
140  void print(raw_ostream &os) const;
141 
142  bool isAssigned() const {
143  return laneLayout.size() > 0 && laneData.size() > 0;
144  }
145 
146  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
147 
148  const LaneLayout &getLayout() const { return laneLayout; }
149  const LaneData &getData() const { return laneData; }
150  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
151  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
152 };
153 
154 void LayoutInfo::print(raw_ostream &os) const {
155  if (isAssigned()) {
156  os << "lane_layout: ";
157  laneLayout.print(os);
158  os << ", lane_data: ";
159  laneData.print(os);
160  } else {
161  os << "Not assigned.";
162  }
163 }
164 
165 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
166  if (!lhs.isAssigned())
167  return rhs;
168  return lhs;
169 }
170 
171 /// Since this is a backward analysis, join method is not used.
172 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
173  llvm_unreachable("Join should not be triggered by layout propagation.");
174 }
175 
176 /// Get the transposed layout according to the given permutation.
177 LayoutInfo
178 LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
179  if (!isAssigned())
180  return {};
181  LaneLayout newLayout;
182  LaneData newData;
183  for (int64_t idx : permutation) {
184  newLayout.layout.push_back(laneLayout.layout[idx]);
185  newData.layout.push_back(laneData.layout[idx]);
186  }
187  return LayoutInfo(newLayout, newData);
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // LayoutInfoLattice
192 //===----------------------------------------------------------------------===//
193 
194 /// Lattice holding the LayoutInfo for each value.
195 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
197  using Lattice::Lattice;
198 };
199 
200 /// Helper Functions to get default layouts. A `default layout` is a layout that
201 /// is assigned to a value when the layout is not fixed by some anchor operation
202 /// (like DPAS).
203 
204 /// Helper Function to get the default layout for uniform values like constants.
205 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
206 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
207 static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
208  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
209  if (rank == 1)
210  return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
211  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
212 }
213 
214 /// Helper to get the default layout for a vector type.
215 static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
216  // Expecting a 1D or 2D vector.
217  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
218  "Expected 1D or 2D vector.");
219  // Expecting int or float element type.
220  assert(vectorTy.getElementType().isIntOrFloat() &&
221  "Expected int or float element type.");
222  // If the rank is 1, then return default layout for 1D vector.
223  if (vectorTy.getRank() == 1)
224  return getDefaultLayoutInfo(1);
225  // Packing factor is determined by the element type bitwidth.
226  int packingFactor = 1;
227  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
228  if (bitwidth < packedSizeInBitsForDefault)
229  packingFactor = packedSizeInBitsForDefault / bitwidth;
230  return LayoutInfo(LaneLayout({1, subgroupSize}),
231  LaneData({1, packingFactor}));
232 }
233 
234 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
235 /// is set according to the following criteria:
236 /// * For A operand, the data must be packed in minimum
237 /// `packedSizeInBitsForDefault`
238 /// * For B operand, the data must be packed in minimum
239 /// `packedSizeInBitsForDpasB`
240 static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
241  unsigned operandNum) {
242  Type elementTy = vectorTy.getElementType();
243  assert(elementTy.isIntOrFloat() &&
244  "Expected int or float type in DPAS operands");
245  LaneLayout layout({1, subgroupSize});
246  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
247  // must have the VNNI format.
248  if (operandNum == 1 &&
250  LaneData data(
252  return LayoutInfo(layout, data);
253  }
254  // Otherwise, return the default layout for the vector type.
255  return getDefaultLayoutInfo(vectorTy);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // LayoutInfoPropagation
260 //===----------------------------------------------------------------------===//
261 
262 /// Backward data flow analysis to propagate the lane_layout and lane_data of
263 /// each value in the program. Currently, the layouts for operands DPAS,
264 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
265 /// this analysis is to propagate those known layouts to all their producers and
266 /// (other) consumers.
267 class LayoutInfoPropagation
268  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
269 private:
270  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
272 
273  void visitStoreNdOp(xegpu::StoreNdOp store,
276 
277  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
280 
281  void visitLoadNdOp(xegpu::LoadNdOp load,
284 
285  void visitLoadGatherOp(xegpu::LoadGatherOp load,
288 
289  void visitTransposeOp(vector::TransposeOp transpose,
292 
293  void visitVectorBitcastOp(vector::BitCastOp bitcast,
296 
297  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
300 
301  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
304 
305  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
308 
309  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
312 
313 public:
314  LayoutInfoPropagation(DataFlowSolver &solver,
315  SymbolTableCollection &symbolTable)
316  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
318 
319  LogicalResult
320  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
321  ArrayRef<const LayoutInfoLattice *> results) override;
322 
323  void visitBranchOperand(OpOperand &operand) override {};
324 
325  void visitCallOperand(OpOperand &operand) override {};
326 
327  void visitExternalCall(CallOpInterface call,
329  ArrayRef<const LayoutInfoLattice *> results) override {
330  };
331 
332  void setToExitState(LayoutInfoLattice *lattice) override {
333  (void)lattice->meet(LayoutInfo());
334  }
335 };
336 } // namespace
337 
338 LogicalResult LayoutInfoPropagation::visitOperation(
342  .Case<xegpu::DpasOp>(
343  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
344  .Case<xegpu::StoreNdOp>(
345  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
346  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
347  visitStoreScatterOp(storeScatterOp, operands, results);
348  })
349  .Case<xegpu::LoadNdOp>(
350  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
351  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
352  visitLoadGatherOp(loadGatherOp, operands, results);
353  })
354  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
355  visitCreateDescOp(createDescOp, operands, results);
356  })
357  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
358  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
359  })
360  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
361  visitPrefetchNdOp(prefetchNdOp, operands, results);
362  })
363  // No need to propagate the layout to operands in CreateNdDescOp because
364  // they are scalars (offsets, sizes, etc.).
365  .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
366  .Case<vector::TransposeOp>([&](auto transposeOp) {
367  visitTransposeOp(transposeOp, operands, results);
368  })
369  .Case<vector::BitCastOp>([&](auto bitcastOp) {
370  visitVectorBitcastOp(bitcastOp, operands, results);
371  })
372  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
373  visitVectorMultiReductionOp(reductionOp, operands, results);
374  })
375  // All other ops.
376  .Default([&](Operation *op) {
377  for (const LayoutInfoLattice *r : results) {
378  for (LayoutInfoLattice *operand : operands) {
379  // Propagate the layout of the result to the operand.
380  if (r->getValue().isAssigned())
381  meet(operand, *r);
382  }
383  }
384  });
385  // Add a dependency from each result to program point after the operation.
386  for (const LayoutInfoLattice *r : results) {
387  addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
388  }
389  return success();
390 }
391 
392 void LayoutInfoPropagation::visitPrefetchNdOp(
393  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
395  // Here we assign the default layout to the tensor descriptor operand of
396  // prefetch.
397  auto tdescTy = prefetch.getTensorDescType();
398  auto prefetchLayout = getDefaultLayoutInfo(
399  VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
400  // Propagate the layout to the source tensor descriptor.
401  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
402 }
403 
404 void LayoutInfoPropagation::visitVectorMultiReductionOp(
405  vector::MultiDimReductionOp reduction,
408  // The layout of the result must be present.
409  LayoutInfo resultLayout = results[0]->getValue();
410  if (!resultLayout.isAssigned())
411  return;
412  // We only consider 2D -> 1D reductions at this point.
413  assert(resultLayout.getLayout().size() == 1 &&
414  "Expected 1D layout for reduction result.");
415  // Given that the result is 1D, the layout of the operand should be 2D with
416  // default layout.
417  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
418  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
419  // Accumulator should have the same layout as the result.
420  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
421 }
422 
423 /// Propagate the layout of the result tensor to the source tensor descriptor in
424 /// UpdateNdOffsetOp.
425 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
426  xegpu::UpdateNdOffsetOp updateNdOffset,
429  // The layout of the result must be present.
430  LayoutInfo resultLayout = results[0]->getValue();
431  if (!resultLayout.isAssigned())
432  return;
433  // Propagate the layout to the source operand.
434  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
435 }
436 
437 /// Set the layouts for DPAS A, B, and C operands.
438 void LayoutInfoPropagation::visitDpasOp(
439  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
441  VectorType aTy = dpas.getLhsType();
442  VectorType bTy = dpas.getRhsType();
443  propagateIfChanged(operands[0],
444  operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
445  propagateIfChanged(operands[1],
446  operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
447  if (operands.size() > 2) {
448  VectorType cTy = dpas.getAccType();
449  propagateIfChanged(operands[2],
450  operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
451  }
452 }
453 
454 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
455 void LayoutInfoPropagation::visitStoreNdOp(
456  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
458  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
459  // Both operands should have the same layout
460  for (LayoutInfoLattice *operand : operands) {
461  propagateIfChanged(operand, operand->meet(storeLayout));
462  }
463 }
464 
465 /// Propagate the layout of the value to the tensor descriptor operand in
466 /// LoadNdOp.
467 void LayoutInfoPropagation::visitLoadNdOp(
468  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
470  LayoutInfo valueLayout = results[0]->getValue();
471  // Need the layout of the value to propagate to the tensor descriptor.
472  if (!valueLayout.isAssigned())
473  return;
474  LayoutInfo tensorDescLayout = valueLayout;
475  // LoadNdOp has the transpose effect. However, at the stage of this analysis
476  // this effect is not expected and should be abstracted away. Emit a warning.
477  if (auto transpose = load.getTranspose()) {
478  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
479  "LayoutInfoPropagation stage.");
480  tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
481  }
482  // Propagate the new layout to the tensor descriptor operand.
483  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
484 }
485 
486 /// For vector::TransposeOp, the layout of the result is transposed and
487 /// propagated to the operand.
488 void LayoutInfoPropagation::visitTransposeOp(
489  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
491  // Need the layout of transpose result to propagate to the operands.
492  LayoutInfo resultLayout = results[0]->getValue();
493  if (!resultLayout.isAssigned())
494  return;
495  LayoutInfo newLayout =
496  resultLayout.getTransposedLayout(transpose.getPermutation());
497  // Propagate the new layout to the vector operand.
498  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
499 }
500 
501 /// For vector::BitCastOp, the lane_data of the source layout is changed based
502 /// on the bit width of the source and result types.
503 void LayoutInfoPropagation::visitVectorBitcastOp(
504  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
506  // Need the layout of bitcast result to propagate to the operands.
507  LayoutInfo resultLayout = results[0]->getValue();
508  if (!resultLayout.isAssigned())
509  return;
510  int inElemTyBitWidth =
511  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
512  int outElemTyBitWidth =
513  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
514 
515  // LaneLayout does not change.
516  const LaneLayout &newLaneLayout = resultLayout.getLayout();
517  const LaneData &currData = resultLayout.getData();
518  LaneData newLaneData;
519  // It's a widening bitcast
520  if (inElemTyBitWidth < outElemTyBitWidth) {
521  int ratio = outElemTyBitWidth / inElemTyBitWidth;
522  newLaneData = resultLayout.getData()[0] == 1
523  ? LaneData({1, currData[1] * ratio})
524  : LaneData({currData[0] * ratio, 1});
525  } else {
526  // It's a narrowing bitcast
527  int ratio = inElemTyBitWidth / outElemTyBitWidth;
528  newLaneData = resultLayout.getData()[0] == 1
529  ? LaneData({1, currData[1] / ratio})
530  : LaneData({currData[0] / ratio, 1});
531  }
532 
533  propagateIfChanged(operands[0],
534  operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
535 }
536 
537 /// Propagate the layout of the result to the tensor descriptor and mask
538 /// operands in LoadGatherOp.
539 void LayoutInfoPropagation::visitLoadGatherOp(
540  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
542  LayoutInfo valueLayout = results[0]->getValue();
543  // Need the layout of the value to propagate to the tensor descriptor.
544  if (!valueLayout.isAssigned())
545  return;
546 
547  LayoutInfo tensorDescLayout = valueLayout;
548  if (load.getTranspose()) {
549  // LoadGatherOp has the transpose effect. However, at the stage of this
550  // analyis this effect is not expected and should be abstracted away. Emit
551  // a warning.
552  load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
553  "LayoutInfoPropagation stage.");
554  tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
555  }
556  // Mask operand should have 1D default layout.
557  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
558  // Propagate the new layout to the tensor descriptor operand.
559  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
560  // Propagate the new layout to the mask operand.
561  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
562 }
563 
564 /// Propagate the layout of the descriptor to the vector offset operand in
565 /// CreateDescOp.
566 void LayoutInfoPropagation::visitCreateDescOp(
567  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
569  LayoutInfo descLayout = results[0]->getValue();
570  // Need the layout of the descriptor to propagate to the operands.
571  if (!descLayout.isAssigned())
572  return;
573  // For offset operand propagate 1D default layout.
574  LayoutInfo layout = getDefaultLayoutInfo(1);
575  propagateIfChanged(operands[1], operands[1]->meet(layout));
576 }
577 
578 /// Set the layout for the value, tensor descriptor, and mask operands in the
579 /// StoreScatterOp.
580 void LayoutInfoPropagation::visitStoreScatterOp(
581  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
583  // Currently, for 2D StoreScatterOp we expect that the height dimension of
584  // the tensor descriptor is equal to the subgroup size. This is ensured by
585  // the op verifier.
586  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
587  if (tdescShape.size() > 1)
588  assert(
589  tdescShape[0] == subgroupSize &&
590  "Expected the first dimension of 2D tensor descriptor to be equal to "
591  "subgroup size.");
592 
593  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
594  LayoutInfo storeScatterLayout = valueLayout;
595  if (storeScatter.getTranspose()) {
596  // StoreScatteOp allows transpose effect. However, at the stage of this
597  // analyis this effect is not expected and should be abstracted away. Emit
598  // a warning.
599  storeScatter.emitWarning("Transpose effect is not expected for "
600  "StoreScatterOp at LayoutInfoPropagation stage.");
601  storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
602  }
603  // Propagate the value layout.
604  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
605  // Propagate the tensor descriptor layout.
606  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
607  // Use default 1D layout for mask operand.
608  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
609  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
610 }
611 
612 namespace {
613 
614 //===----------------------------------------------------------------------===//
615 // RunLayoutInfoPropagation
616 //===----------------------------------------------------------------------===//
617 
618 /// Driver class for running the LayoutInfoPropagation analysis.
619 class RunLayoutInfoPropagation {
620 public:
621  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
622 
623  RunLayoutInfoPropagation(Operation *op) : target(op) {
624  SymbolTableCollection symbolTable;
625  solver.load<DeadCodeAnalysis>();
626  solver.load<SparseConstantPropagation>();
627  solver.load<LayoutInfoPropagation>(symbolTable);
628  (void)solver.initializeAndRun(op);
629  }
630 
631  LayoutInfo getLayoutInfo(Value val);
632 
633  void printAnalysisResult(llvm::raw_ostream &os);
634 
635 private:
636  DataFlowSolver solver;
637  const Operation *target;
638 };
639 } // namespace
640 
641 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
642  auto *state = solver.lookupState<LayoutInfoLattice>(val);
643  if (!state)
644  return {};
645  return state->getValue();
646 }
647 
648 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
649  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
650  os << "function: " << funcOp.getName() << ":\n";
651  // Function arguments
652  for (BlockArgument arg : funcOp.getArguments()) {
653  LayoutInfo layout = getLayoutInfo(arg);
654  os << "argument: " << arg << "\n";
655  os << "layout : ";
656  layout.print(os);
657  os << "\n";
658  }
659  // Function ops
660  funcOp.walk([&](Operation *op) {
661  // Skip ops that do not have results
662  if (op->getResults().empty())
663  return;
664  os << "op : ";
665  // For control-flow ops, print the op name only.
666  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
667  os << op->getName();
668  else
669  op->print(os);
670  os << "\n";
671  // Print the layout for each result.
672  for (auto [i, r] : llvm::enumerate(op->getResults())) {
673  LayoutInfo layout = getLayoutInfo(r);
674  os << "layout for result #" << i << ": ";
675  layout.print(os);
676  os << "\n";
677  }
678  });
679  };
680 
682  if (auto modOp = dyn_cast<ModuleOp>(target)) {
683  for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
684  funcOps.push_back(funcOp);
685  }
686  // Collect all GpuFuncOps in the module.
687  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
688  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
689  funcOps.push_back(gpuFuncOp);
690  }
691  }
692  }
693  // Print the analysis result for each function.
694  for (FunctionOpInterface funcOp : funcOps) {
695  printFunctionResult(funcOp);
696  }
697 }
698 
699 namespace {
700 
701 //===----------------------------------------------------------------------===//
702 // LayoutAttrAssignment
703 //===----------------------------------------------------------------------===//
704 
705 /// This class is responsible for assigning the layout attributes to the ops and
706 /// their users based on the layout propagation analysis result.
707 class LayoutAttrAssignment {
708 public:
709  LayoutAttrAssignment(Operation *top,
710  function_ref<LayoutInfo(Value)> getLayout)
711  : getAnalysisResult(getLayout), top(top) {}
712 
713  LogicalResult run();
714 
715 private:
716  LogicalResult assign(Operation *op);
717  void assignToUsers(Value v, xegpu::LayoutAttr layout);
718  xegpu::LayoutAttr getLayoutAttrForValue(Value v);
719  LogicalResult resolveConflicts();
720  // Callable to get the layout of a value based on the layout propagation
721  // analysis.
722  function_ref<LayoutInfo(Value)> getAnalysisResult;
723  Operation *top;
724 };
725 
726 } // namespace
727 
728 /// Helper to assign the layout attribute to the users of the value.
729 void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
730  for (OpOperand &user : v.getUses()) {
731  Operation *owner = user.getOwner();
732  unsigned operandNumber = user.getOperandNumber();
733  // Use a generic name for ease of querying the layout attribute later.
734  std::string attrName =
735  operandLayoutNamePrefix + std::to_string(operandNumber);
736  owner->setAttr(attrName, layout);
737  }
738 }
739 
740 /// Convert the layout assigned to a value to xegpu::LayoutAttr.
741 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
742  LayoutInfo layout = getAnalysisResult(v);
743  if (!layout.isAssigned())
744  return {};
745  SmallVector<int, 2> laneLayout, laneData;
746  for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
747  layout.getDataAsArrayRef())) {
748  laneLayout.push_back(static_cast<int>(layout));
749  laneData.push_back(static_cast<int>(data));
750  }
751  return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
752 }
753 
754 /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
755 /// based on the layout propagation analysis result.
756 LogicalResult LayoutAttrAssignment::assign(Operation *op) {
757  // For function ops, propagate the function argument layout to the users.
758  if (auto func = dyn_cast<FunctionOpInterface>(op)) {
759  for (BlockArgument arg : func.getArguments()) {
760  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
761  if (layoutInfo) {
762  assignToUsers(arg, layoutInfo);
763  }
764  }
765  return success();
766  }
767  // If no results, move on.
768  if (op->getNumResults() == 0)
769  return success();
770  // If all the results are scalars, move on.
771  if (llvm::all_of(op->getResultTypes(),
772  [](Type t) { return t.isIntOrIndexOrFloat(); }))
773  return success();
774  // If the op has more than one result and at least one result is a tensor
775  // descriptor, exit. This case is not supported yet.
776  // TODO: Support this case.
777  if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type t) {
778  return isa<xegpu::TensorDescType>(t);
779  })) {
780  LLVM_DEBUG(
781  DBGS() << op->getName()
782  << " op has more than one result and at least one is a tensor "
783  "descriptor. This case is not handled.\n");
784  return failure();
785  }
786  // If the result is a tensor descriptor, attach the layout to the tensor
787  // descriptor itself.
788  if (auto tensorDescTy =
789  dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
790  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
791  if (!layoutInfo) {
792  LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
793  return failure();
794  }
795 
796  // Clone the op, attach the layout to the result tensor descriptor, and
797  // remove the original op.
798  OpBuilder builder(op);
799  Operation *newOp = builder.clone(*op);
800  auto newTensorDescTy = xegpu::TensorDescType::get(
801  tensorDescTy.getContext(), tensorDescTy.getShape(),
802  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
803  newOp->getResult(0).setType(newTensorDescTy);
804  op->replaceAllUsesWith(newOp->getResults());
805  op->erase();
806  return success();
807  }
808  // Otherwise simply attach the layout to the op itself.
809  for (auto [i, r] : llvm::enumerate(op->getResults())) {
810  xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
811  if (layoutInfo) {
812  std::string attrName = resultLayoutNamePrefix + std::to_string(i);
813  op->setAttr(attrName, layoutInfo);
814  // Attach the layout attribute to the users of the result.
815  assignToUsers(r, layoutInfo);
816  }
817  }
818  return success();
819 }
820 
821 /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
822 LogicalResult LayoutAttrAssignment::run() {
823  auto walkResult = top->walk([&](Operation *op) {
824  if (failed(assign(op)))
825  return WalkResult::interrupt();
826  return WalkResult::advance();
827  });
828 
829  if (walkResult.wasInterrupted())
830  return failure();
831 
832  return resolveConflicts();
833 }
834 
835 /// TODO: Implement the layout conflict resolution. This must ensure mainly two
836 /// things:
837 /// 1) Is a given layout supported by the op? (need to query the target
838 /// HW info). Otherwise can we achieve this layout using a layout conversion?
839 /// 2) Do all the operands have the required layout? If not, can it
840 /// be resolved using a layout conversion?
841 LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
842 
843 namespace {
844 
845 //===----------------------------------------------------------------------===//
846 // SIMT Distribution Patterns
847 //===----------------------------------------------------------------------===//
848 
849 /// Helper function to get distributed vector type for a source vector type
850 /// according to the lane_layout. We simply divide each dimension of tensor
851 /// descriptor shape by corresponding lane_layout dimension. If
852 /// array_length > 1, that is appended to the front of the ditributed shape.
853 /// NOTE: This is the vector type that will be returned by the
854 /// gpu.warp_execute_on_lane0 op.
855 ///
856 /// Examples:
857 /// | original vector shape | lane_layout | distributed vector shape |
858 /// |-----------------------|-------------|--------------------------|
859 /// | 32x16 | [1, 16] | 32x1 |
860 /// | 32x16 | [2, 8] | 16x2 |
861 /// | 2x32x16 | [1, 16] | 2x32x1 |
862 static FailureOr<VectorType>
863 getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
864  VectorType originalType) {
865  if (!layout)
866  return failure();
867 
868  auto laneLayout = layout.getLaneLayout().asArrayRef();
869  assert(originalType.getShape().size() >= laneLayout.size() &&
870  "Rank of the original vector type should be greater or equal to the "
871  "size of the lane layout to distribute the vector type.");
872  SmallVector<int64_t> distributedShape(originalType.getShape());
873  // Only distribute the last `laneLayout.size()` dimensions. The remaining
874  // dimensions are not distributed.
875  unsigned distributionStart = originalType.getRank() - laneLayout.size();
876  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
877  if (i < distributionStart) {
878  continue;
879  }
880  // Check if the dimension can be distributed evenly.
881  if (dim % laneLayout[i - distributionStart] != 0)
882  return failure();
883  distributedShape[i] = dim / laneLayout[i - distributionStart];
884  }
885  return VectorType::get(distributedShape, originalType.getElementType());
886 }
887 
888 /// Helper function to resolve types if the distributed type out of
889 /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
890 /// Example 1:
891 /// distributed type: vector<8x1xf32>
892 /// expected type: vector<8xf32>
893 /// resolved using,
894 /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
895 /// Example 2:
896 /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
897 /// expected type: xegpu.tensor_desc<8x16xf32>
898 /// resolved using,
899 /// %0 = unrealized_conversion_cast %1 :
900 /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
901 /// xegpu.tensor_desc<8x16xf32>
902 template <typename T>
903 static Value resolveDistributedTy(Value orig, T expected,
904  PatternRewriter &rewriter) {
905  // If orig and expected types are the same, return orig.
906  if (orig.getType() == expected)
907  return orig;
908  // If orig is a vector type, create a shape cast op to reconcile the types.
909  if (isa<VectorType>(orig.getType())) {
910  auto castOp =
911  rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
912  return castOp.getResult();
913  }
914  // If orig is a tensor descriptor type, create an unrealized conversion cast
915  // op to reconcile the types.
916  if (isa<xegpu::TensorDescType>(orig.getType())) {
917  auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
918  expected, orig);
919  return castOp.getResult(0);
920  }
921  llvm_unreachable("Unsupported type for reconciliation");
922  return orig;
923 }
924 
925 /// Helper function to filter out the temporary layout attributes attached
926 /// during the layout assignment process. These are not needed after going to
927 /// SIMT.
929 removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
931  for (NamedAttribute attr : attrs) {
932  if (attr.getName().strref().contains(operandLayoutNamePrefix) ||
933  attr.getName().strref().contains(resultLayoutNamePrefix)) {
934  continue;
935  }
936  newAttrs.push_back(attr);
937  }
938  return newAttrs;
939 }
940 
941 /// Helper function to check if the layout is packed. Layout is packed if it is
942 /// 2D and lane_data[0] != 1 (data packed from col dimension).
943 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
944  if (layout == xegpu::LayoutAttr())
945  return false;
946  DenseI32ArrayAttr laneData = layout.getLaneData();
947  if (!laneData || laneData.size() != 2)
948  return false;
949  return laneData.asArrayRef()[0] != 1;
950 }
951 
952 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
953 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
954 /// contained within a WarpExecuteOnLane0Op.
955 /// Example:
956 ///
957 /// ```
958 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
959 /// ...
960 /// ...
961 /// gpu.return %result: vector<8x16xf32>
962 /// }
963 /// ```
964 /// To
965 /// ```
966 /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
967 /// %laneid = gpu.lane_id : index
968 /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
969 /// ...
970 /// ...
971 /// gpu.yield %result: vector<8x16xf32>
972 /// }
973 /// return %0
974 /// }
975 struct MoveFuncBodyToWarpExecuteOnLane0
976  : public OpRewritePattern<gpu::GPUFuncOp> {
978  LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
979  PatternRewriter &rewriter) const override {
980  // If the function only contains a single void return, skip.
981  if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
982  return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
983  }))
984  return failure();
985  // If the function already moved inside a warp_execute_on_lane0, skip.
986  if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
987  return isa<gpu::WarpExecuteOnLane0Op>(op);
988  }))
989  return failure();
990  // Create a new function with the same signature.
991  auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
992  gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
993  // Create a WarpExecuteOnLane0Op with same arguments and results as the
994  // original gpuFuncOp.
995  rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
996  auto laneId = rewriter.create<gpu::LaneIdOp>(
997  newGpuFunc.getLoc(), rewriter.getIndexType(),
998  /** upperBound = **/ mlir::IntegerAttr());
999  ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
1000  auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
1001  laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
1002  newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
1003  Block &warpBodyBlock = warpOp.getBodyRegion().front();
1004  // Replace the ReturnOp of the original gpu function with a YieldOp.
1005  auto origRetunOp =
1006  cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
1007  rewriter.setInsertionPointAfter(origRetunOp);
1008  rewriter.create<gpu::YieldOp>(origRetunOp.getLoc(),
1009  origRetunOp.getOperands());
1010  rewriter.eraseOp(origRetunOp);
1011  // Move the original function body to the WarpExecuteOnLane0Op body.
1012  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
1013  warpOp.getBodyRegion().begin());
1014  rewriter.eraseBlock(&warpBodyBlock);
1015  // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
1016  rewriter.setInsertionPointAfter(warpOp);
1017  rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
1018  rewriter.replaceOp(gpuFuncOp, newGpuFunc);
1019  return success();
1020  }
1021 };
1022 
1023 /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
1024 /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
1025 /// still contain the original op that will not be used by the yield op (and
1026 /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
1027 /// arguments. Tensor descriptor shape is not distributed because it is a
1028 /// uniform value across all work items within the subgroup. However, the
1029 /// layout information is dropped in the new tensor descriptor type.
1030 ///
1031 /// Example:
1032 ///
1033 /// ```
1034 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1035 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1036 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1037 /// ...
1038 /// %td = xegpu.create_nd_tdesc %arg0[0, 0]
1039 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1040 /// vector.yield %td
1041 /// }
1042 /// ```
1043 /// To
1044 /// ```
1045 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
1046 /// ...
1047 /// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1048 /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1049 /// vector.yield %arg0, %dead
1050 /// }
1051 /// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
1052 /// -> !xegpu.tensor_desc<4x8xf32>
1053 ///
1054 /// ```
1055 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
1056  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1057  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1058  PatternRewriter &rewriter) const override {
1059  OpOperand *operand =
1060  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
1061  if (!operand)
1062  return rewriter.notifyMatchFailure(
1063  subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
1064  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
1065  unsigned operandIdx = operand->getOperandNumber();
1066 
1067  xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
1068  if (!layout)
1069  return rewriter.notifyMatchFailure(
1070  descOp, "the tensor descriptor lacks layout attribute");
1071 
1072  SmallVector<size_t> newRetIndices;
1073  SmallVector<Value> newYieldValues;
1074  SmallVector<Type> newYieldTypes;
1075 
1076  for (Value operand : descOp->getOperands()) {
1077  newYieldValues.push_back(operand);
1078  newYieldTypes.push_back(operand.getType());
1079  }
1080  rewriter.setInsertionPoint(subgroupOp);
1081  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1082  rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
1083  /* new yielded types = */ newYieldTypes, newRetIndices);
1084 
1085  SmallVector<Value> newDescOperands;
1086  for (size_t i : newRetIndices) {
1087  newDescOperands.push_back(newWarpOp.getResult(i));
1088  }
1089  rewriter.setInsertionPointAfter(newWarpOp);
1090  xegpu::TensorDescType distributedTensorDescTy =
1091  descOp.getType().dropLayouts(); // Distributed tensor descriptor type
1092  // does not contain layout info.
1093  auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
1094  newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
1095  descOp->getAttrs());
1096 
1097  Value distributedVal = newWarpOp.getResult(operandIdx);
1098  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
1099  return success();
1100  }
1101 };
1102 
1103 /// Distribute a store_nd op at the end of enclosing
1104 /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
1105 /// through the warp op interface they would be propagated as returned values.
1106 /// Source vector is distributed based on lane layout. Appropriate cast ops are
1107 /// inserted if the distributed types does not match expected xegpu SIMT types.
1108 ///
1109 /// Example:
1110 ///
1111 /// ```
1112 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1113 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
1114 /// ...
1115 /// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1116 /// !xegpu.tensor_desc<4x8xf32, #layout0>
1117 /// }
1118 /// ```
1119 /// To
1120 /// ```
1121 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1122 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1123 /// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1124 /// #layout0>
1125 /// }
1126 /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
1127 /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1128 /// #layout0>
1129 /// -> !xegpu.tensor_desc<4x8xf32>
1130 /// xegpu.store_nd %0, %1: vector<4xf32>,
1131 /// !xegpu.tensor_desc<4x8xf32>
1132 ///
1133 /// ```
1134 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
1135  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1136  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1137  PatternRewriter &rewriter) const override {
1138  auto yield = cast<gpu::YieldOp>(
1139  subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1140  Operation *lastNode = yield->getPrevNode();
1141  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
1142  if (!storeOp)
1143  return failure();
1144 
1145  xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
1146  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1147  if (!layout)
1148  return rewriter.notifyMatchFailure(
1149  storeOp, "the source tensor descriptor lacks layout attribute");
1150 
1151  FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
1152  getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1153  if (failed(distributedTypeByWarpOpOrFailure))
1154  return rewriter.notifyMatchFailure(storeOp,
1155  "Failed to distribute the type");
1156  VectorType distributedTypeByWarpOp =
1157  distributedTypeByWarpOpOrFailure.value();
1158 
1159  SmallVector<size_t> newRetIndices;
1160  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1161  rewriter, subgroupOp,
1162  /* new yielded values = */
1163  ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
1164  /* new yielded types = */
1165  TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
1166  newRetIndices);
1167  // Create a new store op outside the warp op with the distributed vector
1168  // type. Tensor descriptor is not distributed.
1169  rewriter.setInsertionPointAfter(newWarpOp);
1170  SmallVector<Value> newStoreOperands;
1171 
1172  // For the value operand, there can be a mismatch between the vector type
1173  // distributed by the warp op and (xegpu-specific) distributed type
1174  // supported by the store op. Type mismatch must be resolved using
1175  // appropriate cast op.
1176  FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
1177  xegpu::getDistributedVectorType(storeOp.getTensorDescType());
1178  if (failed(storeNdDistributedValueTyOrFailure))
1179  return rewriter.notifyMatchFailure(
1180  storeOp, "Failed to get distributed vector type for the store op");
1181  newStoreOperands.push_back(resolveDistributedTy(
1182  newWarpOp.getResult(newRetIndices[0]),
1183  storeNdDistributedValueTyOrFailure.value(), rewriter));
1184  // For the tensor descriptor operand, the layout attribute is dropped after
1185  // distribution. Types needs to be resolved in this case also.
1186  xegpu::TensorDescType distributedTensorDescTy =
1187  storeOp.getTensorDescType().dropLayouts();
1188  newStoreOperands.push_back(
1189  resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1190  distributedTensorDescTy, rewriter));
1191 
1192  rewriter.create<xegpu::StoreNdOp>(
1193  newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
1194  removeTemporaryLayoutAttributes(storeOp->getAttrs()));
1195  rewriter.eraseOp(storeOp);
1196  return success();
1197  }
1198 };
1199 
1200 /// Distribute a load_nd op feeding into vector.yield op for the enclosing
1201 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1202 /// The warp op will still contain the original op that will not be used by
1203 /// the yield op (and should be cleaned up later). The yield op will
1204 /// bypass the load's arguments. Only the loaded vector is distributed
1205 /// according to lane layout and, tensor descriptor types is not
1206 /// distributed. Appropriate cast ops are inserted if the distributed types does
1207 /// not match expected xegpu SIMT types.
1208 ///
1209 /// Example:
1210 ///
1211 /// ```
1212 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1213 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1214 /// (vector<4x1xf32>) {
1215 /// ...
1216 /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1217 /// ->
1218 /// vector<4x8xf32>
1219 /// gpu.yield %ld
1220 /// }
1221 /// ```
1222 /// To
1223 /// ```
1224 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1225 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1226 /// ...
1227 /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
1228 /// vector<4x8xf32> gpu.yield %dead, %arg0
1229 /// }
1230 /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1231 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1232 /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
1233 /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
1234 ///
1235 /// ```
1236 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
1237  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1238  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1239  PatternRewriter &rewriter) const override {
1240  OpOperand *operand =
1241  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
1242  if (!operand)
1243  return rewriter.notifyMatchFailure(
1244  subgroupOp, "warp result is not a xegpu::LoadNd op");
1245 
1246  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
1247  xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
1248  xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1249  if (!layout)
1250  return rewriter.notifyMatchFailure(
1251  loadOp, "the source tensor descriptor lacks layout attribute");
1252 
1253  unsigned operandIdx = operand->getOperandNumber();
1254  VectorType distributedTypeByWarpOp =
1255  cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
1256 
1257  SmallVector<size_t> newRetIndices;
1258  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1259  rewriter, subgroupOp,
1260  /* new yielded values = */ loadOp.getTensorDesc(),
1261  /* new yielded types = */ tensorDescTy, newRetIndices);
1262 
1263  // Create a new load op outside the warp op with the distributed vector
1264  // type.
1265  rewriter.setInsertionPointAfter(newWarpOp);
1266  FailureOr<VectorType> loadNdDistValueTyOrFailure =
1267  xegpu::getDistributedVectorType(loadOp.getTensorDescType());
1268  if (failed(loadNdDistValueTyOrFailure))
1269  return rewriter.notifyMatchFailure(
1270  loadOp, "Failed to get distributed vector type for the load op");
1271  xegpu::TensorDescType distributedTensorDescTy =
1272  loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
1273  // descriptor type does not
1274  // contain layout info.
1275  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
1276  newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1277  resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
1278  distributedTensorDescTy, rewriter),
1279  removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1280  // Set the packed attribute if the layout requires it.
1281  newLoadOp.setPacked(hasPackedLayout(layout));
1282  Value distributedVal = newWarpOp.getResult(operandIdx);
1283  // There can be a conflict between the vector type distributed by the
1284  // warp op and (xegpu-specific) distributed type supported by the load
1285  // op. Resolve these mismatches by inserting a cast.
1286  Value tyResolvedVal = resolveDistributedTy(
1287  newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1288  rewriter.replaceAllUsesWith(distributedVal, tyResolvedVal);
1289  return success();
1290  }
1291 };
1292 
1293 /// Distribute a dpas op feeding into vector.yield op for the enclosing
1294 /// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1295 /// The warp op will still contain the original op that will not be used by
1296 /// the yield op (and should be cleaned up later). The yield op will
1297 /// bypass the dpas's arguments. Appropriate cast ops are inserted if the
1298 /// distributed types does not match expected xegpu SIMT types.
1299 /// Example:
1300 /// ```
1301 /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1302 /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
1303 /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1304 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1305 /// (vector<8x1xf32>) {
1306 /// ...
1307 /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
1308 /// vector<8x16xf32>
1309 /// gpu.yield %dpas
1310 /// }
1311 /// ```
1312 /// To
1313 /// ```
1314 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
1315 /// vector<8x1xf16>, vector<16x1xf16>) {
1316 /// ...
1317 /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
1318 /// -> vector<8x16xf32>
1319 /// gpu.yield %dead, %arg0, %arg1
1320 /// }
1321 /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
1322 /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
1323 /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
1324 /// vector<8xf32>
1325 /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
1326 /// ```
1327 struct DpasDistribution final : public gpu::WarpDistributionPattern {
1328  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1329  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1330  PatternRewriter &rewriter) const override {
1331  OpOperand *operand =
1332  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
1333  if (!operand)
1334  return rewriter.notifyMatchFailure(subgroupOp,
1335  "warp result is not a xegpu::Dpas op");
1336 
1337  auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
1338  unsigned operandIdx = operand->getOperandNumber();
1339  std::string layoutAName =
1340  llvm::formatv("{0}{1}", operandLayoutNamePrefix, 0).str();
1341  std::string layoutBName =
1342  llvm::formatv("{0}{1}", operandLayoutNamePrefix, 1).str();
1343  auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str();
1344  xegpu::LayoutAttr layoutA =
1345  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1346  xegpu::LayoutAttr layoutB =
1347  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
1348  xegpu::LayoutAttr layoutOut =
1349  dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
1350  if (!layoutA || !layoutB || !layoutOut)
1351  return rewriter.notifyMatchFailure(
1352  dpasOp,
1353  "the xegpu::Dpas op lacks layout attribute for A, B or output");
1354 
1355  FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
1356  getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
1357  FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
1358  getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
1359  FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
1360  getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
1361  if (failed(distLhsTypeByWarpOpOrFailure) ||
1362  failed(distRhsTypeByWarpOpOrFailure) ||
1363  failed(distResultTypeByWarpOpOrFailure))
1364  return rewriter.notifyMatchFailure(
1365  dpasOp,
1366  "Failed to distribute the A, B or output types in xegpu::Dpas op");
1367 
1368  llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
1369  dpasOp.getRhs()};
1370  llvm::SmallVector<Type, 3> newYieldTypes{
1371  distLhsTypeByWarpOpOrFailure.value(),
1372  distRhsTypeByWarpOpOrFailure.value()};
1373  // Dpas acc operand is optional.
1374  if (dpasOp.getAcc()) {
1375  newYieldValues.push_back(dpasOp.getAcc());
1376  newYieldTypes.push_back(distResultTypeByWarpOpOrFailure.value());
1377  }
1378  // Create a new warp op without the dpas.
1379  SmallVector<size_t> newRetIndices;
1380  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1381  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1382 
1383  FailureOr<VectorType> expectedDistLhsTyOrFailure =
1384  xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
1385  FailureOr<VectorType> expectedDistRhsTyOrFailure =
1386  xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
1387  FailureOr<VectorType> expectedDistResultTyOrFailure =
1388  xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
1389  if (failed(expectedDistLhsTyOrFailure) ||
1390  failed(expectedDistRhsTyOrFailure) ||
1391  failed(expectedDistResultTyOrFailure))
1392  return rewriter.notifyMatchFailure(
1393  dpasOp,
1394  "Failed to get distributed vector type for the dpas operands.");
1395  // Create a new dpas op outside the warp op.
1396  rewriter.setInsertionPointAfter(newWarpOp);
1397  SmallVector<Value> newDpasOperands;
1398  SmallVector<VectorType> newDpasOperandExpectedTypes;
1399 
1400  // Resolve the distributed types with the original types.
1401  newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1402  newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1403  VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
1404  if (dpasOp.getAcc())
1405  newDpasOperandExpectedTypes.push_back(distributedResultTy);
1406 
1407  for (unsigned i = 0; i < newRetIndices.size(); i++) {
1408  newDpasOperands.push_back(
1409  resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
1410  newDpasOperandExpectedTypes[i], rewriter));
1411  }
1412  Value newDpasOp = rewriter.create<xegpu::DpasOp>(
1413  newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
1414  removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
1415  Value distributedVal = newWarpOp.getResult(operandIdx);
1416  // Resolve the output type.
1417  newDpasOp = resolveDistributedTy(
1418  newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
1419  rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
1420  return success();
1421  }
1422 };
1423 
1424 /// Sink an update_nd_offset op feeding into yield op of an enclosing
1425 /// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1426 /// original op that will not be used by the yield op (and should be cleaned
1427 /// up later). The yield op will bypass the updateOp's arguments. The tensor
1428 /// descriptor type is not distributed. Appropriate cast ops are inserted if
1429 /// the distributed types does not match expected xegpu SIMT types.
1430 /// Example:
1431 /// ```
1432 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1433 /// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1434 /// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1435 /// ...
1436 /// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1437 /// !xegpu.tensor_desc<4x8xf32, #layout0>
1438 /// gpu.yield %update
1439 /// }
1440 /// ...
1441 /// ```
1442 /// To
1443 /// ```
1444 /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1445 /// !xegpu.tensor_desc<4x8xf32, #layout0>,
1446 /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
1447 /// ...
1448 /// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1449 /// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1450 /// gpu.yield %dead, %arg0, %c32, %c16
1451 /// }
1452 /// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1453 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1454 /// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
1455 /// !xegpu.tensor_desc<4x8xf32>
1456 /// ...
1457 /// ```
1458 struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1459  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1460  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1461  PatternRewriter &rewriter) const override {
1462  OpOperand *operand =
1463  getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1464  if (!operand)
1465  return rewriter.notifyMatchFailure(
1466  subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
1467  auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
1468  unsigned operandIdx = operand->getOperandNumber();
1469  // new update op does not have layout attribute.
1470  xegpu::TensorDescType newTensorDescTy =
1471  updateOp.getTensorDescType().dropLayouts();
1472 
1473  SmallVector<Value, 3> newYieldValues;
1474  SmallVector<Type, 3> newYieldTypes;
1475  for (Value operand : updateOp->getOperands()) {
1476  newYieldValues.push_back(operand);
1477  if (isa<xegpu::TensorDescType>(operand.getType())) {
1478  newYieldTypes.push_back(newTensorDescTy);
1479  } else {
1480  newYieldTypes.push_back(operand.getType());
1481  }
1482  }
1483  SmallVector<size_t> newRetIndices;
1484  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1485  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1486  rewriter.setInsertionPointAfter(newWarpOp);
1487  SmallVector<Value> newUpdateOperands;
1488  for (size_t i : newRetIndices) {
1489  // For the tensor descriptor operand, the layout attribute is dropped
1490  // after distribution. Types needs to be resolved in this case.
1491  if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1492  newUpdateOperands.push_back(resolveDistributedTy(
1493  newWarpOp.getResult(i), newTensorDescTy, rewriter));
1494  } else {
1495  newUpdateOperands.push_back(newWarpOp.getResult(i));
1496  }
1497  }
1498  // Create a new update op outside the warp op.
1499  auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
1500  newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1501  removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1502  Value distributedVal = newWarpOp.getResult(operandIdx);
1503  rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
1504  return success();
1505  }
1506 };
1507 
1508 /// Distribute a prefetch_nd op at the end of enclosing
1509 /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1510 /// through the warp op interface they would be propagated as returned values.
1511 /// Tensor descriptor shape is not distributed because it is a uniform value
1512 /// across all work items within the subgroup. Appropriate cast ops are inserted
1513 /// if the distributed types does not match expected xegpu SIMT types.
1514 ///
1515 /// Example:
1516 ///
1517 /// ```
1518 /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1519 /// gpu.warp_execute_on_lane_0(%laneid) -> () {
1520 /// ...
1521 /// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
1522 /// }
1523 /// ```
1524 /// To
1525 /// ```
1526 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
1527 /// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1528 /// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
1529 /// }
1530 /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1531 /// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1532 /// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
1533 ///
1534 /// ```
1535 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1536  using gpu::WarpDistributionPattern::WarpDistributionPattern;
1537  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1538  PatternRewriter &rewriter) const override {
1539  auto yield = cast<gpu::YieldOp>(
1540  subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1541  Operation *lastNode = yield->getPrevNode();
1542  auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1543  if (!prefetchOp)
1544  return failure();
1545  xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
1546  if (!layout)
1547  return rewriter.notifyMatchFailure(
1548  prefetchOp, "the source tensor descriptor lacks layout attribute");
1549 
1550  SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
1551  SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
1552  SmallVector<size_t> newRetIndices;
1553  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1554  rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1555  // Create a new prefetch op outside the warp op with updated tensor
1556  // descriptor type. Source tensor descriptor require type resolution.
1557  xegpu::TensorDescType newTensorDescTy =
1558  prefetchOp.getTensorDescType().dropLayouts();
1559  rewriter.setInsertionPointAfter(newWarpOp);
1560  SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
1561  newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1562  rewriter.create<xegpu::PrefetchNdOp>(
1563  newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
1564  removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1565  rewriter.eraseOp(prefetchOp);
1566  return success();
1567  }
1568 };
1569 
1570 } // namespace
1571 
1572 namespace {
1573 struct XeGPUSubgroupDistributePass final
1574  : public xegpu::impl::XeGPUSubgroupDistributeBase<
1575  XeGPUSubgroupDistributePass> {
1576  XeGPUSubgroupDistributePass() = default;
1577  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1578  default;
1579  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1580  : XeGPUSubgroupDistributeBase(options) {}
1581  void runOnOperation() override;
1582 };
1583 } // namespace
1584 
1587  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1588  LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1589  UpdateNdOffsetDistribution>(patterns.getContext());
1590 }
1591 
1592 void XeGPUSubgroupDistributePass::runOnOperation() {
1593  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
1594  // Print the analysis result and exit. (for testing purposes)
1595  if (printOnly) {
1596  auto &os = llvm::outs();
1597  analyis.printAnalysisResult(os);
1598  return;
1599  }
1600  auto getPropagatedLayout = [&](Value val) {
1601  return analyis.getLayoutInfo(val);
1602  };
1603 
1604  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
1605  // propagation analysis result.
1606  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
1607  if (failed(layoutAssignment.run())) {
1608  signalPassFailure();
1609  return;
1610  }
1611 
1612  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
1613  // operation.
1614  {
1616  patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1617 
1618  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1619  signalPassFailure();
1620  return;
1621  }
1622  // At this point, we have moved the entire function body inside the warpOp.
1623  // Now move any scalar uniform code outside of the warpOp (like GPU index
1624  // ops, scalar constants, etc.). This will simplify the later lowering and
1625  // avoid custom patterns for these ops.
1626  getOperation()->walk([&](Operation *op) {
1627  if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1628  vector::moveScalarUniformCode(warpOp);
1629  }
1630  });
1631  }
1632  // Finally, do the SIMD to SIMT distribution.
1635  // TODO: distributionFn and shuffleFn are not used at this point.
1636  auto distributionFn = [](Value val) {
1637  VectorType vecType = dyn_cast<VectorType>(val.getType());
1638  int64_t vecRank = vecType ? vecType.getRank() : 0;
1639  OpBuilder builder(val.getContext());
1640  if (vecRank == 0)
1641  return AffineMap::get(val.getContext());
1642  return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
1643  };
1644  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1645  int64_t warpSz) { return Value(); };
1646  vector::populatePropagateWarpVectorDistributionPatterns(
1647  patterns, distributionFn, shuffleFn);
1648  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1649  signalPassFailure();
1650  return;
1651  }
1652 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
constexpr unsigned subgroupSize
HW dependent constants.
static const char *const operandLayoutNamePrefix
static const char *const resultLayoutNamePrefix
#define DBGS()
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
IndexType getIndexType()
Definition: Builders.cpp:51
The general data-flow analysis solver.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU SIMT distribution into patterns.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:21
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314