MLIR  22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1 //===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/Value.h"
25 #include "mlir/IR/Visitors.h"
28 #include "mlir/Support/LLVM.h"
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/raw_ostream.h"
38 
40 
41 namespace mlir {
42 namespace xegpu {
43 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 } // namespace xegpu
46 } // namespace mlir
47 
48 #define DEBUG_TYPE "xegpu-propagate-layout"
49 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50 
51 using namespace mlir;
52 using namespace mlir::dataflow;
53 
54 namespace {
55 
56 //===----------------------------------------------------------------------===//
57 // LayoutInfo
58 //===----------------------------------------------------------------------===//
59 
60 /// Helper class for tracking the analysis state of an mlir value. For layout
61 /// propagation, the analysis state is simply the distribution layout of
62 /// each value. The distribution layout information is encapsulated using
63 /// xegpu::DistributeLayoutAttr class which can hold information about any type
64 /// of distribution layout that XeGPU dialect supports. Purpose of this analysis
65 /// to propagate some unique distribution layout for each value in the program
66 /// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
67 /// that analysis will reach a fixed point when all values are reached some
68 /// layout and, analysis does not try to modify any already assigned layouts.
69 ///
70 /// Given this, LayoutInfo satisifies the following properties:
71 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
72 /// assigned`.
73 /// 2) Two LayoutInfo values are equal if they are both assigned or
74 /// both not assigned. The concrete value of assigned state does not matter.
75 /// 3) The meet operator works as follows:
76 /// - If current state is assigned, return the current state. (already
77 /// a unique layout is assigned. don't change it)
78 /// - Otherwise, return the other state.
79 
80 struct LayoutInfo {
81 private:
82  xegpu::DistributeLayoutAttr storage = nullptr;
83 
84 public:
85  LayoutInfo() = default;
86  LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
87 
88  // Two lattice values are equal if they have `some` layout. The actual
89  // content of the layout does not matter.
90  bool operator==(const LayoutInfo &other) const {
91  return this->isAssigned() == other.isAssigned();
92  }
93 
94  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
95 
96  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
97 
98  void print(raw_ostream &os) const;
99 
100  bool isAssigned() const { return storage != nullptr; }
101 
102  LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
103 
104  SmallVector<int> getLaneLayout() const;
105 
106  SmallVector<int> getLaneData() const;
107 
108  SmallVector<int> getInstData() const;
109 
110  bool isSliceLayout() const {
111  if (!isAssigned())
112  return false;
113  return isa<xegpu::SliceAttr>(storage);
114  }
115 
116  int64_t getRank() const {
117  if (!isAssigned())
118  return -1;
119  return storage.getRank();
120  }
121 
122  Attribute get() { return storage; }
123 };
124 
125 SmallVector<int> LayoutInfo::getLaneLayout() const {
126  if (!isAssigned())
127  return {};
128  assert(storage.getEffectiveLaneLayoutAsInt().size() &&
129  "Expected lane layout to be assigned");
130  return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
131  [](int64_t val) { return static_cast<int>(val); });
132 }
133 
134 SmallVector<int> LayoutInfo::getLaneData() const {
135  if (!isAssigned())
136  return {};
137  assert(storage.getEffectiveLaneDataAsInt().size() &&
138  "Expected lane data to be assigned");
139  return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
140  [](int64_t val) { return static_cast<int>(val); });
141 }
142 
143 SmallVector<int> LayoutInfo::getInstData() const {
144  if (!isAssigned())
145  return {};
146  return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
147  [](int64_t val) { return static_cast<int>(val); });
148 }
149 
150 void LayoutInfo::print(raw_ostream &os) const {
151  if (isAssigned()) {
152  os << storage;
153  } else {
154  os << "Not assigned.";
155  }
156 }
157 
158 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
159  if (!lhs.isAssigned())
160  return rhs;
161  return lhs;
162 }
163 
164 /// Since this is a backward analysis, join method is not used.
165 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
166  llvm_unreachable("Join should not be triggered by layout propagation.");
167 }
168 
169 /// Construct a new layout with the transposed lane layout and lane data.
170 LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
171  if (!isAssigned())
172  return {};
173  // Check if the permutation is valid.
174  llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
175  bool hasDuplicates = seen.size() != permutation.size();
176  bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
177  return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
178  });
179 
180  if (!withinRange || hasDuplicates) {
181  assert(false && "Invalid permutation for transpose.");
182  return {};
183  }
184 
185  SmallVector<int32_t> laneLayout;
186  SmallVector<int32_t> laneData;
187  SmallVector<int32_t> instData;
188  for (int64_t idx : permutation) {
189  laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
190  laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
191  instData.push_back(static_cast<int32_t>(getInstData()[idx]));
192  }
193  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
194  laneLayout, laneData));
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // LayoutInfoLattice
199 //===----------------------------------------------------------------------===//
200 
201 /// Lattice holding the LayoutInfo for each value.
202 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
204  using Lattice::Lattice;
205 };
206 
207 /// Helper Function to find a proper instruction multiple for the user-supplied
208 /// sg-level data shape. `candidates` are uArch allowed shapes.
209 /// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
210 template <typename T>
211 int getLargestDivisor(T dim, ArrayRef<T> candidates,
212  ArrayRef<T> candidateMultiples = {}) {
213  static_assert(std::is_integral<T>::value, "T must be an integer type");
214  int largest = -1;
215  SmallVector<T> multiples = {1};
216  if (!candidateMultiples.empty())
217  multiples =
218  SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
219  for (T candidate : candidates) {
220  for (T multiple : multiples) {
221  int value = static_cast<int>(candidate * multiple);
222  if (value != 0 && dim % value == 0 && value > largest)
223  largest = value;
224  }
225  }
226  return largest;
227 }
228 
229 /// Helper Functions to get default layouts. A `default layout` is a layout that
230 /// is assigned to a value when the layout is not fixed by some anchor operation
231 /// (like DPAS).
232 
233 /// Helper Function to get the default layout for uniform values like constants.
234 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
235 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
236 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
237  unsigned rank,
238  const xegpu::uArch::uArch *uArch,
239  ArrayRef<int> instData) {
240  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
241  if (rank == 1) {
242  return LayoutInfo(
243  xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
244  }
245  return LayoutInfo(xegpu::LayoutAttr::get(
246  ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
247 }
248 
249 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
250  unsigned rank, int subgroupSize) {
251  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
252  if (rank == 1) {
253  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
254  }
255  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
256 }
257 
258 /// Helper to get the default layout for a vector type.
259 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
260  const xegpu::uArch::uArch *uArch,
261  ArrayRef<int> instData,
262  unsigned packingSize,
263  bool isScattered = false) {
264  // Expecting a 1D or 2D vector.
265  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
266  "Expected 1D or 2D vector.");
267  // Expecting int or float element type.
268  assert(vectorTy.getElementType().isIntOrFloat() &&
269  "Expected int or float element type.");
270  // If the rank is 1, then return default layout for 1D vector.
271  if (vectorTy.getRank() == 1)
272  return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
273  // Packing factor is determined by the element type bitwidth.
274  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
275  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
276  if (isScattered) {
277  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
278  {uArch->getSubgroupSize(), 1},
279  {1, packingFactor}));
280  }
281  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
282  {1, uArch->getSubgroupSize()},
283  {1, packingFactor}));
284 }
285 
286 /// Helper to get the default layout for a vector type.
287 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
288  const xegpu::uArch::uArch *uArch,
289  ArrayRef<int> instData,
290  unsigned packingSize,
291  bool isScattered = false) {
292  // Expecting a 1D or 2D vector.
293  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
294  "Expected 1D or 2D TensorDesc.");
295  // Expecting int or float element type.
296  assert(tdescTy.getElementType().isIntOrFloat() &&
297  "Expected int or float element type.");
298  // If the rank is 1, then return default layout for 1D vector.
299  if (tdescTy.getRank() == 1)
300  return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
301  // Packing factor is determined by the element type bitwidth.
302  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
303  int subgroupSize = uArch->getSubgroupSize();
304  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
305  if (isScattered) {
306  return LayoutInfo(xegpu::LayoutAttr::get(
307  tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
308  }
309 
310  return LayoutInfo(xegpu::LayoutAttr::get(
311  tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
312 }
313 
314 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
315 /// is set according to the following criteria:
316 /// * For A operand, the data must be packed in minimum
317 /// `packedSizeInBitsForDefault`
318 /// * For B operand, the data must be packed in minimum
319 /// `packedSizeInBitsForDpasB`
320 static LayoutInfo
321 getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
322  const xegpu::uArch::uArch *uArch,
323  ArrayRef<int> instData, unsigned packingSize) {
324  Type elementTy = vectorTy.getElementType();
325  assert(elementTy.isIntOrFloat() &&
326  "Expected int or float type in DPAS operands");
328  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
329  // must have the VNNI format.
330  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
332  {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
333  1});
334  return LayoutInfo(
335  xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
336  }
337  // Otherwise, return the default layout for the vector type.
338  return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // LayoutInfoPropagation
343 //===----------------------------------------------------------------------===//
344 
345 /// Backward data flow analysis to propagate the lane_layout and lane_data of
346 /// each value in the program. Currently, the layouts for operands DPAS,
347 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
348 /// this analysis is to propagate those known layouts to all their producers and
349 /// (other) consumers.
350 class LayoutInfoPropagation
351  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
352 private:
353  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
355 
356  void visitStoreNdOp(xegpu::StoreNdOp store,
359 
360  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
363 
364  void visitLoadNdOp(xegpu::LoadNdOp load,
367 
368  void visitLoadGatherOp(xegpu::LoadGatherOp load,
371 
372  void visitTransposeOp(vector::TransposeOp transpose,
375 
376  void visitVectorBitcastOp(vector::BitCastOp bitcast,
379 
380  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
383 
384  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
387 
388  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
391 
392  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
395 
396  void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
399  void visitShapeCastOp(vector::ShapeCastOp shapeCast,
402 
403 public:
404  LayoutInfoPropagation(DataFlowSolver &solver,
405  SymbolTableCollection &symbolTable)
406  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
408 
409  LogicalResult
410  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
411  ArrayRef<const LayoutInfoLattice *> results) override;
412 
413  void visitBranchOperand(OpOperand &operand) override {};
414 
415  void visitCallOperand(OpOperand &operand) override {};
416 
417  void visitExternalCall(CallOpInterface call,
419  ArrayRef<const LayoutInfoLattice *> results) override {
420  };
421 
422  void setToExitState(LayoutInfoLattice *lattice) override {
423  (void)lattice->meet(LayoutInfo());
424  }
425 };
426 } // namespace
427 
428 LogicalResult LayoutInfoPropagation::visitOperation(
432  .Case<xegpu::DpasOp>(
433  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
434  .Case<xegpu::StoreNdOp>(
435  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
436  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
437  visitStoreScatterOp(storeScatterOp, operands, results);
438  })
439  .Case<xegpu::LoadNdOp>(
440  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
441  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
442  visitLoadGatherOp(loadGatherOp, operands, results);
443  })
444  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
445  visitCreateDescOp(createDescOp, operands, results);
446  })
447  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
448  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
449  })
450  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
451  visitPrefetchNdOp(prefetchNdOp, operands, results);
452  })
453  .Case<vector::TransposeOp>([&](auto transposeOp) {
454  visitTransposeOp(transposeOp, operands, results);
455  })
456  .Case<vector::BitCastOp>([&](auto bitcastOp) {
457  visitVectorBitcastOp(bitcastOp, operands, results);
458  })
459  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
460  visitVectorMultiReductionOp(reductionOp, operands, results);
461  })
462  .Case<vector::BroadcastOp>([&](auto broadcastOp) {
463  visitVectorBroadCastOp(broadcastOp, operands, results);
464  })
465  .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
466  visitShapeCastOp(shapeCastOp, operands, results);
467  })
468  // All other ops.
469  .Default([&](Operation *op) {
470  for (const LayoutInfoLattice *resultInfo : results) {
471  if (!resultInfo->getValue().isAssigned())
472  continue;
473  for (auto [operandInfo, operand] :
474  llvm::zip(operands, op->getOpOperands())) {
475  // If the operand type is not a vector or tensor descriptor, skip
476  // it.
477  if (!isa<xegpu::TensorDescType, VectorType>(
478  operand.get().getType()))
479  continue;
480  // Propagate the result layout to the operand.
481  meet(operandInfo, *resultInfo);
482  }
483  }
484  });
485 
486  return success();
487 }
488 
489 void LayoutInfoPropagation::visitPrefetchNdOp(
490  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
492  // Here we assign the default layout to the tensor descriptor operand of
493  // prefetch.
494  auto tdescTy = prefetch.getTensorDescType();
495 
496  auto uArch = getUArch(getChipStr(prefetch).value_or(""));
497  const auto *uArchInstruction =
498  dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
501 
502  auto blockWHC =
503  uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
504  if (!blockWHC)
505  prefetch.emitWarning("No known block params found for the element type.");
506  auto [bWidth, bHeight, bCount] = blockWHC.value();
507  SmallVector<int> instData;
508  int instWidth = getLargestDivisor(
509  static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
510  bCount);
511  if (instWidth == -1)
512  prefetch.emitWarning(
513  "No suitable instruction multiple found for the given shape.");
514  if (tdescTy.getRank() == 1)
515  instData = {instWidth};
516  else {
517  int instHeight = getLargestDivisor(
518  static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
519  if (instHeight == -1)
520  prefetch.emitWarning(
521  "No suitable instruction multiple found for the given shape.");
522  instData = {instHeight, instWidth};
523  }
524  auto prefetchLayout = getDefaultSIMTLayoutInfo(
525  tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
526  // Propagate the layout to the source tensor descriptor.
527  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
528 }
529 
530 void LayoutInfoPropagation::visitVectorMultiReductionOp(
531  vector::MultiDimReductionOp reduction,
534  // The layout of the result must be present.
535  LayoutInfo resultLayout = results[0]->getValue();
536  if (!resultLayout.isAssigned())
537  return;
538  // We only consider 2D -> 1D reductions at this point.
539  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
540  if (!resultTy || resultTy.getRank() != 1) {
541  reduction.emitWarning("Expecting output type to be 1D vector.");
542  return;
543  }
544  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
545  // Given that the result is 1D, the layout of the operand should be 2D with
546  // default layout.
547  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
548  reduction->getContext(), 2, uArch->getSubgroupSize());
549  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
550  // Accumulator should have the same layout as the result.
551  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
552 }
553 
554 void LayoutInfoPropagation::visitVectorBroadCastOp(
555  vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
557  // The layout of the result must be present.
558  LayoutInfo resultLayout = results[0]->getValue();
559  if (!resultLayout.isAssigned())
560  return;
561  // Only consider vector to vector broadcasts for now.
562  VectorType resultTy = broadcast.getResultVectorType();
563  VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
564  if (!sourceTy) {
565  broadcast.emitWarning("Expecting source type to be a vector type.");
566  return;
567  }
568 
569  // Only consider nD -> nD broadcast.
570  if (sourceTy.getRank() != resultTy.getRank()) {
571  broadcast.emitWarning("Expecting source and result to have same rank.");
572  return;
573  }
574  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
575  if (broadcastUnitDims.size() != 1) {
576  broadcast.emitWarning("Expecting source type to be nD vector only with "
577  "one broadcasted dimension.");
578  return;
579  }
580  // Propagate the result layout to the source operand.
581  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
582 }
583 
584 void LayoutInfoPropagation::visitShapeCastOp(
585  vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
587  // The layout of the result must be present.
588  LayoutInfo resultLayout = results[0]->getValue();
589  if (!resultLayout.isAssigned())
590  return;
591  VectorType sourceTy = shapeCast.getSourceVectorType();
592  VectorType resultTy = shapeCast.getResultVectorType();
593  // Shape cast layout propagation only supports 1D -> 2D shape casts.
594  // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
595  // unit dimensions and non-unit dims match.
596  if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
597  shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
598  return;
599  }
600  int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
601  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
602  shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
603  DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
604  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
605 }
606 
607 /// Propagate the layout of the result tensor to the source tensor descriptor
608 /// in UpdateNdOffsetOp.
609 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
610  xegpu::UpdateNdOffsetOp updateNdOffset,
613  // The layout of the result must be present.
614  LayoutInfo resultLayout = results[0]->getValue();
615  if (!resultLayout.isAssigned())
616  return;
617  // Propagate the layout to the source operand.
618  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
619 }
620 
621 /// Set the layouts for DPAS A, B, and C operands.
622 void LayoutInfoPropagation::visitDpasOp(
623  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
625  VectorType aTy = dpas.getLhsType();
626  VectorType bTy = dpas.getRhsType();
627 
628  auto uArch = getUArch(getChipStr(dpas).value_or(""));
629  const int subgroupSize = uArch->getSubgroupSize();
630  const auto *uArchInstruction =
631  dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
633 
634  const unsigned dataALen = aTy.getShape().front();
635  auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
636  const int maxALen =
637  getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
638  if (maxALen == -1)
639  dpas.emitWarning(
640  "No suitable instruction multiple found for the given shape.");
641 
642  const unsigned dataBLen = bTy.getShape().back();
643  auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
644  const int maxBLen =
645  getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
646  if (maxBLen == -1)
647  dpas.emitWarning(
648  "No suitable instruction multiple found for the given shape.");
649  SmallVector<int> instDataA = {maxALen, subgroupSize};
650  SmallVector<int> instDataB = {subgroupSize, maxBLen};
651 
652  propagateIfChanged(operands[0],
653  operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
654  aTy, 0, uArch, instDataA,
655  uArchInstruction->getPackedFormatBitSizeA())));
656  propagateIfChanged(operands[1],
657  operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
658  bTy, 1, uArch, instDataB,
659  uArchInstruction->getPackedFormatBitSizeB())));
660  if (operands.size() > 2) {
661  VectorType cTy = dpas.getAccType();
662  const unsigned dataCLen = bTy.getShape().back();
663  auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
664  const int maxCLen =
665  getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
666  if (maxCLen == -1)
667  dpas.emitWarning(
668  "No suitable instruction multiple found for the given shape.");
669  SmallVector<int> instDataC = {maxALen, maxCLen};
670  propagateIfChanged(operands[2],
671  operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
672  cTy, 2, uArch, instDataC,
673  uArchInstruction->getPackedFormatBitSizeB())));
674  }
675 }
676 
677 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
678 void LayoutInfoPropagation::visitStoreNdOp(
679  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
681 
682  auto uArch = getUArch(getChipStr(store).value_or(""));
683  const auto *uArchInstruction =
684  dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
687  VectorType dataTy = store.getValueType();
688  auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
689  store.getValueType().getElementType());
690  if (!blockWHC)
691  store.emitWarning("No known block params found for the element type.");
692  auto [bWidth, bHeight, bCount] = blockWHC.value();
693  SmallVector<int> instData;
694  int instWidth = getLargestDivisor(
695  static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
696  bCount);
697  if (instWidth == -1)
698  store.emitWarning(
699  "No suitable instruction multiple found for the given shape.");
700  if (dataTy.getRank() == 1)
701  instData = {instWidth};
702  else {
703  int instHeight = getLargestDivisor(
704  static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
705  if (instHeight == -1)
706  store.emitWarning(
707  "No suitable instruction multiple found for the given shape.");
708  instData = {instHeight, instWidth};
709  }
710  LayoutInfo storeLayout =
711  getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
712  uArchInstruction->getPackedFormatBitSize());
713  // Both operands should have the same layout
714  for (LayoutInfoLattice *operand : operands)
715  propagateIfChanged(operand, operand->meet(storeLayout));
716 }
717 
718 /// Propagate the layout of the value to the tensor descriptor operand in
719 /// LoadNdOp.
720 void LayoutInfoPropagation::visitLoadNdOp(
721  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
723  LayoutInfo valueLayout = results[0]->getValue();
724  // Need the layout of the value to propagate to the tensor descriptor.
725  if (!valueLayout.isAssigned())
726  return;
727  LayoutInfo tensorDescLayout = valueLayout;
728  // LoadNdOp has the transpose effect. However, at the stage of this analysis
729  // this effect is not expected and should be abstracted away. Emit a
730  // warning.
731  if (auto transpose = load.getTranspose()) {
732  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
733  "LayoutInfoPropagation stage.");
734  tensorDescLayout = valueLayout.transpose(transpose.value());
735  }
736  // Propagate the new layout to the tensor descriptor operand.
737  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
738 }
739 
740 /// For vector::TransposeOp, the layout of the result is transposed and
741 /// propagated to the operand.
742 void LayoutInfoPropagation::visitTransposeOp(
743  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
745  // Need the layout of transpose result to propagate to the operands.
746  LayoutInfo resultLayout = results[0]->getValue();
747  if (!resultLayout.isAssigned())
748  return;
749  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
750  // Propagate the new layout to the vector operand.
751  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
752 }
753 
754 /// For vector::BitCastOp, the lane_data of the source layout is changed based
755 /// on the bit width of the source and result types.
756 void LayoutInfoPropagation::visitVectorBitcastOp(
757  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
759  // Need the layout of bitcast result to propagate to the operands.
760  LayoutInfo resultLayout = results[0]->getValue();
761  if (!resultLayout.isAssigned())
762  return;
763  int inElemTyBitWidth =
764  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
765  int outElemTyBitWidth =
766  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
767  // If the element bit widths are the same, then the layout does not change.
768  if (inElemTyBitWidth == outElemTyBitWidth) {
769  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
770  return;
771  }
772  // Check if the result layout is valid. i.e. result vector can be distributed.
773  auto resultLaneLayout = resultLayout.getLaneLayout();
774  auto resultLaneData = resultLayout.getLaneData();
776  bitcast.getResultVectorType(),
777  xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
778  resultLaneData)))) {
779  bitcast.emitWarning(
780  "Result vector type can not be evenly distributed across lanes.");
781  return;
782  }
783  int64_t rank = bitcast.getSourceVectorType().getRank();
784  // Bitcast is a `narrowing` if the input element type bit width larger than
785  // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
786  bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
787  int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
788  : outElemTyBitWidth / inElemTyBitWidth;
789  SmallVector<int> sourceLaneLayout =
790  resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
791  SmallVector<int> outData = resultLayout.getLaneData();
792 
793  // TODO: Currently we assume that bitcasts does not require cross lane
794  // communication. So each lane must own the required number of elements to
795  // perform the bitcast locally without cross-lane communication.
796  int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
797  if (outInnerBitsPerLane < inElemTyBitWidth) {
798  bitcast.emitWarning(
799  "Narrowing bitcast with cross lane communication is not supported.");
800  return;
801  }
802  // Check if each lane owns a single element in all dimensions except the
803  // innermost dimension.
804  SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
805  if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
806  bitcast.emitWarning("Each lane must not own multiple elements in any "
807  "dimension other than "
808  "the innermost dimension.");
809  return;
810  }
811  // Decide lane data based on whether the bitcast is narrowing or widening.
812  int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
813  : outData[rank - 1] * bitCastRatio;
814  sourceLaneData.push_back(innerMostLaneData);
815 
816  propagateIfChanged(
817  operands[0],
818  operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
819  bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
820 }
821 
822 /// Propagate the layout of the result to the tensor descriptor, mask and offset
823 /// operands in LoadGatherOp.
824 void LayoutInfoPropagation::visitLoadGatherOp(
825  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
827  // The layout is strictly determined by the payload type.
828  auto payloadTy = dyn_cast<VectorType>(load.getValueType());
829  if (!payloadTy) {
830  load.emitWarning("Not propagating, non-vector payload supplied.");
831  return;
832  }
833  auto uArch = getUArch(getChipStr(load).value_or(""));
834  const int subgroupSize = uArch->getSubgroupSize();
835  SmallVector<int> instData{subgroupSize};
836  if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
837  instData.push_back(chunkSize);
838  else if (auto srcTdescTy =
839  dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
840  if (srcTdescTy.getChunkSizeAsInt() > 1)
841  instData.push_back(chunkSize);
842  }
843  LayoutInfo layout = getDefaultSIMTLayoutInfo(
844  payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
845  /*scattered*/ true);
846 
847  // Mask operand should have 1D default layout.
848  LayoutInfo maskLayout =
849  getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
850 
851  // Propagate the new layout to the tensor descriptor operand.
852  if (isa<xegpu::TensorDescType>(load.getSourceType()))
853  propagateIfChanged(operands[0], operands[0]->meet(layout));
854  // Propagate the new layout to the mask and optional offset operand.
855  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
856  if (load.getOffsets())
857  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
858 }
859 
860 /// Propagate the layout of the descriptor to the vector offset operand in
861 /// CreateDescOp.
862 void LayoutInfoPropagation::visitCreateDescOp(
863  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
865  LayoutInfo descLayout = results[0]->getValue();
866  // Need the layout of the descriptor to propagate to the operands.
867  if (!descLayout.isAssigned())
868  return;
869  auto uArch = getUArch(getChipStr(createDesc).value_or(""));
870  // For offset operand propagate 1D default layout.
871  LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
873  propagateIfChanged(operands[1], operands[1]->meet(layout));
874 }
875 
876 /// Set the layout for the value, tensor descriptor, offset and mask operands in
877 /// the StoreScatterOp.
878 void LayoutInfoPropagation::visitStoreScatterOp(
879  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
881  // Currently, for 2D StoreScatterOp we expect that the height dimension of
882  // the tensor descriptor is equal to the subgroup size. This is ensured by
883  // the op verifier.
884  auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
885  if (!payloadTy) {
886  storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
887  return;
888  }
889  auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
890  const int subgroupSize = uArch->getSubgroupSize();
891 
892  auto payloadShape = payloadTy.getShape();
893  if (payloadShape.size() > 1)
894  assert(
895  payloadShape[0] == subgroupSize &&
896  "Expected the first dimension of 2D tensor descriptor to be equal to "
897  "subgroup size.");
898 
899  SmallVector<int> instData{subgroupSize};
900  if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
901  instData.push_back(chunkSize);
902  else if (auto dstTdescTy =
903  dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
904  if (dstTdescTy.getChunkSizeAsInt() > 1)
905  instData.push_back(chunkSize);
906  }
907  LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
908  payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
909  /*scattered=*/true);
910 
911  LayoutInfo maskLayout =
912  getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
913  // Propagate the payload operand layout
914  propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
915  // Propagate the destination (if tdesc) operand layout
916  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
917  propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
918  // Propagate the new layout to the mask and optional offset operand.
919  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
920  if (storeScatter.getOffsets())
921  propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
922 }
923 
924 namespace {
925 //===----------------------------------------------------------------------===//
926 // RunLayoutInfoPropagation
927 //===----------------------------------------------------------------------===//
928 
929 /// Driver class for running the LayoutInfoPropagation analysis.
930 class RunLayoutInfoPropagation {
931 public:
932  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
933 
934  RunLayoutInfoPropagation(Operation *op) : target(op) {
935  SymbolTableCollection symbolTable;
936  loadBaselineAnalyses(solver);
937  solver.load<LayoutInfoPropagation>(symbolTable);
938  (void)solver.initializeAndRun(op);
939  }
940 
941  LayoutInfo getLayoutInfo(Value val);
942 
943  void printAnalysisResult(llvm::raw_ostream &os);
944 
945 private:
946  DataFlowSolver solver;
947  const Operation *target;
948 };
949 } // namespace
950 
951 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
952  auto *state = solver.lookupState<LayoutInfoLattice>(val);
953  if (!state)
954  return {};
955  return state->getValue();
956 }
957 
958 // Print the analysis result for debugging purposes.
959 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
960  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
961  os << "function: " << funcOp.getName() << ":\n";
962  // Function arguments
963  for (BlockArgument arg : funcOp.getArguments()) {
964  LayoutInfo layout = getLayoutInfo(arg);
965  os << "argument: " << arg << "\n";
966  os << "layout : ";
967  layout.print(os);
968  os << "\n";
969  }
970  // Function ops
971  funcOp.walk([&](Operation *op) {
972  // Skip ops that do not have results
973  if (op->getResults().empty())
974  return;
975  os << "op : ";
976  // For control-flow ops, print the op name only.
977  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
978  os << op->getName();
979  else
980  op->print(os);
981  os << "\n";
982  // Print the layout for each result.
983  for (auto [i, r] : llvm::enumerate(op->getResults())) {
984  LayoutInfo layout = getLayoutInfo(r);
985  os << "layout for result #" << i << ": ";
986  layout.print(os);
987  os << "\n";
988  }
989  });
990  };
991 
993  if (auto modOp = dyn_cast<ModuleOp>(target)) {
994  for (auto funcOp : modOp.getOps<FunctionOpInterface>())
995  funcOps.push_back(funcOp);
996 
997  // Collect all GpuFuncOps in the module.
998  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
999  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1000  funcOps.push_back(gpuFuncOp);
1001  }
1002  }
1003  // Print the analysis result for each function.
1004  for (FunctionOpInterface funcOp : funcOps)
1005  printFunctionResult(funcOp);
1006 }
1007 
1008 using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1009 /// Update an operation with the layout of its results. If the result type is
1010 /// a vector type, a temporary layout attribute is added to the operation. If
1011 /// the result type is a tensor descriptor type, the type is updated with the
1012 /// layout attribute. The users of the result are also updated with the layout
1013 /// attribute.
1014 static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1015  GetLayoutFnTy getLayoutOfValue) {
1016  // Region ops (like scf.for) are already handled by the
1017  // updateControlFlowOps.
1018  if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1019  return success();
1020 
1021  // Iterate over all the results.
1022  for (OpResult result : op->getResults()) {
1023  Type resultType = result.getType();
1024  // Layouts are needed only for vector and tensor descriptor types.
1025  if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1026  continue;
1027  // If the result has no layout but has users, emit a warning and continue.
1028  xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1029  if (!layout && result.getNumUses() > 0) {
1030  op->emitWarning("op has users but no layout assigned for its result");
1031  continue;
1032  }
1033  // If the result is a tensor descriptor type, update the tensor desc type
1034  // with layout.
1035  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1036  auto typeWithLayout = xegpu::TensorDescType::get(
1037  tensorDescTy.getContext(), tensorDescTy.getShape(),
1038  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1039  result.setType(typeWithLayout);
1040  continue;
1041  }
1042  // If the result is a vector type, add a temporary layout attribute to the
1043  // op.
1044  xegpu::setDistributeLayoutAttr(result, layout);
1045  }
1046  return success();
1047 }
1048 
1049 /// Region ops like scf.for need special handling because they have blocks
1050 /// inside. If the blocks have tensor descriptor type as block arguments,
1051 /// thier types must be updated. Also region op can have results that may not
1052 /// have any users (e.g. A and B tiles). They are not assigned a layout by
1053 /// layout analysis because they have no users. However inside the region op
1054 /// corresponding block arguments for these results do have layouts.
1055 /// Therefore, in this case we still need to update the result types with the
1056 /// layout attribute. This function function updates the internal block
1057 /// arguments and the result types of the region op with the assigned layouts.
1058 /// clang-format off
1059 /// Example: scf.for ... iter_args(...) -> (out types) {
1060 /// ^bb0(block types):
1061 /// ...
1062 /// scf.yield ... : (yield types)
1063 /// }
1064 /// clang-format on
1065 /// In this example, at scf.yield, control-flow can transfer to two successor
1066 /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1067 /// itself (yield the results). So we update both the block arguments of the
1068 /// successor region (i.e. block types) and the result types of the scf.for op
1069 /// (i.e. out types). Note that yield types are updated by respective
1070 /// producers inside bb0.
1071 static LogicalResult
1073  mlir::RegionBranchTerminatorOpInterface terminator,
1074  GetLayoutFnTy getLayoutOfValue) {
1075  // Only process if the terminator is inside a region branch op.
1076  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1077  return success();
1078 
1080  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
1081  nullptr);
1082  terminator.getSuccessorRegions(operands, successors);
1083 
1084  for (mlir::RegionSuccessor &successor : successors) {
1085  mlir::OperandRange successorOperands =
1086  terminator.getSuccessorOperands(successor);
1087  mlir::ValueRange successorInputs = successor.getSuccessorInputs();
1088  for (auto [successorOperand, successorInput] :
1089  llvm::zip(successorOperands, successorInputs)) {
1090  Type inputType = successorInput.getType();
1091  // We only need to operate on tensor descriptor or vector types.
1092  if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1093  continue;
1094  xegpu::DistributeLayoutAttr successorInputLayout =
1095  getLayoutOfValue(successorInput);
1096  xegpu::DistributeLayoutAttr successorOperandLayout =
1097  getLayoutOfValue(successorOperand);
1098 
1099  // If either of the layouts is not assigned, we cannot proceed.
1100  if (!successorOperandLayout) {
1101  LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1102  "branch terminator: "
1103  << successorOperand << "\n");
1104  return failure();
1105  }
1106  // We expect the layouts to match.
1107  if (successorInputLayout &&
1108  successorInputLayout != successorOperandLayout) {
1109  LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1110  "operand forwarded as the argument: "
1111  << successorInputLayout << " vs "
1112  << successorOperandLayout << "\n");
1113  return failure();
1114  }
1115  // Get tensor descriptor type with the layout.
1116  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1117  auto newTdescTy = xegpu::TensorDescType::get(
1118  tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1119  tdescTy.getEncoding(), successorOperandLayout);
1120  successorInput.setType(newTdescTy);
1121  continue;
1122  }
1123  // If the type is a vector type and this region argument is an OpResult,
1124  // set the layout attribute on the OpResult.
1125  if (auto result = dyn_cast<OpResult>(successorInput))
1126  xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1127  }
1128  }
1129  return success();
1130 }
1131 
1132 /// Update the function arguments and results with the layouts.
1133 static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1134  mlir::FunctionOpInterface funcOp,
1135  GetLayoutFnTy getLayoutOfValue) {
1136  SmallVector<Type> newArgTypes;
1137  // Update the function arguments.
1138  for (BlockArgument arg : funcOp.getArguments()) {
1139  Type argType = arg.getType();
1140  newArgTypes.push_back(argType);
1141  if (!isa<VectorType, xegpu::TensorDescType>(argType))
1142  continue;
1143  xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1144  if (!layout) {
1145  LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1146  << " but got none.\n");
1147  return failure();
1148  }
1149  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1150  auto newTdescTy = xegpu::TensorDescType::get(
1151  tensorDescTy.getContext(), tensorDescTy.getShape(),
1152  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1153  arg.setType(newTdescTy);
1154  newArgTypes.back() = newTdescTy;
1155  }
1156  }
1157  // Update the function type with the new argument types.
1158  // NOTE: We assume that function results are not expected to have layouts.
1159  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1160  funcOp.getResultTypes()));
1161  return success();
1162 }
1163 
1164 namespace {
1165 struct XeGPUPropagateLayoutPass final
1166  : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1167  XeGPUPropagateLayoutPass() = default;
1168  XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1169  XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1170  : XeGPUPropagateLayoutBase(options) {}
1171  void runOnOperation() override;
1172 };
1173 
1174 } // namespace
1175 
1176 void XeGPUPropagateLayoutPass::runOnOperation() {
1177  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
1178  // Print the analysis result and exit. (for debugging purposes)
1179  if (printOnly) {
1180  auto &os = llvm::outs();
1181  analysis.printAnalysisResult(os);
1182  return;
1183  }
1184  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1185  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1186  LayoutInfo layout = analysis.getLayoutInfo(val);
1187  if (!layout.isAssigned())
1188  return {};
1189  xegpu::DistributeLayoutAttr layoutAttr =
1190  cast<xegpu::DistributeLayoutAttr>(layout.get());
1191  if (this->layoutKind == "lane")
1192  layoutAttr = layoutAttr.dropInstData();
1193  if (layout.isSliceLayout())
1194  return cast<xegpu::SliceAttr>(layoutAttr);
1195  return cast<xegpu::LayoutAttr>(layoutAttr);
1196  };
1197 
1198  mlir::OpBuilder builder(&getContext());
1199  Operation *op = getOperation();
1200  auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1201  for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1202  LogicalResult r = success();
1203  TypeSwitch<Operation *>(&op)
1204  .Case<mlir::RegionBranchTerminatorOpInterface>(
1205  [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1206  r = updateControlFlowOps(builder, branchTermOp,
1207  getXeGPULayoutForValue);
1208  })
1209  .Case<mlir::FunctionOpInterface>(
1210  [&](mlir::FunctionOpInterface funcOp) {
1211  r = updateFunctionOpInterface(builder, funcOp,
1212  getXeGPULayoutForValue);
1213  })
1214  .Default([&](Operation *op) {
1215  r = updateOp(builder, op, getXeGPULayoutForValue);
1216  });
1217  if (failed(r)) {
1218  op.emitError("Failed to update operation with the layout.");
1219  return WalkResult::interrupt();
1220  }
1221  }
1222  return WalkResult::advance();
1223  });
1224  if (walkResult.wasInterrupted()) {
1225  signalPassFailure();
1226  return;
1227  }
1228 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
#define DBGS()
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
The general data-flow analysis solver.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:457
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition: Utils.h:29
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:567
const uArch * getUArch(llvm::StringRef archName)
Definition: IntelGpuXe2.h:268
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:178
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:448
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:40
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
const Instruction * getInstruction(InstructionKind instKind) const
Definition: uArchBase.h:157
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0