MLIR  21.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1 //===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/IR/Visitors.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/InterleavedRange.h"
38 #include "llvm/Support/LogicalResult.h"
39 #include "llvm/Support/raw_ostream.h"
40 
41 namespace mlir {
42 namespace xegpu {
43 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45 } // namespace xegpu
46 } // namespace mlir
47 
48 #define DEBUG_TYPE "xegpu-propagate-layout"
49 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50 
51 using namespace mlir;
52 using namespace mlir::dataflow;
53 
54 namespace {
55 
56 //===----------------------------------------------------------------------===//
57 // Layout
58 //===----------------------------------------------------------------------===//
59 
60 /// Helper class to store the ND layout of lanes within a subgroup and data
61 /// owned by each lane.
62 struct Layout {
64  Layout() = default;
65  Layout(std::initializer_list<int64_t> list) : layout(list) {}
66  void print(llvm::raw_ostream &os) const;
67  size_t size() const { return layout.size(); }
68 };
69 
70 void Layout::print(llvm::raw_ostream &os) const {
71  os << llvm::interleaved_array(layout);
72 }
73 
74 /// LaneLayout represents the logical layout of lanes within a subgroup when it
75 /// accesses some value. LaneData represents the logical layout of data owned by
76 /// each work item.
77 using LaneLayout = Layout;
78 using LaneData = Layout;
79 
80 //===----------------------------------------------------------------------===//
81 // LayoutInfo
82 //===----------------------------------------------------------------------===//
83 
84 /// Helper class for tracking the analysis state of an mlir value. For layout
85 /// propagation, the analysis state is simply the lane_layout and lane_data of
86 /// each value. Purpose of this analysis to propagate some unique layout for
87 /// each value in the program starting from a set of anchor operations (like
88 /// DPAS, StoreNd, etc.).
89 ///
90 /// Given this, LayoutInfo satisifies the following properties:
91 /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
92 /// assigned`.
93 /// 2) Two LayoutInfo values are equal if they are both assigned or
94 /// both not assigned. The concrete value of assigned state does not matter.
95 /// 3) The meet operator works as follows:
96 /// - If current state is assigned, return the current state. (already
97 /// a unique layout is assigned. don't change it)
98 /// - Otherwise, return the other state.
99 
100 struct LayoutInfo {
101 private:
102  LaneLayout laneLayout;
103  LaneData laneData;
104  xegpu::LayoutAttr layoutAttr;
105 
106 public:
107  LayoutInfo() = default;
108  LayoutInfo(const LaneLayout &layout, const LaneData &data)
109  : laneLayout(layout), laneData(data) {}
110 
111  // Two lattice values are equal if they have `some` layout. The actual
112  // content of the layout does not matter.
113  bool operator==(const LayoutInfo &other) const {
114  return this->isAssigned() == other.isAssigned();
115  }
116 
117  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
118 
119  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
120 
121  void print(raw_ostream &os) const;
122 
123  bool isAssigned() const {
124  return laneLayout.size() > 0 && laneData.size() > 0;
125  }
126 
127  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
128 
129  const LaneLayout &getLayout() const { return laneLayout; }
130  const LaneData &getData() const { return laneData; }
131  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
132  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
133 };
134 
135 void LayoutInfo::print(raw_ostream &os) const {
136  if (isAssigned()) {
137  os << "lane_layout: ";
138  laneLayout.print(os);
139  os << ", lane_data: ";
140  laneData.print(os);
141  } else {
142  os << "Not assigned.";
143  }
144 }
145 
146 LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
147  if (!lhs.isAssigned())
148  return rhs;
149  return lhs;
150 }
151 
152 /// Since this is a backward analysis, join method is not used.
153 LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
154  llvm_unreachable("Join should not be triggered by layout propagation.");
155 }
156 
157 /// Get the transposed layout according to the given permutation.
158 LayoutInfo
159 LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
160  if (!isAssigned())
161  return {};
162  LaneLayout newLayout;
163  LaneData newData;
164  for (int64_t idx : permutation) {
165  newLayout.layout.push_back(laneLayout.layout[idx]);
166  newData.layout.push_back(laneData.layout[idx]);
167  }
168  return LayoutInfo(newLayout, newData);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // LayoutInfoLattice
173 //===----------------------------------------------------------------------===//
174 
175 /// Lattice holding the LayoutInfo for each value.
176 struct LayoutInfoLattice : public Lattice<LayoutInfo> {
178  using Lattice::Lattice;
179 };
180 
181 /// Helper Functions to get default layouts. A `default layout` is a layout that
182 /// is assigned to a value when the layout is not fixed by some anchor operation
183 /// (like DPAS).
184 
185 /// Helper Function to get the default layout for uniform values like constants.
186 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
187 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
188 static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
189  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
190  if (rank == 1)
191  return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
192  LaneData({1}));
193  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
194  LaneData({1, 1}));
195 }
196 
197 /// Helper to get the default layout for a vector type.
198 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
199  // Expecting a 1D or 2D vector.
200  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
201  "Expected 1D or 2D vector.");
202  // Expecting int or float element type.
203  assert(vectorTy.getElementType().isIntOrFloat() &&
204  "Expected int or float element type.");
205  // If the rank is 1, then return default layout for 1D vector.
206  if (vectorTy.getRank() == 1)
207  return getDefaultSIMTLayoutInfo(1);
208  // Packing factor is determined by the element type bitwidth.
209  int packingFactor = 1;
210  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
212  packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
213  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
214  LaneData({1, packingFactor}));
215 }
216 
217 /// Helper to get the default layout for a vector type.
218 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
219  // Expecting a 1D or 2D vector.
220  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
221  "Expected 1D or 2D TensorDesc.");
222  // Expecting int or float element type.
223  assert(tdescTy.getElementType().isIntOrFloat() &&
224  "Expected int or float element type.");
225  // If the rank is 1, then return default layout for 1D vector.
226  if (tdescTy.getRank() == 1)
227  return getDefaultSIMTLayoutInfo(1);
228  // Packing factor is determined by the element type bitwidth.
229  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
230 
231  if (tdescTy.isScattered()) {
232  int packingFactor =
235  : 1;
236  return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
237  LaneData({1, packingFactor}));
238  }
239 
240  int packingFactor =
243  : 1;
244  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
245  LaneData({1, packingFactor}));
246 }
247 
248 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
249 /// is set according to the following criteria:
250 /// * For A operand, the data must be packed in minimum
251 /// `packedSizeInBitsForDefault`
252 /// * For B operand, the data must be packed in minimum
253 /// `packedSizeInBitsForDpasB`
254 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
255  unsigned operandNum) {
256  Type elementTy = vectorTy.getElementType();
257  assert(elementTy.isIntOrFloat() &&
258  "Expected int or float type in DPAS operands");
259  LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
260  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
261  // must have the VNNI format.
262  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
265  elementTy.getIntOrFloatBitWidth(),
266  1});
267  return LayoutInfo(layout, data);
268  }
269  // Otherwise, return the default layout for the vector type.
270  return getDefaultSIMTLayoutInfo(vectorTy);
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // LayoutInfoPropagation
275 //===----------------------------------------------------------------------===//
276 
277 /// Backward data flow analysis to propagate the lane_layout and lane_data of
278 /// each value in the program. Currently, the layouts for operands DPAS,
279 /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
280 /// this analysis is to propagate those known layouts to all their producers and
281 /// (other) consumers.
282 class LayoutInfoPropagation
283  : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
284 private:
285  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
287 
288  void visitStoreNdOp(xegpu::StoreNdOp store,
291 
292  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
295 
296  void visitLoadNdOp(xegpu::LoadNdOp load,
299 
300  void visitLoadGatherOp(xegpu::LoadGatherOp load,
303 
304  void visitTransposeOp(vector::TransposeOp transpose,
307 
308  void visitVectorBitcastOp(vector::BitCastOp bitcast,
311 
312  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
315 
316  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
319 
320  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
323 
324  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
327 
328 public:
329  LayoutInfoPropagation(DataFlowSolver &solver,
330  SymbolTableCollection &symbolTable)
331  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
333 
334  LogicalResult
335  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
336  ArrayRef<const LayoutInfoLattice *> results) override;
337 
338  void visitBranchOperand(OpOperand &operand) override {};
339 
340  void visitCallOperand(OpOperand &operand) override {};
341 
342  void visitExternalCall(CallOpInterface call,
344  ArrayRef<const LayoutInfoLattice *> results) override {
345  };
346 
347  void setToExitState(LayoutInfoLattice *lattice) override {
348  (void)lattice->meet(LayoutInfo());
349  }
350 };
351 } // namespace
352 
353 LogicalResult LayoutInfoPropagation::visitOperation(
357  .Case<xegpu::DpasOp>(
358  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
359  .Case<xegpu::StoreNdOp>(
360  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
361  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
362  visitStoreScatterOp(storeScatterOp, operands, results);
363  })
364  .Case<xegpu::LoadNdOp>(
365  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
366  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
367  visitLoadGatherOp(loadGatherOp, operands, results);
368  })
369  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
370  visitCreateDescOp(createDescOp, operands, results);
371  })
372  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
373  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
374  })
375  .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
376  visitPrefetchNdOp(prefetchNdOp, operands, results);
377  })
378  .Case<vector::TransposeOp>([&](auto transposeOp) {
379  visitTransposeOp(transposeOp, operands, results);
380  })
381  .Case<vector::BitCastOp>([&](auto bitcastOp) {
382  visitVectorBitcastOp(bitcastOp, operands, results);
383  })
384  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
385  visitVectorMultiReductionOp(reductionOp, operands, results);
386  })
387  // All other ops.
388  .Default([&](Operation *op) {
389  for (const LayoutInfoLattice *resultInfo : results) {
390  if (!resultInfo->getValue().isAssigned())
391  continue;
392  for (auto [operandInfo, operand] :
393  llvm::zip(operands, op->getOpOperands())) {
394  // If the operand type is not a vector or tensor descriptor, skip
395  // it.
396  if (!isa<xegpu::TensorDescType, VectorType>(
397  operand.get().getType()))
398  continue;
399  // Propagate the result layout to the operand.
400  meet(operandInfo, *resultInfo);
401  }
402  }
403  });
404 
405  return success();
406 }
407 
408 void LayoutInfoPropagation::visitPrefetchNdOp(
409  xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
411  // Here we assign the default layout to the tensor descriptor operand of
412  // prefetch.
413  auto tdescTy = prefetch.getTensorDescType();
414  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
415  // Propagate the layout to the source tensor descriptor.
416  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
417 }
418 
419 void LayoutInfoPropagation::visitVectorMultiReductionOp(
420  vector::MultiDimReductionOp reduction,
423  // The layout of the result must be present.
424  LayoutInfo resultLayout = results[0]->getValue();
425  if (!resultLayout.isAssigned())
426  return;
427  // We only consider 2D -> 1D reductions at this point.
428  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
429  if (!resultTy || resultTy.getRank() != 1) {
430  reduction.emitWarning("Expecting output type to be 1D vector.");
431  return;
432  }
433  // Given that the result is 1D, the layout of the operand should be 2D with
434  // default layout.
435  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
436  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
437  // Accumulator should have the same layout as the result.
438  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
439 }
440 
441 /// Propagate the layout of the result tensor to the source tensor descriptor in
442 /// UpdateNdOffsetOp.
443 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
444  xegpu::UpdateNdOffsetOp updateNdOffset,
447  // The layout of the result must be present.
448  LayoutInfo resultLayout = results[0]->getValue();
449  if (!resultLayout.isAssigned())
450  return;
451  // Propagate the layout to the source operand.
452  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
453 }
454 
455 /// Set the layouts for DPAS A, B, and C operands.
456 void LayoutInfoPropagation::visitDpasOp(
457  xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
459  VectorType aTy = dpas.getLhsType();
460  VectorType bTy = dpas.getRhsType();
461  propagateIfChanged(
462  operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
463  propagateIfChanged(
464  operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
465  if (operands.size() > 2) {
466  VectorType cTy = dpas.getAccType();
467  propagateIfChanged(
468  operands[2],
469  operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
470  }
471 }
472 
473 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
474 void LayoutInfoPropagation::visitStoreNdOp(
475  xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
477  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
478  // Both operands should have the same layout
479  for (LayoutInfoLattice *operand : operands)
480  propagateIfChanged(operand, operand->meet(storeLayout));
481 }
482 
483 /// Propagate the layout of the value to the tensor descriptor operand in
484 /// LoadNdOp.
485 void LayoutInfoPropagation::visitLoadNdOp(
486  xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
488  LayoutInfo valueLayout = results[0]->getValue();
489  // Need the layout of the value to propagate to the tensor descriptor.
490  if (!valueLayout.isAssigned())
491  return;
492  LayoutInfo tensorDescLayout = valueLayout;
493  // LoadNdOp has the transpose effect. However, at the stage of this analysis
494  // this effect is not expected and should be abstracted away. Emit a
495  // warning.
496  if (auto transpose = load.getTranspose()) {
497  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
498  "LayoutInfoPropagation stage.");
499  tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
500  }
501  // Propagate the new layout to the tensor descriptor operand.
502  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
503 }
504 
505 /// For vector::TransposeOp, the layout of the result is transposed and
506 /// propagated to the operand.
507 void LayoutInfoPropagation::visitTransposeOp(
508  vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
510  // Need the layout of transpose result to propagate to the operands.
511  LayoutInfo resultLayout = results[0]->getValue();
512  if (!resultLayout.isAssigned())
513  return;
514  LayoutInfo newLayout =
515  resultLayout.getTransposedLayout(transpose.getPermutation());
516  // Propagate the new layout to the vector operand.
517  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
518 }
519 
520 /// For vector::BitCastOp, the lane_data of the source layout is changed based
521 /// on the bit width of the source and result types.
522 void LayoutInfoPropagation::visitVectorBitcastOp(
523  vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
525  // Need the layout of bitcast result to propagate to the operands.
526  LayoutInfo resultLayout = results[0]->getValue();
527  if (!resultLayout.isAssigned())
528  return;
529  int inElemTyBitWidth =
530  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
531  int outElemTyBitWidth =
532  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
533 
534  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
535  // a warning and return.
536  if (inElemTyBitWidth != outElemTyBitWidth) {
537  bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
538  "layout propagation stage.");
539  return;
540  }
541 
542  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
543 }
544 
545 /// Propagate the layout of the result to the tensor descriptor and mask
546 /// operands in LoadGatherOp.
547 void LayoutInfoPropagation::visitLoadGatherOp(
548  xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
550  // The layout is strictly determined by the tensor descriptor type.
551  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
552 
553  // Mask operand should have 1D default layout.
554  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
555 
556  // Propagate the new layout to the tensor descriptor operand.
557  propagateIfChanged(operands[0], operands[0]->meet(layout));
558  // Propagate the new layout to the mask operand.
559  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
560 }
561 
562 /// Propagate the layout of the descriptor to the vector offset operand in
563 /// CreateDescOp.
564 void LayoutInfoPropagation::visitCreateDescOp(
565  xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
567  LayoutInfo descLayout = results[0]->getValue();
568  // Need the layout of the descriptor to propagate to the operands.
569  if (!descLayout.isAssigned())
570  return;
571  // For offset operand propagate 1D default layout.
572  LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
573  propagateIfChanged(operands[1], operands[1]->meet(layout));
574 }
575 
576 /// Set the layout for the value, tensor descriptor, and mask operands in the
577 /// StoreScatterOp.
578 void LayoutInfoPropagation::visitStoreScatterOp(
579  xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
581  // Currently, for 2D StoreScatterOp we expect that the height dimension of
582  // the tensor descriptor is equal to the subgroup size. This is ensured by
583  // the op verifier.
584  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
585  if (tdescShape.size() > 1)
586  assert(
587  tdescShape[0] == xegpu::targetinfo::subgroupSize &&
588  "Expected the first dimension of 2D tensor descriptor to be equal to "
589  "subgroup size.");
590 
591  LayoutInfo layout =
592  getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
593 
594  // Propagate the value layout.
595  propagateIfChanged(operands[0], operands[0]->meet(layout));
596  // Propagate the tensor descriptor layout.
597  propagateIfChanged(operands[1], operands[1]->meet(layout));
598  // Use default 1D layout for mask operand.
599  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
600  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
601 }
602 
603 namespace {
604 //===----------------------------------------------------------------------===//
605 // RunLayoutInfoPropagation
606 //===----------------------------------------------------------------------===//
607 
608 /// Driver class for running the LayoutInfoPropagation analysis.
609 class RunLayoutInfoPropagation {
610 public:
611  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
612 
613  RunLayoutInfoPropagation(Operation *op) : target(op) {
614  SymbolTableCollection symbolTable;
615  loadBaselineAnalyses(solver);
616  solver.load<LayoutInfoPropagation>(symbolTable);
617  (void)solver.initializeAndRun(op);
618  }
619 
620  LayoutInfo getLayoutInfo(Value val);
621 
622  void printAnalysisResult(llvm::raw_ostream &os);
623 
624 private:
625  DataFlowSolver solver;
626  const Operation *target;
627 };
628 } // namespace
629 
630 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
631  auto *state = solver.lookupState<LayoutInfoLattice>(val);
632  if (!state)
633  return {};
634  return state->getValue();
635 }
636 
637 // Print the analysis result for debugging purposes.
638 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
639  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
640  os << "function: " << funcOp.getName() << ":\n";
641  // Function arguments
642  for (BlockArgument arg : funcOp.getArguments()) {
643  LayoutInfo layout = getLayoutInfo(arg);
644  os << "argument: " << arg << "\n";
645  os << "layout : ";
646  layout.print(os);
647  os << "\n";
648  }
649  // Function ops
650  funcOp.walk([&](Operation *op) {
651  // Skip ops that do not have results
652  if (op->getResults().empty())
653  return;
654  os << "op : ";
655  // For control-flow ops, print the op name only.
656  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
657  os << op->getName();
658  else
659  op->print(os);
660  os << "\n";
661  // Print the layout for each result.
662  for (auto [i, r] : llvm::enumerate(op->getResults())) {
663  LayoutInfo layout = getLayoutInfo(r);
664  os << "layout for result #" << i << ": ";
665  layout.print(os);
666  os << "\n";
667  }
668  });
669  };
670 
672  if (auto modOp = dyn_cast<ModuleOp>(target)) {
673  for (auto funcOp : modOp.getOps<FunctionOpInterface>())
674  funcOps.push_back(funcOp);
675 
676  // Collect all GpuFuncOps in the module.
677  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
678  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
679  funcOps.push_back(gpuFuncOp);
680  }
681  }
682  // Print the analysis result for each function.
683  for (FunctionOpInterface funcOp : funcOps)
684  printFunctionResult(funcOp);
685 }
686 
687 using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
688 /// Update an operation with the layout of its results. If the result type is a
689 /// vector type, a temporary layout attribute is added to the operation. If the
690 /// result type is a tensor descriptor type, the type is updated with the layout
691 /// attribute. The users of the result are also updated with the layout
692 /// attribute.
693 static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
694  GetLayoutFnTy getLayoutOfValue) {
695  // Region ops (like scf.for) are already handled by the updateControlFlowOps.
696  if (mlir::isa<mlir::RegionBranchOpInterface>(op))
697  return success();
698 
699  // Iterate over all the results.
700  for (OpResult result : op->getResults()) {
701  Type resultType = result.getType();
702  // Layouts are needed only for vector and tensor descriptor types.
703  if (!isa<VectorType, xegpu::TensorDescType>(resultType))
704  continue;
705  // If the result has no layout but has users, emit a warning and continue.
706  xegpu::LayoutAttr layout = getLayoutOfValue(result);
707  if (!layout && result.getNumUses() > 0) {
708  op->emitWarning("op has users but no layout assigned for its result");
709  continue;
710  }
711  // If the result is a tensor descriptor type, update the tensor desc type
712  // with layout.
713  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
714  auto typeWithLayout = xegpu::TensorDescType::get(
715  tensorDescTy.getContext(), tensorDescTy.getShape(),
716  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
717  result.setType(typeWithLayout);
718  continue;
719  }
720  // If the result is a vector type, add a temporary layout attribute to the
721  // op.
722  xegpu::setLayoutAttr(result, layout);
723  }
724  return success();
725 }
726 
727 /// Region ops like scf.for need special handling because they have blocks
728 /// inside. If the blocks have tensor descriptor type as block arguments, thier
729 /// types must be updated. Also region op can have results that may not have any
730 /// users (e.g. A and B tiles). They are not assigned a layout by layout
731 /// analysis because they have no users. However inside the region op
732 /// corresponding block arguments for these results do have layouts. Therefore,
733 /// in this case we still need to update the result types with the layout
734 /// attribute. This function function updates the internal block arguments and
735 /// the result types of the region op with the assigned layouts.
736 /// clang-format off
737 /// Example: scf.for ... iter_args(...) -> (out types) {
738 /// ^bb0(block types):
739 /// ...
740 /// scf.yield ... : (yield types)
741 /// }
742 /// clang-format on
743 /// In this example, at scf.yield, control-flow can transfer to two successor
744 /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
745 /// itself (yield the results). So we update both the block arguments of the
746 /// successor region (i.e. block types) and the result types of the scf.for op
747 /// (i.e. out types). Note that yield types are updated by respective producers
748 /// inside bb0.
749 static LogicalResult
751  mlir::RegionBranchTerminatorOpInterface terminator,
752  GetLayoutFnTy getLayoutOfValue) {
753  // Only process if the terminator is inside a region branch op.
754  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
755  return success();
756 
758  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
759  nullptr);
760  terminator.getSuccessorRegions(operands, successors);
761 
762  for (mlir::RegionSuccessor &successor : successors) {
763  mlir::OperandRange successorOperands =
764  terminator.getSuccessorOperands(successor);
765  mlir::ValueRange successorInputs = successor.getSuccessorInputs();
766  for (auto [successorOperand, successorInput] :
767  llvm::zip(successorOperands, successorInputs)) {
768  Type inputType = successorInput.getType();
769  // We only need to operate on tensor descriptor or vector types.
770  if (!isa<xegpu::TensorDescType, VectorType>(inputType))
771  continue;
772  xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
773  xegpu::LayoutAttr successorOperandLayout =
774  getLayoutOfValue(successorOperand);
775 
776  // If either of the layouts is not assigned, we cannot proceed.
777  if (!successorOperandLayout) {
778  LLVM_DEBUG(
779  DBGS()
780  << "No layout assigned for forwarded operand in branch terminator: "
781  << successorOperand << "\n");
782  return failure();
783  }
784  // We expect the layouts to match.
785  if (successorInputLayout &&
786  successorInputLayout != successorOperandLayout) {
787  LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
788  "operand forwarded as the argument: "
789  << successorInputLayout << " vs "
790  << successorOperandLayout << "\n");
791  return failure();
792  }
793  // Get tensor descriptor type with the layout.
794  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
795  auto newTdescTy = xegpu::TensorDescType::get(
796  tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
797  tdescTy.getEncoding(), successorOperandLayout);
798  successorInput.setType(newTdescTy);
799  continue;
800  }
801  // If the type is a vector type and this region argument is an OpResult,
802  // set the layout attribute on the OpResult.
803  if (auto result = dyn_cast<OpResult>(successorInput))
804  xegpu::setLayoutAttr(result, successorOperandLayout);
805  }
806  }
807  return success();
808 }
809 
810 /// Update the function arguments and results with the layouts.
811 static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
812  mlir::FunctionOpInterface funcOp,
813  GetLayoutFnTy getLayoutOfValue) {
814  SmallVector<Type> newArgTypes;
815  // Update the function arguments.
816  for (BlockArgument arg : funcOp.getArguments()) {
817  Type argType = arg.getType();
818  newArgTypes.push_back(argType);
819  if (!isa<VectorType, xegpu::TensorDescType>(argType))
820  continue;
821  xegpu::LayoutAttr layout = getLayoutOfValue(arg);
822  if (!layout) {
823  LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
824  << " but got none.\n");
825  return failure();
826  }
827  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
828  auto newTdescTy = xegpu::TensorDescType::get(
829  tensorDescTy.getContext(), tensorDescTy.getShape(),
830  tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
831  arg.setType(newTdescTy);
832  newArgTypes.back() = newTdescTy;
833  }
834  }
835  // Update the function type with the new argument types.
836  // NOTE: We assume that function results are not expected to have layouts.
837  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
838  funcOp.getResultTypes()));
839  return success();
840 }
841 
842 namespace {
843 struct XeGPUPropagateLayoutPass final
844  : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
845  XeGPUPropagateLayoutPass() = default;
846  XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
847  XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
848  : XeGPUPropagateLayoutBase(options) {}
849  void runOnOperation() override;
850 };
851 
852 } // namespace
853 
854 void XeGPUPropagateLayoutPass::runOnOperation() {
855  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
856  // Print the analysis result and exit. (for debugging purposes)
857  if (printOnly) {
858  auto &os = llvm::outs();
859  analysis.printAnalysisResult(os);
860  return;
861  }
862  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
863  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
864  LayoutInfo layout = analysis.getLayoutInfo(val);
865  if (!layout.isAssigned())
866  return {};
867  return xegpu::LayoutAttr::get(
868  val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
869  llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
870  };
871 
872  mlir::OpBuilder builder(&getContext());
873  Operation *op = getOperation();
874  auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
875  for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
876  LogicalResult r = success();
877  TypeSwitch<Operation *>(&op)
878  .Case<mlir::RegionBranchTerminatorOpInterface>(
879  [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
880  r = updateControlFlowOps(builder, branchTermOp,
881  getXeGPULayoutForValue);
882  })
883  .Case<mlir::FunctionOpInterface>(
884  [&](mlir::FunctionOpInterface funcOp) {
885  r = updateFunctionOpInterface(builder, funcOp,
886  getXeGPULayoutForValue);
887  })
888  .Default([&](Operation *op) {
889  r = updateOp(builder, op, getXeGPULayoutForValue);
890  });
891  if (failed(r)) {
892  op.emitError("Failed to update operation with the layout.");
893  return WalkResult::interrupt();
894  }
895  }
896  return WalkResult::advance();
897  });
898  if (walkResult.wasInterrupted()) {
899  signalPassFailure();
900  return;
901  }
902 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
#define DBGS()
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
The general data-flow analysis solver.
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:279
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition: Utils.h:29
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Definition: XeGPUUtils.cpp:156
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...