MLIR  22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1 //===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/IR/Visitors.h"
29 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallSet.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/raw_ostream.h"
39 
40 namespace mlir {
41 namespace xegpu {
42 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
43 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
44 } // namespace xegpu
45 } // namespace mlir
46 
47 #define DEBUG_TYPE "xegpu-propagate-layout"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49 
50 using namespace mlir;
51 using namespace mlir::dataflow;
52 
53 namespace {
54 
55 //===----------------------------------------------------------------------===//
56 // LayoutInfo
57 //===----------------------------------------------------------------------===//
58 
59 /// Helper class for tracking the analysis state of an mlir value. For layout
60 /// propagation, the analysis state is simply the distribution layout of
61 /// each value. The distribution layout information is encapsulated using
62 /// xegpu::DistributeLayoutAttr class which can hold information about any type
63 /// of distribution layout that XeGPU dialect supports. Purpose of this analysis
64 /// to propagate some unique distribution layout for each value in the program
65 /// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
66 /// that analysis will reach a fixed point when all values are reached some
67 /// layout and, analysis does not try to modify any already assigned layouts.
68 ///
69 /// Given this, LayoutInfo satisifies the following properties:
70 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
71 /// assigned`.
72 /// 2) Two LayoutInfo values are equal if they are both assigned or
73 /// both not assigned. The concrete value of assigned state does not matter.
74 /// 3) The meet operator works as follows:
75 /// - If current state is assigned, return the current state. (already
76 /// a unique layout is assigned. don't change it)
77 /// - Otherwise, return the other state.
78 
79 struct LayoutInfo {
80 private:
81  xegpu::DistributeLayoutAttr storage = nullptr;
82 
83 public:
84  LayoutInfo() = default;
85  LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
86 
87  // Two lattice values are equal if they have `some` layout. The actual
88  // content of the layout does not matter.
89  bool operator==(const LayoutInfo &other) const {
90  return this->isAssigned() == other.isAssigned();
91  }
92 
93  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
94 
95  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
96 
97  void print(raw_ostream &os) const;
98 
99  bool isAssigned() const { return storage != nullptr; }
100 
101  LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
102 
103  SmallVector<int> getLaneLayout() const;
104 
105  SmallVector<int> getLaneData() const;
106 
107  bool isSliceLayout() const {
108  if (!isAssigned())
109  return false;
110  return isa<xegpu::SliceAttr>(storage);
111  }
112 
113  int64_t getRank() const {
114  if (!isAssigned())
115  return -1;
116  return storage.getRank();
117  }
118 
119  Attribute get() { return storage; }
120 };
121 
122 SmallVector<int> LayoutInfo::getLaneLayout() const {
123  if (!isAssigned())
124  return {};
125  assert(storage.getEffectiveLaneLayoutAsInt().size() &&
126  "Expected lane layout to be assigned");
127  return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
128  [](int64_t val) { return static_cast<int>(val); });
129 }
130 
131 SmallVector<int> LayoutInfo::getLaneData() const {
132  if (!isAssigned())
133  return {};
134  assert(storage.getEffectiveLaneDataAsInt().size() &&
135  "Expected lane data to be assigned");
136  return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
137  [](int64_t val) { return static_cast<int>(val); });
138 }
139 
140 void LayoutInfo::print(raw_ostream &os) const {
141  if (isAssigned()) {
142  os << storage;
143  } else {
144  os << "Not assigned.";
145  }
146 }
147 
148 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
149  if (!lhs.isAssigned())
150  return rhs;
151  return lhs;
152 }
153 
154 /// Since this is a backward analysis, join method is not used.
155 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
156  llvm_unreachable("Join should not be triggered by layout propagation.");
157 }
158 
159 /// Construct a new layout with the transposed lane layout and lane data.
160 LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
161  if (!isAssigned())
162  return {};
163  // Check if the permutation is valid.
164  llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
165  bool hasDuplicates = seen.size() != permutation.size();
166  bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
167  return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
168  });
169 
170  if (!withinRange || hasDuplicates) {
171  assert(false && "Invalid permutation for transpose.");
172  return {};
173  }
174 
175  SmallVector<int32_t> laneLayout;
176  SmallVector<int32_t> laneData;
177  for (int64_t idx : permutation) {
178  laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
179  laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
180  }
181  return LayoutInfo(
182  xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // LayoutInfoLattice
187 //===----------------------------------------------------------------------===//
188 
189 /// Lattice holding the LayoutInfo for each value.
190 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
192  using Lattice::Lattice;
193 };
194 
195 /// Helper Functions to get default layouts. A `default layout` is a layout that
196 /// is assigned to a value when the layout is not fixed by some anchor operation
197 /// (like DPAS).
198 
199 /// Helper Function to get the default layout for uniform values like constants.
200 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
201 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
202 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
203  unsigned rank) {
204  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
205  if (rank == 1) {
206  return LayoutInfo(
208  }
209  return LayoutInfo(xegpu::LayoutAttr::get(
210  ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
211 }
212 
213 /// Helper to get the default layout for a vector type.
214 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
215  bool isScattered = false) {
216  // Expecting a 1D or 2D vector.
217  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
218  "Expected 1D or 2D vector.");
219  // Expecting int or float element type.
220  assert(vectorTy.getElementType().isIntOrFloat() &&
221  "Expected int or float element type.");
222  // If the rank is 1, then return default layout for 1D vector.
223  if (vectorTy.getRank() == 1)
224  return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
225  // Packing factor is determined by the element type bitwidth.
226  int packingFactor = 1;
227  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
228  if (isScattered) {
229  packingFactor =
232  : 1;
233  return LayoutInfo(xegpu::LayoutAttr::get(
234  vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
235  {1, packingFactor}));
236  }
238  packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
239  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
240  {1, xegpu::targetinfo::subgroupSize},
241  {1, packingFactor}));
242 }
243 
244 /// Helper to get the default layout for a vector type.
245 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
246  bool isScattered = false) {
247  // Expecting a 1D or 2D vector.
248  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
249  "Expected 1D or 2D TensorDesc.");
250  // Expecting int or float element type.
251  assert(tdescTy.getElementType().isIntOrFloat() &&
252  "Expected int or float element type.");
253  // If the rank is 1, then return default layout for 1D vector.
254  if (tdescTy.getRank() == 1)
255  return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
256  // Packing factor is determined by the element type bitwidth.
257  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
258 
259  if (isScattered) {
260  int packingFactor =
263  : 1;
264  return LayoutInfo(xegpu::LayoutAttr::get(
265  tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
266  {1, packingFactor}));
267  }
268 
269  int packingFactor =
272  : 1;
273  return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(),
274  {1, xegpu::targetinfo::subgroupSize},
275  {1, packingFactor}));
276 }
277 
278 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
279 /// is set according to the following criteria:
280 /// * For A operand, the data must be packed in minimum
281 /// `packedSizeInBitsForDefault`
282 /// * For B operand, the data must be packed in minimum
283 /// `packedSizeInBitsForDpasB`
284 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
285  unsigned operandNum) {
286  Type elementTy = vectorTy.getElementType();
287  assert(elementTy.isIntOrFloat() &&
288  "Expected int or float type in DPAS operands");
290  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
291  // must have the VNNI format.
292  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
295  {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB /
296  elementTy.getIntOrFloatBitWidth()),
297  1});
298  return LayoutInfo(
299  xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
300  }
301  // Otherwise, return the default layout for the vector type.
302  return getDefaultSIMTLayoutInfo(vectorTy);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // LayoutInfoPropagation
307 //===----------------------------------------------------------------------===//
308 
309 /// Backward data flow analysis to propagate the lane_layout and lane_data of
310 /// each value in the program. Currently, the layouts for operands DPAS,
311 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
312 /// this analysis is to propagate those known layouts to all their producers and
313 /// (other) consumers.
314 class LayoutInfoPropagation
315  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
316 private:
317  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
319 
320  void visitStoreNdOp(xegpu::StoreNdOp store,
323 
324  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
327 
328  void visitLoadNdOp(xegpu::LoadNdOp load,
331 
332  void visitLoadGatherOp(xegpu::LoadGatherOp load,
335 
336  void visitTransposeOp(vector::TransposeOp transpose,
339 
340  void visitVectorBitcastOp(vector::BitCastOp bitcast,
343 
344  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
347 
348  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
351 
352  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
355 
356  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
359 
360  void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
363  void visitShapeCastOp(vector::ShapeCastOp shapeCast,
366 
367 public:
368  LayoutInfoPropagation(DataFlowSolver &solver,
369  SymbolTableCollection &symbolTable)
370  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
372 
373  LogicalResult
374  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
375  ArrayRef<const LayoutInfoLattice *> results) override;
376 
377  void visitBranchOperand(OpOperand &operand) override {};
378 
379  void visitCallOperand(OpOperand &operand) override {};
380 
381  void visitExternalCall(CallOpInterface call,
383  ArrayRef<const LayoutInfoLattice *> results) override {
384  };
385 
386  void setToExitState(LayoutInfoLattice *lattice) override {
387  (void)lattice->meet(LayoutInfo());
388  }
389 };
390 } // namespace
391 
392 LogicalResult LayoutInfoPropagation::visitOperation(
396  .Case<xegpu::DpasOp>(
397  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
398  .Case<xegpu::StoreNdOp>(
399  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
400  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
401  visitStoreScatterOp(storeScatterOp, operands, results);
402  })
403  .Case<xegpu::LoadNdOp>(
404  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
405  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
406  visitLoadGatherOp(loadGatherOp, operands, results);
407  })
408  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
409  visitCreateDescOp(createDescOp, operands, results);
410  })
411  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
412  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
413  })
414  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
415  visitPrefetchNdOp(prefetchNdOp, operands, results);
416  })
417  .Case<vector::TransposeOp>([&](auto transposeOp) {
418  visitTransposeOp(transposeOp, operands, results);
419  })
420  .Case<vector::BitCastOp>([&](auto bitcastOp) {
421  visitVectorBitcastOp(bitcastOp, operands, results);
422  })
423  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
424  visitVectorMultiReductionOp(reductionOp, operands, results);
425  })
426  .Case<vector::BroadcastOp>([&](auto broadcastOp) {
427  visitVectorBroadCastOp(broadcastOp, operands, results);
428  })
429  .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
430  visitShapeCastOp(shapeCastOp, operands, results);
431  })
432  // All other ops.
433  .Default([&](Operation *op) {
434  for (const LayoutInfoLattice *resultInfo : results) {
435  if (!resultInfo->getValue().isAssigned())
436  continue;
437  for (auto [operandInfo, operand] :
438  llvm::zip(operands, op->getOpOperands())) {
439  // If the operand type is not a vector or tensor descriptor, skip
440  // it.
441  if (!isa<xegpu::TensorDescType, VectorType>(
442  operand.get().getType()))
443  continue;
444  // Propagate the result layout to the operand.
445  meet(operandInfo, *resultInfo);
446  }
447  }
448  });
449 
450  return success();
451 }
452 
453 void LayoutInfoPropagation::visitPrefetchNdOp(
454  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
456  // Here we assign the default layout to the tensor descriptor operand of
457  // prefetch.
458  auto tdescTy = prefetch.getTensorDescType();
459  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
460  // Propagate the layout to the source tensor descriptor.
461  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
462 }
463 
464 void LayoutInfoPropagation::visitVectorMultiReductionOp(
465  vector::MultiDimReductionOp reduction,
468  // The layout of the result must be present.
469  LayoutInfo resultLayout = results[0]->getValue();
470  if (!resultLayout.isAssigned())
471  return;
472  // We only consider 2D -> 1D reductions at this point.
473  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
474  if (!resultTy || resultTy.getRank() != 1) {
475  reduction.emitWarning("Expecting output type to be 1D vector.");
476  return;
477  }
478  // Given that the result is 1D, the layout of the operand should be 2D with
479  // default layout.
480  LayoutInfo operandLayout =
481  getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
482  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
483  // Accumulator should have the same layout as the result.
484  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
485 }
486 
487 void LayoutInfoPropagation::visitVectorBroadCastOp(
488  vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
490  // The layout of the result must be present.
491  LayoutInfo resultLayout = results[0]->getValue();
492  if (!resultLayout.isAssigned())
493  return;
494  // Only consider vector to vector broadcasts for now.
495  VectorType resultTy = broadcast.getResultVectorType();
496  VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
497  if (!sourceTy) {
498  broadcast.emitWarning("Expecting source type to be a vector type.");
499  return;
500  }
501 
502  // Only consider nD -> nD broadcast.
503  if (sourceTy.getRank() != resultTy.getRank()) {
504  broadcast.emitWarning("Expecting source and result to have same rank.");
505  return;
506  }
507  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
508  if (broadcastUnitDims.size() != 1) {
509  broadcast.emitWarning("Expecting source type to be nD vector only with "
510  "one broadcasted dimension.");
511  return;
512  }
513  // Propagate the result layout to the source operand.
514  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
515 }
516 
517 void LayoutInfoPropagation::visitShapeCastOp(
518  vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
520  // The layout of the result must be present.
521  LayoutInfo resultLayout = results[0]->getValue();
522  if (!resultLayout.isAssigned())
523  return;
524  VectorType sourceTy = shapeCast.getSourceVectorType();
525  VectorType resultTy = shapeCast.getResultVectorType();
526  // Shape cast layout propagation only supports 1D -> 2D shape casts.
527  // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
528  // unit dimensions and non-unit dims match.
529  if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
530  shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
531  return;
532  }
533  int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
534  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
535  shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
536  DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
537  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
538 }
539 
540 /// Propagate the layout of the result tensor to the source tensor descriptor
541 /// in UpdateNdOffsetOp.
542 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
543  xegpu::UpdateNdOffsetOp updateNdOffset,
546  // The layout of the result must be present.
547  LayoutInfo resultLayout = results[0]->getValue();
548  if (!resultLayout.isAssigned())
549  return;
550  // Propagate the layout to the source operand.
551  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
552 }
553 
554 /// Set the layouts for DPAS A, B, and C operands.
555 void LayoutInfoPropagation::visitDpasOp(
556  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
558  VectorType aTy = dpas.getLhsType();
559  VectorType bTy = dpas.getRhsType();
560  propagateIfChanged(
561  operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
562  propagateIfChanged(
563  operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
564  if (operands.size() > 2) {
565  VectorType cTy = dpas.getAccType();
566  propagateIfChanged(
567  operands[2],
568  operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
569  }
570 }
571 
572 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
573 void LayoutInfoPropagation::visitStoreNdOp(
574  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
576  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
577  // Both operands should have the same layout
578  for (LayoutInfoLattice *operand : operands)
579  propagateIfChanged(operand, operand->meet(storeLayout));
580 }
581 
582 /// Propagate the layout of the value to the tensor descriptor operand in
583 /// LoadNdOp.
584 void LayoutInfoPropagation::visitLoadNdOp(
585  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
587  LayoutInfo valueLayout = results[0]->getValue();
588  // Need the layout of the value to propagate to the tensor descriptor.
589  if (!valueLayout.isAssigned())
590  return;
591  LayoutInfo tensorDescLayout = valueLayout;
592  // LoadNdOp has the transpose effect. However, at the stage of this analysis
593  // this effect is not expected and should be abstracted away. Emit a
594  // warning.
595  if (auto transpose = load.getTranspose()) {
596  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
597  "LayoutInfoPropagation stage.");
598  tensorDescLayout = valueLayout.transpose(transpose.value());
599  }
600  // Propagate the new layout to the tensor descriptor operand.
601  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
602 }
603 
604 /// For vector::TransposeOp, the layout of the result is transposed and
605 /// propagated to the operand.
606 void LayoutInfoPropagation::visitTransposeOp(
607  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
609  // Need the layout of transpose result to propagate to the operands.
610  LayoutInfo resultLayout = results[0]->getValue();
611  if (!resultLayout.isAssigned())
612  return;
613  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
614  // Propagate the new layout to the vector operand.
615  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
616 }
617 
618 /// For vector::BitCastOp, the lane_data of the source layout is changed based
619 /// on the bit width of the source and result types.
620 void LayoutInfoPropagation::visitVectorBitcastOp(
621  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
623  // Need the layout of bitcast result to propagate to the operands.
624  LayoutInfo resultLayout = results[0]->getValue();
625  if (!resultLayout.isAssigned())
626  return;
627  int inElemTyBitWidth =
628  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
629  int outElemTyBitWidth =
630  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
631  // If the element bit widths are the same, then the layout does not change.
632  if (inElemTyBitWidth == outElemTyBitWidth) {
633  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
634  return;
635  }
636  // Check if the result layout is valid. i.e. result vector can be distributed.
637  auto resultLaneLayout = resultLayout.getLaneLayout();
638  auto resultLaneData = resultLayout.getLaneData();
640  bitcast.getResultVectorType(),
641  xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
642  resultLaneData)))) {
643  bitcast.emitWarning(
644  "Result vector type can not be evenly distributed across lanes.");
645  return;
646  }
647  int64_t rank = bitcast.getSourceVectorType().getRank();
648  // Bitcast is a `narrowing` if the input element type bit width larger than
649  // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
650  bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
651  int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
652  : outElemTyBitWidth / inElemTyBitWidth;
653  SmallVector<int> sourceLaneLayout =
654  resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
655  SmallVector<int> outData = resultLayout.getLaneData();
656 
657  // TODO: Currently we assume that bitcasts does not require cross lane
658  // communication. So each lane must own the required number of elements to
659  // perform the bitcast locally without cross-lane communication.
660  int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
661  if (outInnerBitsPerLane < inElemTyBitWidth) {
662  bitcast.emitWarning(
663  "Narrowing bitcast with cross lane communication is not supported.");
664  return;
665  }
666  // Check if each lane owns a single element in all dimensions except the
667  // innermost dimension.
668  SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
669  if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
670  bitcast.emitWarning("Each lane must not own multiple elements in any "
671  "dimension other than "
672  "the innermost dimension.");
673  return;
674  }
675  // Decide lane data based on whether the bitcast is narrowing or widening.
676  int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
677  : outData[rank - 1] * bitCastRatio;
678  sourceLaneData.push_back(innerMostLaneData);
679 
680  propagateIfChanged(
681  operands[0],
682  operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
683  bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
684 }
685 
686 /// Propagate the layout of the result to the tensor descriptor, mask and offset
687 /// operands in LoadGatherOp.
688 void LayoutInfoPropagation::visitLoadGatherOp(
689  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
691  // The layout is strictly determined by the payload type.
692  auto payloadTy = dyn_cast<VectorType>(load.getValueType());
693  if (!payloadTy) {
694  load.emitWarning("Not propagating, non-vector payload supplied.");
695  return;
696  }
697  LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
698 
699  // Mask operand should have 1D default layout.
700  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
701 
702  // Propagate the new layout to the tensor descriptor operand.
703  if (isa<xegpu::TensorDescType>(load.getSourceType()))
704  propagateIfChanged(operands[0], operands[0]->meet(layout));
705  // Propagate the new layout to the mask and optional offset operand.
706  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
707  if (load.getOffsets())
708  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
709 }
710 
711 /// Propagate the layout of the descriptor to the vector offset operand in
712 /// CreateDescOp.
713 void LayoutInfoPropagation::visitCreateDescOp(
714  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
716  LayoutInfo descLayout = results[0]->getValue();
717  // Need the layout of the descriptor to propagate to the operands.
718  if (!descLayout.isAssigned())
719  return;
720  // For offset operand propagate 1D default layout.
721  LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
722  propagateIfChanged(operands[1], operands[1]->meet(layout));
723 }
724 
725 /// Set the layout for the value, tensor descriptor, offset and mask operands in
726 /// the StoreScatterOp.
727 void LayoutInfoPropagation::visitStoreScatterOp(
728  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
730  // Currently, for 2D StoreScatterOp we expect that the height dimension of
731  // the tensor descriptor is equal to the subgroup size. This is ensured by
732  // the op verifier.
733  auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
734  if (!payloadTy) {
735  storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
736  return;
737  }
738  auto payloadShape = payloadTy.getShape();
739  if (payloadShape.size() > 1)
740  assert(
741  payloadShape[0] == xegpu::targetinfo::subgroupSize &&
742  "Expected the first dimension of 2D tensor descriptor to be equal to "
743  "subgroup size.");
744 
745  LayoutInfo payloadLayout =
746  getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
747 
748  LayoutInfo maskLayout =
749  getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
750  // Propagate the payload operand layout
751  propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
752  // Propagate the destination (if tdesc) operand layout
753  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
754  propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
755  // Propagate the new layout to the mask and optional offset operand.
756  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
757  if (storeScatter.getOffsets())
758  propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
759 }
760 
761 namespace {
762 //===----------------------------------------------------------------------===//
763 // RunLayoutInfoPropagation
764 //===----------------------------------------------------------------------===//
765 
766 /// Driver class for running the LayoutInfoPropagation analysis.
767 class RunLayoutInfoPropagation {
768 public:
769  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
770 
771  RunLayoutInfoPropagation(Operation *op) : target(op) {
772  SymbolTableCollection symbolTable;
773  loadBaselineAnalyses(solver);
774  solver.load<LayoutInfoPropagation>(symbolTable);
775  (void)solver.initializeAndRun(op);
776  }
777 
778  LayoutInfo getLayoutInfo(Value val);
779 
780  void printAnalysisResult(llvm::raw_ostream &os);
781 
782 private:
783  DataFlowSolver solver;
784  const Operation *target;
785 };
786 } // namespace
787 
788 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
789  auto *state = solver.lookupState<LayoutInfoLattice>(val);
790  if (!state)
791  return {};
792  return state->getValue();
793 }
794 
795 // Print the analysis result for debugging purposes.
796 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
797  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
798  os << "function: " << funcOp.getName() << ":\n";
799  // Function arguments
800  for (BlockArgument arg : funcOp.getArguments()) {
801  LayoutInfo layout = getLayoutInfo(arg);
802  os << "argument: " << arg << "\n";
803  os << "layout : ";
804  layout.print(os);
805  os << "\n";
806  }
807  // Function ops
808  funcOp.walk([&](Operation *op) {
809  // Skip ops that do not have results
810  if (op->getResults().empty())
811  return;
812  os << "op : ";
813  // For control-flow ops, print the op name only.
814  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
815  os << op->getName();
816  else
817  op->print(os);
818  os << "\n";
819  // Print the layout for each result.
820  for (auto [i, r] : llvm::enumerate(op->getResults())) {
821  LayoutInfo layout = getLayoutInfo(r);
822  os << "layout for result #" << i << ": ";
823  layout.print(os);
824  os << "\n";
825  }
826  });
827  };
828 
830  if (auto modOp = dyn_cast<ModuleOp>(target)) {
831  for (auto funcOp : modOp.getOps<FunctionOpInterface>())
832  funcOps.push_back(funcOp);
833 
834  // Collect all GpuFuncOps in the module.
835  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
836  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
837  funcOps.push_back(gpuFuncOp);
838  }
839  }
840  // Print the analysis result for each function.
841  for (FunctionOpInterface funcOp : funcOps)
842  printFunctionResult(funcOp);
843 }
844 
845 using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
846 /// Update an operation with the layout of its results. If the result type is
847 /// a vector type, a temporary layout attribute is added to the operation. If
848 /// the result type is a tensor descriptor type, the type is updated with the
849 /// layout attribute. The users of the result are also updated with the layout
850 /// attribute.
851 static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
852  GetLayoutFnTy getLayoutOfValue) {
853  // Region ops (like scf.for) are already handled by the
854  // updateControlFlowOps.
855  if (mlir::isa<mlir::RegionBranchOpInterface>(op))
856  return success();
857 
858  // Iterate over all the results.
859  for (OpResult result : op->getResults()) {
860  Type resultType = result.getType();
861  // Layouts are needed only for vector and tensor descriptor types.
862  if (!isa<VectorType, xegpu::TensorDescType>(resultType))
863  continue;
864  // If the result has no layout but has users, emit a warning and continue.
865  xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
866  if (!layout && result.getNumUses() > 0) {
867  op->emitWarning("op has users but no layout assigned for its result");
868  continue;
869  }
870  // If the result is a tensor descriptor type, update the tensor desc type
871  // with layout.
872  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
873  auto typeWithLayout = xegpu::TensorDescType::get(
874  tensorDescTy.getContext(), tensorDescTy.getShape(),
875  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
876  result.setType(typeWithLayout);
877  continue;
878  }
879  // If the result is a vector type, add a temporary layout attribute to the
880  // op.
881  xegpu::setDistributeLayoutAttr(result, layout);
882  }
883  return success();
884 }
885 
886 /// Region ops like scf.for need special handling because they have blocks
887 /// inside. If the blocks have tensor descriptor type as block arguments,
888 /// thier types must be updated. Also region op can have results that may not
889 /// have any users (e.g. A and B tiles). They are not assigned a layout by
890 /// layout analysis because they have no users. However inside the region op
891 /// corresponding block arguments for these results do have layouts.
892 /// Therefore, in this case we still need to update the result types with the
893 /// layout attribute. This function function updates the internal block
894 /// arguments and the result types of the region op with the assigned layouts.
895 /// clang-format off
896 /// Example: scf.for ... iter_args(...) -> (out types) {
897 /// ^bb0(block types):
898 /// ...
899 /// scf.yield ... : (yield types)
900 /// }
901 /// clang-format on
902 /// In this example, at scf.yield, control-flow can transfer to two successor
903 /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
904 /// itself (yield the results). So we update both the block arguments of the
905 /// successor region (i.e. block types) and the result types of the scf.for op
906 /// (i.e. out types). Note that yield types are updated by respective
907 /// producers inside bb0.
908 static LogicalResult
910  mlir::RegionBranchTerminatorOpInterface terminator,
911  GetLayoutFnTy getLayoutOfValue) {
912  // Only process if the terminator is inside a region branch op.
913  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
914  return success();
915 
917  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
918  nullptr);
919  terminator.getSuccessorRegions(operands, successors);
920 
921  for (mlir::RegionSuccessor &successor : successors) {
922  mlir::OperandRange successorOperands =
923  terminator.getSuccessorOperands(successor);
924  mlir::ValueRange successorInputs = successor.getSuccessorInputs();
925  for (auto [successorOperand, successorInput] :
926  llvm::zip(successorOperands, successorInputs)) {
927  Type inputType = successorInput.getType();
928  // We only need to operate on tensor descriptor or vector types.
929  if (!isa<xegpu::TensorDescType, VectorType>(inputType))
930  continue;
931  xegpu::DistributeLayoutAttr successorInputLayout =
932  getLayoutOfValue(successorInput);
933  xegpu::DistributeLayoutAttr successorOperandLayout =
934  getLayoutOfValue(successorOperand);
935 
936  // If either of the layouts is not assigned, we cannot proceed.
937  if (!successorOperandLayout) {
938  LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
939  "branch terminator: "
940  << successorOperand << "\n");
941  return failure();
942  }
943  // We expect the layouts to match.
944  if (successorInputLayout &&
945  successorInputLayout != successorOperandLayout) {
946  LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
947  "operand forwarded as the argument: "
948  << successorInputLayout << " vs "
949  << successorOperandLayout << "\n");
950  return failure();
951  }
952  // Get tensor descriptor type with the layout.
953  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
954  auto newTdescTy = xegpu::TensorDescType::get(
955  tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
956  tdescTy.getEncoding(), successorOperandLayout);
957  successorInput.setType(newTdescTy);
958  continue;
959  }
960  // If the type is a vector type and this region argument is an OpResult,
961  // set the layout attribute on the OpResult.
962  if (auto result = dyn_cast<OpResult>(successorInput))
963  xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
964  }
965  }
966  return success();
967 }
968 
969 /// Update the function arguments and results with the layouts.
970 static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
971  mlir::FunctionOpInterface funcOp,
972  GetLayoutFnTy getLayoutOfValue) {
973  SmallVector<Type> newArgTypes;
974  // Update the function arguments.
975  for (BlockArgument arg : funcOp.getArguments()) {
976  Type argType = arg.getType();
977  newArgTypes.push_back(argType);
978  if (!isa<VectorType, xegpu::TensorDescType>(argType))
979  continue;
980  xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
981  if (!layout) {
982  LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
983  << " but got none.\n");
984  return failure();
985  }
986  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
987  auto newTdescTy = xegpu::TensorDescType::get(
988  tensorDescTy.getContext(), tensorDescTy.getShape(),
989  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
990  arg.setType(newTdescTy);
991  newArgTypes.back() = newTdescTy;
992  }
993  }
994  // Update the function type with the new argument types.
995  // NOTE: We assume that function results are not expected to have layouts.
996  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
997  funcOp.getResultTypes()));
998  return success();
999 }
1000 
1001 namespace {
1002 struct XeGPUPropagateLayoutPass final
1003  : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1004  XeGPUPropagateLayoutPass() = default;
1005  XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1006  XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1007  : XeGPUPropagateLayoutBase(options) {}
1008  void runOnOperation() override;
1009 };
1010 
1011 } // namespace
1012 
1013 void XeGPUPropagateLayoutPass::runOnOperation() {
1014  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
1015  // Print the analysis result and exit. (for debugging purposes)
1016  if (printOnly) {
1017  auto &os = llvm::outs();
1018  analysis.printAnalysisResult(os);
1019  return;
1020  }
1021  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1022  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1023  LayoutInfo layout = analysis.getLayoutInfo(val);
1024  if (!layout.isAssigned())
1025  return {};
1026  if (layout.isSliceLayout())
1027  return cast<xegpu::SliceAttr>(layout.get());
1028  return cast<xegpu::LayoutAttr>(layout.get());
1029  };
1030 
1031  mlir::OpBuilder builder(&getContext());
1032  Operation *op = getOperation();
1033  auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1034  for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1035  LogicalResult r = success();
1036  TypeSwitch<Operation *>(&op)
1037  .Case<mlir::RegionBranchTerminatorOpInterface>(
1038  [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1039  r = updateControlFlowOps(builder, branchTermOp,
1040  getXeGPULayoutForValue);
1041  })
1042  .Case<mlir::FunctionOpInterface>(
1043  [&](mlir::FunctionOpInterface funcOp) {
1044  r = updateFunctionOpInterface(builder, funcOp,
1045  getXeGPULayoutForValue);
1046  })
1047  .Default([&](Operation *op) {
1048  r = updateOp(builder, op, getXeGPULayoutForValue);
1049  });
1050  if (failed(r)) {
1051  op.emitError("Failed to update operation with the layout.");
1052  return WalkResult::interrupt();
1053  }
1054  }
1055  return WalkResult::advance();
1056  });
1057  if (walkResult.wasInterrupted()) {
1058  signalPassFailure();
1059  return;
1060  }
1061 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
#define DBGS()
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
The general data-flow analysis solver.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:279
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition: Utils.h:29
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:497
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
Definition: XeGPUUtils.cpp:179
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:40
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...