MLIR  22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1 //===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/IR/Visitors.h"
29 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/raw_ostream.h"
39 
40 namespace mlir {
41 namespace xegpu {
42 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
43 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
44 } // namespace xegpu
45 } // namespace mlir
46 
47 #define DEBUG_TYPE "xegpu-propagate-layout"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49 
50 using namespace mlir;
51 using namespace mlir::dataflow;
52 
53 namespace {
54 
55 //===----------------------------------------------------------------------===//
56 // Layout
57 //===----------------------------------------------------------------------===//
58 
59 /// Helper class to store the ND layout of lanes within a subgroup and data
60 /// owned by each lane.
61 struct Layout {
63  Layout() = default;
64  Layout(std::initializer_list<int64_t> list) : layout(list) {}
65  void print(llvm::raw_ostream &os) const;
66  size_t size() const { return layout.size(); }
67 };
68 
69 void Layout::print(llvm::raw_ostream &os) const {
70  os << llvm::interleaved_array(layout);
71 }
72 
73 /// LaneLayout represents the logical layout of lanes within a subgroup when it
74 /// accesses some value. LaneData represents the logical layout of data owned by
75 /// each work item.
76 using LaneLayout = Layout;
77 using LaneData = Layout;
78 
79 //===----------------------------------------------------------------------===//
80 // LayoutInfo
81 //===----------------------------------------------------------------------===//
82 
83 /// Helper class for tracking the analysis state of an mlir value. For layout
84 /// propagation, the analysis state is simply the lane_layout and lane_data of
85 /// each value. Purpose of this analysis to propagate some unique layout for
86 /// each value in the program starting from a set of anchor operations (like
87 /// DPAS, StoreNd, etc.).
88 ///
89 /// Given this, LayoutInfo satisifies the following properties:
90 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
91 /// assigned`.
92 /// 2) Two LayoutInfo values are equal if they are both assigned or
93 /// both not assigned. The concrete value of assigned state does not matter.
94 /// 3) The meet operator works as follows:
95 /// - If current state is assigned, return the current state. (already
96 /// a unique layout is assigned. don't change it)
97 /// - Otherwise, return the other state.
98 
99 struct LayoutInfo {
100 private:
101  LaneLayout laneLayout;
102  LaneData laneData;
103  xegpu::LayoutAttr layoutAttr;
104 
105 public:
106  LayoutInfo() = default;
107  LayoutInfo(const LaneLayout &layout, const LaneData &data)
108  : laneLayout(layout), laneData(data) {}
109 
110  // Two lattice values are equal if they have `some` layout. The actual
111  // content of the layout does not matter.
112  bool operator==(const LayoutInfo &other) const {
113  return this->isAssigned() == other.isAssigned();
114  }
115 
116  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
117 
118  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
119 
120  void print(raw_ostream &os) const;
121 
122  bool isAssigned() const {
123  return laneLayout.size() > 0 && laneData.size() > 0;
124  }
125 
126  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
127 
128  const LaneLayout &getLayout() const { return laneLayout; }
129  const LaneData &getData() const { return laneData; }
130  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
131  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
132 };
133 
134 void LayoutInfo::print(raw_ostream &os) const {
135  if (isAssigned()) {
136  os << "lane_layout: ";
137  laneLayout.print(os);
138  os << ", lane_data: ";
139  laneData.print(os);
140  } else {
141  os << "Not assigned.";
142  }
143 }
144 
145 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
146  if (!lhs.isAssigned())
147  return rhs;
148  return lhs;
149 }
150 
151 /// Since this is a backward analysis, join method is not used.
152 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
153  llvm_unreachable("Join should not be triggered by layout propagation.");
154 }
155 
156 /// Get the transposed layout according to the given permutation.
157 LayoutInfo
158 LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
159  if (!isAssigned())
160  return {};
161  LaneLayout newLayout;
162  LaneData newData;
163  for (int64_t idx : permutation) {
164  newLayout.layout.push_back(laneLayout.layout[idx]);
165  newData.layout.push_back(laneData.layout[idx]);
166  }
167  return LayoutInfo(newLayout, newData);
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // LayoutInfoLattice
172 //===----------------------------------------------------------------------===//
173 
174 /// Lattice holding the LayoutInfo for each value.
175 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
177  using Lattice::Lattice;
178 };
179 
180 /// Helper Functions to get default layouts. A `default layout` is a layout that
181 /// is assigned to a value when the layout is not fixed by some anchor operation
182 /// (like DPAS).
183 
184 /// Helper Function to get the default layout for uniform values like constants.
185 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
186 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
187 static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
188  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
189  if (rank == 1)
190  return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
191  LaneData({1}));
192  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
193  LaneData({1, 1}));
194 }
195 
196 /// Helper to get the default layout for a vector type.
197 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
198  // Expecting a 1D or 2D vector.
199  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
200  "Expected 1D or 2D vector.");
201  // Expecting int or float element type.
202  assert(vectorTy.getElementType().isIntOrFloat() &&
203  "Expected int or float element type.");
204  // If the rank is 1, then return default layout for 1D vector.
205  if (vectorTy.getRank() == 1)
206  return getDefaultSIMTLayoutInfo(1);
207  // Packing factor is determined by the element type bitwidth.
208  int packingFactor = 1;
209  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
211  packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
212  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
213  LaneData({1, packingFactor}));
214 }
215 
216 /// Helper to get the default layout for a vector type.
217 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
218  // Expecting a 1D or 2D vector.
219  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
220  "Expected 1D or 2D TensorDesc.");
221  // Expecting int or float element type.
222  assert(tdescTy.getElementType().isIntOrFloat() &&
223  "Expected int or float element type.");
224  // If the rank is 1, then return default layout for 1D vector.
225  if (tdescTy.getRank() == 1)
226  return getDefaultSIMTLayoutInfo(1);
227  // Packing factor is determined by the element type bitwidth.
228  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
229 
230  if (tdescTy.isScattered()) {
231  int packingFactor =
234  : 1;
235  return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
236  LaneData({1, packingFactor}));
237  }
238 
239  int packingFactor =
242  : 1;
243  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
244  LaneData({1, packingFactor}));
245 }
246 
247 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
248 /// is set according to the following criteria:
249 /// * For A operand, the data must be packed in minimum
250 /// `packedSizeInBitsForDefault`
251 /// * For B operand, the data must be packed in minimum
252 /// `packedSizeInBitsForDpasB`
253 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
254  unsigned operandNum) {
255  Type elementTy = vectorTy.getElementType();
256  assert(elementTy.isIntOrFloat() &&
257  "Expected int or float type in DPAS operands");
258  LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
259  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
260  // must have the VNNI format.
261  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
264  elementTy.getIntOrFloatBitWidth(),
265  1});
266  return LayoutInfo(layout, data);
267  }
268  // Otherwise, return the default layout for the vector type.
269  return getDefaultSIMTLayoutInfo(vectorTy);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // LayoutInfoPropagation
274 //===----------------------------------------------------------------------===//
275 
276 /// Backward data flow analysis to propagate the lane_layout and lane_data of
277 /// each value in the program. Currently, the layouts for operands DPAS,
278 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
279 /// this analysis is to propagate those known layouts to all their producers and
280 /// (other) consumers.
281 class LayoutInfoPropagation
282  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
283 private:
284  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
286 
287  void visitStoreNdOp(xegpu::StoreNdOp store,
290 
291  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
294 
295  void visitLoadNdOp(xegpu::LoadNdOp load,
298 
299  void visitLoadGatherOp(xegpu::LoadGatherOp load,
302 
303  void visitTransposeOp(vector::TransposeOp transpose,
306 
307  void visitVectorBitcastOp(vector::BitCastOp bitcast,
310 
311  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
314 
315  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
318 
319  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
322 
323  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
326 
327 public:
328  LayoutInfoPropagation(DataFlowSolver &solver,
329  SymbolTableCollection &symbolTable)
330  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
332 
333  LogicalResult
334  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
335  ArrayRef<const LayoutInfoLattice *> results) override;
336 
337  void visitBranchOperand(OpOperand &operand) override {};
338 
339  void visitCallOperand(OpOperand &operand) override {};
340 
341  void visitExternalCall(CallOpInterface call,
343  ArrayRef<const LayoutInfoLattice *> results) override {
344  };
345 
346  void setToExitState(LayoutInfoLattice *lattice) override {
347  (void)lattice->meet(LayoutInfo());
348  }
349 };
350 } // namespace
351 
352 LogicalResult LayoutInfoPropagation::visitOperation(
356  .Case<xegpu::DpasOp>(
357  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
358  .Case<xegpu::StoreNdOp>(
359  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
360  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
361  visitStoreScatterOp(storeScatterOp, operands, results);
362  })
363  .Case<xegpu::LoadNdOp>(
364  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
365  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
366  visitLoadGatherOp(loadGatherOp, operands, results);
367  })
368  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
369  visitCreateDescOp(createDescOp, operands, results);
370  })
371  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
372  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
373  })
374  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
375  visitPrefetchNdOp(prefetchNdOp, operands, results);
376  })
377  .Case<vector::TransposeOp>([&](auto transposeOp) {
378  visitTransposeOp(transposeOp, operands, results);
379  })
380  .Case<vector::BitCastOp>([&](auto bitcastOp) {
381  visitVectorBitcastOp(bitcastOp, operands, results);
382  })
383  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
384  visitVectorMultiReductionOp(reductionOp, operands, results);
385  })
386  // All other ops.
387  .Default([&](Operation *op) {
388  for (const LayoutInfoLattice *resultInfo : results) {
389  if (!resultInfo->getValue().isAssigned())
390  continue;
391  for (auto [operandInfo, operand] :
392  llvm::zip(operands, op->getOpOperands())) {
393  // If the operand type is not a vector or tensor descriptor, skip
394  // it.
395  if (!isa<xegpu::TensorDescType, VectorType>(
396  operand.get().getType()))
397  continue;
398  // Propagate the result layout to the operand.
399  meet(operandInfo, *resultInfo);
400  }
401  }
402  });
403 
404  return success();
405 }
406 
407 void LayoutInfoPropagation::visitPrefetchNdOp(
408  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
410  // Here we assign the default layout to the tensor descriptor operand of
411  // prefetch.
412  auto tdescTy = prefetch.getTensorDescType();
413  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
414  // Propagate the layout to the source tensor descriptor.
415  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
416 }
417 
418 void LayoutInfoPropagation::visitVectorMultiReductionOp(
419  vector::MultiDimReductionOp reduction,
422  // The layout of the result must be present.
423  LayoutInfo resultLayout = results[0]->getValue();
424  if (!resultLayout.isAssigned())
425  return;
426  // We only consider 2D -> 1D reductions at this point.
427  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
428  if (!resultTy || resultTy.getRank() != 1) {
429  reduction.emitWarning("Expecting output type to be 1D vector.");
430  return;
431  }
432  // Given that the result is 1D, the layout of the operand should be 2D with
433  // default layout.
434  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
435  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
436  // Accumulator should have the same layout as the result.
437  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
438 }
439 
440 /// Propagate the layout of the result tensor to the source tensor descriptor in
441 /// UpdateNdOffsetOp.
442 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
443  xegpu::UpdateNdOffsetOp updateNdOffset,
446  // The layout of the result must be present.
447  LayoutInfo resultLayout = results[0]->getValue();
448  if (!resultLayout.isAssigned())
449  return;
450  // Propagate the layout to the source operand.
451  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
452 }
453 
454 /// Set the layouts for DPAS A, B, and C operands.
455 void LayoutInfoPropagation::visitDpasOp(
456  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
458  VectorType aTy = dpas.getLhsType();
459  VectorType bTy = dpas.getRhsType();
460  propagateIfChanged(
461  operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
462  propagateIfChanged(
463  operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
464  if (operands.size() > 2) {
465  VectorType cTy = dpas.getAccType();
466  propagateIfChanged(
467  operands[2],
468  operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
469  }
470 }
471 
472 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
473 void LayoutInfoPropagation::visitStoreNdOp(
474  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
476  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
477  // Both operands should have the same layout
478  for (LayoutInfoLattice *operand : operands)
479  propagateIfChanged(operand, operand->meet(storeLayout));
480 }
481 
482 /// Propagate the layout of the value to the tensor descriptor operand in
483 /// LoadNdOp.
484 void LayoutInfoPropagation::visitLoadNdOp(
485  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
487  LayoutInfo valueLayout = results[0]->getValue();
488  // Need the layout of the value to propagate to the tensor descriptor.
489  if (!valueLayout.isAssigned())
490  return;
491  LayoutInfo tensorDescLayout = valueLayout;
492  // LoadNdOp has the transpose effect. However, at the stage of this analysis
493  // this effect is not expected and should be abstracted away. Emit a
494  // warning.
495  if (auto transpose = load.getTranspose()) {
496  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
497  "LayoutInfoPropagation stage.");
498  tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
499  }
500  // Propagate the new layout to the tensor descriptor operand.
501  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
502 }
503 
504 /// For vector::TransposeOp, the layout of the result is transposed and
505 /// propagated to the operand.
506 void LayoutInfoPropagation::visitTransposeOp(
507  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
509  // Need the layout of transpose result to propagate to the operands.
510  LayoutInfo resultLayout = results[0]->getValue();
511  if (!resultLayout.isAssigned())
512  return;
513  LayoutInfo newLayout =
514  resultLayout.getTransposedLayout(transpose.getPermutation());
515  // Propagate the new layout to the vector operand.
516  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
517 }
518 
519 /// For vector::BitCastOp, the lane_data of the source layout is changed based
520 /// on the bit width of the source and result types.
521 void LayoutInfoPropagation::visitVectorBitcastOp(
522  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
524  // Need the layout of bitcast result to propagate to the operands.
525  LayoutInfo resultLayout = results[0]->getValue();
526  if (!resultLayout.isAssigned())
527  return;
528  int inElemTyBitWidth =
529  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
530  int outElemTyBitWidth =
531  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
532 
533  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
534  // a warning and return.
535  if (inElemTyBitWidth != outElemTyBitWidth) {
536  bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
537  "layout propagation stage.");
538  return;
539  }
540 
541  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
542 }
543 
544 /// Propagate the layout of the result to the tensor descriptor and mask
545 /// operands in LoadGatherOp.
546 void LayoutInfoPropagation::visitLoadGatherOp(
547  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
549  // The layout is strictly determined by the tensor descriptor type.
550  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
551 
552  // Mask operand should have 1D default layout.
553  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
554 
555  // Propagate the new layout to the tensor descriptor operand.
556  propagateIfChanged(operands[0], operands[0]->meet(layout));
557  // Propagate the new layout to the mask operand.
558  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
559 }
560 
561 /// Propagate the layout of the descriptor to the vector offset operand in
562 /// CreateDescOp.
563 void LayoutInfoPropagation::visitCreateDescOp(
564  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
566  LayoutInfo descLayout = results[0]->getValue();
567  // Need the layout of the descriptor to propagate to the operands.
568  if (!descLayout.isAssigned())
569  return;
570  // For offset operand propagate 1D default layout.
571  LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
572  propagateIfChanged(operands[1], operands[1]->meet(layout));
573 }
574 
575 /// Set the layout for the value, tensor descriptor, and mask operands in the
576 /// StoreScatterOp.
577 void LayoutInfoPropagation::visitStoreScatterOp(
578  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
580  // Currently, for 2D StoreScatterOp we expect that the height dimension of
581  // the tensor descriptor is equal to the subgroup size. This is ensured by
582  // the op verifier.
583  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
584  if (tdescShape.size() > 1)
585  assert(
586  tdescShape[0] == xegpu::targetinfo::subgroupSize &&
587  "Expected the first dimension of 2D tensor descriptor to be equal to "
588  "subgroup size.");
589 
590  LayoutInfo layout =
591  getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
592 
593  // Propagate the value layout.
594  propagateIfChanged(operands[0], operands[0]->meet(layout));
595  // Propagate the tensor descriptor layout.
596  propagateIfChanged(operands[1], operands[1]->meet(layout));
597  // Use default 1D layout for mask operand.
598  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
599  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
600 }
601 
602 namespace {
603 //===----------------------------------------------------------------------===//
604 // RunLayoutInfoPropagation
605 //===----------------------------------------------------------------------===//
606 
607 /// Driver class for running the LayoutInfoPropagation analysis.
608 class RunLayoutInfoPropagation {
609 public:
610  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
611 
612  RunLayoutInfoPropagation(Operation *op) : target(op) {
613  SymbolTableCollection symbolTable;
614  loadBaselineAnalyses(solver);
615  solver.load<LayoutInfoPropagation>(symbolTable);
616  (void)solver.initializeAndRun(op);
617  }
618 
619  LayoutInfo getLayoutInfo(Value val);
620 
621  void printAnalysisResult(llvm::raw_ostream &os);
622 
623 private:
624  DataFlowSolver solver;
625  const Operation *target;
626 };
627 } // namespace
628 
629 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
630  auto *state = solver.lookupState<LayoutInfoLattice>(val);
631  if (!state)
632  return {};
633  return state->getValue();
634 }
635 
636 // Print the analysis result for debugging purposes.
637 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
638  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
639  os << "function: " << funcOp.getName() << ":\n";
640  // Function arguments
641  for (BlockArgument arg : funcOp.getArguments()) {
642  LayoutInfo layout = getLayoutInfo(arg);
643  os << "argument: " << arg << "\n";
644  os << "layout : ";
645  layout.print(os);
646  os << "\n";
647  }
648  // Function ops
649  funcOp.walk([&](Operation *op) {
650  // Skip ops that do not have results
651  if (op->getResults().empty())
652  return;
653  os << "op : ";
654  // For control-flow ops, print the op name only.
655  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
656  os << op->getName();
657  else
658  op->print(os);
659  os << "\n";
660  // Print the layout for each result.
661  for (auto [i, r] : llvm::enumerate(op->getResults())) {
662  LayoutInfo layout = getLayoutInfo(r);
663  os << "layout for result #" << i << ": ";
664  layout.print(os);
665  os << "\n";
666  }
667  });
668  };
669 
671  if (auto modOp = dyn_cast<ModuleOp>(target)) {
672  for (auto funcOp : modOp.getOps<FunctionOpInterface>())
673  funcOps.push_back(funcOp);
674 
675  // Collect all GpuFuncOps in the module.
676  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
677  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
678  funcOps.push_back(gpuFuncOp);
679  }
680  }
681  // Print the analysis result for each function.
682  for (FunctionOpInterface funcOp : funcOps)
683  printFunctionResult(funcOp);
684 }
685 
686 using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
687 /// Update an operation with the layout of its results. If the result type is a
688 /// vector type, a temporary layout attribute is added to the operation. If the
689 /// result type is a tensor descriptor type, the type is updated with the layout
690 /// attribute. The users of the result are also updated with the layout
691 /// attribute.
692 static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
693  GetLayoutFnTy getLayoutOfValue) {
694  // Region ops (like scf.for) are already handled by the updateControlFlowOps.
695  if (mlir::isa<mlir::RegionBranchOpInterface>(op))
696  return success();
697 
698  // Iterate over all the results.
699  for (OpResult result : op->getResults()) {
700  Type resultType = result.getType();
701  // Layouts are needed only for vector and tensor descriptor types.
702  if (!isa<VectorType, xegpu::TensorDescType>(resultType))
703  continue;
704  // If the result has no layout but has users, emit a warning and continue.
705  xegpu::LayoutAttr layout = getLayoutOfValue(result);
706  if (!layout && result.getNumUses() > 0) {
707  op->emitWarning("op has users but no layout assigned for its result");
708  continue;
709  }
710  // If the result is a tensor descriptor type, update the tensor desc type
711  // with layout.
712  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
713  auto typeWithLayout = xegpu::TensorDescType::get(
714  tensorDescTy.getContext(), tensorDescTy.getShape(),
715  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
716  result.setType(typeWithLayout);
717  continue;
718  }
719  // If the result is a vector type, add a temporary layout attribute to the
720  // op.
721  xegpu::setLayoutAttr(result, layout);
722  }
723  return success();
724 }
725 
726 /// Region ops like scf.for need special handling because they have blocks
727 /// inside. If the blocks have tensor descriptor type as block arguments, thier
728 /// types must be updated. Also region op can have results that may not have any
729 /// users (e.g. A and B tiles). They are not assigned a layout by layout
730 /// analysis because they have no users. However inside the region op
731 /// corresponding block arguments for these results do have layouts. Therefore,
732 /// in this case we still need to update the result types with the layout
733 /// attribute. This function function updates the internal block arguments and
734 /// the result types of the region op with the assigned layouts.
735 /// clang-format off
736 /// Example: scf.for ... iter_args(...) -> (out types) {
737 /// ^bb0(block types):
738 /// ...
739 /// scf.yield ... : (yield types)
740 /// }
741 /// clang-format on
742 /// In this example, at scf.yield, control-flow can transfer to two successor
743 /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
744 /// itself (yield the results). So we update both the block arguments of the
745 /// successor region (i.e. block types) and the result types of the scf.for op
746 /// (i.e. out types). Note that yield types are updated by respective producers
747 /// inside bb0.
748 static LogicalResult
750  mlir::RegionBranchTerminatorOpInterface terminator,
751  GetLayoutFnTy getLayoutOfValue) {
752  // Only process if the terminator is inside a region branch op.
753  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
754  return success();
755 
757  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
758  nullptr);
759  terminator.getSuccessorRegions(operands, successors);
760 
761  for (mlir::RegionSuccessor &successor : successors) {
762  mlir::OperandRange successorOperands =
763  terminator.getSuccessorOperands(successor);
764  mlir::ValueRange successorInputs = successor.getSuccessorInputs();
765  for (auto [successorOperand, successorInput] :
766  llvm::zip(successorOperands, successorInputs)) {
767  Type inputType = successorInput.getType();
768  // We only need to operate on tensor descriptor or vector types.
769  if (!isa<xegpu::TensorDescType, VectorType>(inputType))
770  continue;
771  xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
772  xegpu::LayoutAttr successorOperandLayout =
773  getLayoutOfValue(successorOperand);
774 
775  // If either of the layouts is not assigned, we cannot proceed.
776  if (!successorOperandLayout) {
777  LLVM_DEBUG(
778  DBGS()
779  << "No layout assigned for forwarded operand in branch terminator: "
780  << successorOperand << "\n");
781  return failure();
782  }
783  // We expect the layouts to match.
784  if (successorInputLayout &&
785  successorInputLayout != successorOperandLayout) {
786  LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
787  "operand forwarded as the argument: "
788  << successorInputLayout << " vs "
789  << successorOperandLayout << "\n");
790  return failure();
791  }
792  // Get tensor descriptor type with the layout.
793  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
794  auto newTdescTy = xegpu::TensorDescType::get(
795  tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
796  tdescTy.getEncoding(), successorOperandLayout);
797  successorInput.setType(newTdescTy);
798  continue;
799  }
800  // If the type is a vector type and this region argument is an OpResult,
801  // set the layout attribute on the OpResult.
802  if (auto result = dyn_cast<OpResult>(successorInput))
803  xegpu::setLayoutAttr(result, successorOperandLayout);
804  }
805  }
806  return success();
807 }
808 
809 /// Update the function arguments and results with the layouts.
810 static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
811  mlir::FunctionOpInterface funcOp,
812  GetLayoutFnTy getLayoutOfValue) {
813  SmallVector<Type> newArgTypes;
814  // Update the function arguments.
815  for (BlockArgument arg : funcOp.getArguments()) {
816  Type argType = arg.getType();
817  newArgTypes.push_back(argType);
818  if (!isa<VectorType, xegpu::TensorDescType>(argType))
819  continue;
820  xegpu::LayoutAttr layout = getLayoutOfValue(arg);
821  if (!layout) {
822  LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
823  << " but got none.\n");
824  return failure();
825  }
826  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
827  auto newTdescTy = xegpu::TensorDescType::get(
828  tensorDescTy.getContext(), tensorDescTy.getShape(),
829  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
830  arg.setType(newTdescTy);
831  newArgTypes.back() = newTdescTy;
832  }
833  }
834  // Update the function type with the new argument types.
835  // NOTE: We assume that function results are not expected to have layouts.
836  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
837  funcOp.getResultTypes()));
838  return success();
839 }
840 
841 namespace {
842 struct XeGPUPropagateLayoutPass final
843  : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
844  XeGPUPropagateLayoutPass() = default;
845  XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
846  XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
847  : XeGPUPropagateLayoutBase(options) {}
848  void runOnOperation() override;
849 };
850 
851 } // namespace
852 
853 void XeGPUPropagateLayoutPass::runOnOperation() {
854  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
855  // Print the analysis result and exit. (for debugging purposes)
856  if (printOnly) {
857  auto &os = llvm::outs();
858  analysis.printAnalysisResult(os);
859  return;
860  }
861  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
862  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
863  LayoutInfo layout = analysis.getLayoutInfo(val);
864  if (!layout.isAssigned())
865  return {};
866  return xegpu::LayoutAttr::get(
867  val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
868  llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
869  };
870 
871  mlir::OpBuilder builder(&getContext());
872  Operation *op = getOperation();
873  auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
874  for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
875  LogicalResult r = success();
876  TypeSwitch<Operation *>(&op)
877  .Case<mlir::RegionBranchTerminatorOpInterface>(
878  [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
879  r = updateControlFlowOps(builder, branchTermOp,
880  getXeGPULayoutForValue);
881  })
882  .Case<mlir::FunctionOpInterface>(
883  [&](mlir::FunctionOpInterface funcOp) {
884  r = updateFunctionOpInterface(builder, funcOp,
885  getXeGPULayoutForValue);
886  })
887  .Default([&](Operation *op) {
888  r = updateOp(builder, op, getXeGPULayoutForValue);
889  });
890  if (failed(r)) {
891  op.emitError("Failed to update operation with the layout.");
892  return WalkResult::interrupt();
893  }
894  }
895  return WalkResult::advance();
896  });
897  if (walkResult.wasInterrupted()) {
898  signalPassFailure();
899  return;
900  }
901 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
#define DBGS()
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
The general data-flow analysis solver.
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:279
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition: Utils.h:29
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Definition: XeGPUUtils.cpp:160
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...