MLIR  21.0.0git
XeGPUSubgroupDistribute.cpp
Go to the documentation of this file.
1 //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
18 #include "mlir/IR/Builders.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace mlir {
24 namespace xegpu {
25 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27 } // namespace xegpu
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::dataflow;
32 
33 /// HW dependent constants.
34 /// TODO: These constants should be queried from the target information.
35 constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
36 /// If DPAS A or B operands have low precision element types they must be packed
37 /// according to the following sizes.
38 constexpr unsigned packedSizeInBitsForDefault =
39  16; // Minimum packing size per register for DPAS A.
40 constexpr unsigned packedSizeInBitsForDpasB =
41  32; // Minimum packing size per register for DPAS B.
42 
43 namespace {
44 
45 ///===----------------------------------------------------------------------===///
46 /// Layout
47 ///===----------------------------------------------------------------------===///
48 
49 /// Helper class to store the ND layout of work items within a subgroup and data
50 /// owned by each work item.
51 struct Layout {
53  Layout() = default;
54  Layout(std::initializer_list<int64_t> list) : layout(list) {}
55  void print(llvm::raw_ostream &os) const;
56  size_t size() const { return layout.size(); }
57  int64_t operator[](size_t idx) const;
58 };
59 
60 void Layout::print(llvm::raw_ostream &os) const {
61  os << "[";
62  llvm::interleaveComma(layout, os);
63  os << "]";
64 }
65 
66 int64_t Layout::operator[](size_t idx) const {
67  assert(idx < layout.size() && "Index out of bounds.");
68  return layout[idx];
69 }
70 
71 /// WiLayout represents the layout of work items within a subgroup when it
72 /// accesses some value. WiData represents the layout of data owned by each work
73 /// item.
74 using WiLayout = Layout;
75 using WiData = Layout;
76 
77 ///===----------------------------------------------------------------------===///
78 /// SGMap
79 ///===----------------------------------------------------------------------===///
80 
81 /// Helper class for tracking the analysis state of a value. For SGPropagation,
82 /// the analysis state is simply the wi_layout and wi_data of each value.
83 /// Purpose of this analysis to propagate some unique layout for each value in
84 /// the program starting from some known values (like DPAS, StoreNd, etc.).
85 ///
86 /// Given this, SGMap satisifies the following properties:
87 /// 1) SGMap is a lattice with two states - assigned and not assigned.
88 /// 2) Two SGMap values are equal if they are both assigned or both not
89 /// assigned. The concrete value of assigned state does not matter.
90 /// 3) The meet operator works as follows:
91 /// - If current state is assigned, return the current state. (already
92 /// a unique layout is assigned. don't change it)
93 /// - Otherwise, return the other state.
94 
95 struct SGMap {
96 private:
97  WiLayout wiLayout;
98  WiData wiData;
99 
100 public:
101  SGMap() = default;
102  SGMap(const WiLayout &layout, const WiData &data)
103  : wiLayout(layout), wiData(data) {}
104 
105  /// Two lattice values are equal if they have `some` layout. The actual
106  /// content of the layout does not matter.
107  bool operator==(const SGMap &other) const {
108  return this->isAssigned() == other.isAssigned();
109  }
110 
111  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
112 
113  static SGMap join(const SGMap &lhs, const SGMap &rhs);
114 
115  void print(raw_ostream &os) const;
116 
117  bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
118 
119  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
120 
121  const WiLayout &getLayout() const { return wiLayout; }
122  const WiData &getData() const { return wiData; }
123 };
124 
125 void SGMap::print(raw_ostream &os) const {
126  if (isAssigned()) {
127  os << "wi_layout: ";
128  wiLayout.print(os);
129  os << ", wi_data: ";
130  wiData.print(os);
131  } else
132  os << "Not assigned.";
133 }
134 
135 SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
136  if (!lhs.isAssigned())
137  return rhs;
138  return lhs;
139 }
140 
141 /// Since this is a backward analysis, join method is not used.
142 SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
143  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
144 }
145 
146 /// Get the transposed layout according to the given permutation.
147 SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
148  if (!isAssigned())
149  return {};
150  WiLayout newLayout;
151  WiData newData;
152  for (auto idx : permutation) {
153  newLayout.layout.push_back(wiLayout.layout[idx]);
154  newData.layout.push_back(wiData.layout[idx]);
155  }
156  return SGMap(newLayout, newData);
157 }
158 
159 ///===----------------------------------------------------------------------===///
160 /// SGMapLattice
161 ///===----------------------------------------------------------------------===///
162 
163 /// Lattice holding the SGMap for each value.
164 struct SGMapLattice : public Lattice<SGMap> {
166  using Lattice::Lattice;
167 };
168 
169 /// Helper Functions to get default layouts. A `default layout` is a layout that
170 /// is assigned to a value when the layout is not fixed by some anchor operation
171 /// (like DPAS). This is the natural layout work items are arranged in a
172 /// subgroup.
173 
174 /// Helper Function to get the default layout for uniform values like constants.
175 /// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
176 /// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
177 static SGMap getDefaultSgMap(unsigned rank) {
178  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
179  if (rank == 1)
180  return SGMap(WiLayout({subgroupSize}), WiData({1}));
181  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
182 }
183 
184 /// Helper to get the default layout for a vector type.
185 static SGMap getDefaultSgMap(VectorType vectorTy) {
186  /// Expecting a 1D or 2D vector.
187  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
188  "Expected 1D or 2D vector.");
189  /// Expecting int or float element type.
190  assert(vectorTy.getElementType().isIntOrFloat() &&
191  "Expected int or float element type.");
192  /// If the rank is 1, then return default layout for 1D vector.
193  if (vectorTy.getRank() == 1)
194  return getDefaultSgMap(1);
195  /// Packing factor is determined by the element type bitwidth.
196  int packingFactor = 1;
197  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
198  if (bitwidth < packedSizeInBitsForDefault)
199  packingFactor = packedSizeInBitsForDefault / bitwidth;
200  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
201 }
202 
203 /// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
204 /// set according to the following criteria:
205 /// * For A operand, the data must be packed in minimum
206 /// `packedSizeInBitsForDefault`
207 /// * For B operand, the data must be packed in minimum
208 /// `packedSizeInBitsForDpasB`
209 static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
210  auto elementTy = vectorTy.getElementType();
211  assert(elementTy.isIntOrFloat() &&
212  "Expected int or float type in DPAS operands");
213  WiLayout layout({1, subgroupSize});
214  /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
215  /// must have the VNNI format.
216  if (operandNum == 1 &&
217  elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
218  WiData data(
219  {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
220  return SGMap(layout, data);
221  }
222  /// Otherwise, return the default layout for the vector type.
223  return getDefaultSgMap(vectorTy);
224 }
225 
226 ///===----------------------------------------------------------------------===///
227 /// SGMapPropagation
228 ///===----------------------------------------------------------------------===///
229 
230 /// Backward data flow analysis to propagate the wi_layout and wi_data of each
231 /// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
232 /// StoreScatter are fixed (known before propagation). Purpose of this analysis
233 /// is to propagate those known layouts to all their producers and (other)
234 /// consumers.
235 class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
236 private:
237  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
239 
240  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
242 
243  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
244  ArrayRef<SGMapLattice *> operands,
246 
247  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
249 
250  void visitLoadGatherOp(xegpu::LoadGatherOp load,
251  ArrayRef<SGMapLattice *> operands,
253 
254  void visitTransposeOp(vector::TransposeOp transpose,
255  ArrayRef<SGMapLattice *> operands,
257 
258  void visitVectorBitcastOp(vector::BitCastOp bitcast,
259  ArrayRef<SGMapLattice *> operands,
261 
262  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
263  ArrayRef<SGMapLattice *> operands,
265 
266  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
267  ArrayRef<SGMapLattice *> operands,
269 
270  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
271  ArrayRef<SGMapLattice *> operands,
273 
274 public:
275  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
276  : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
278 
279  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
280  ArrayRef<const SGMapLattice *> results) override;
281 
282  void visitBranchOperand(OpOperand &operand) override {};
283 
284  void visitCallOperand(OpOperand &operand) override {};
285 
286  void visitExternalCall(CallOpInterface call,
287  ArrayRef<SGMapLattice *> operands,
288  ArrayRef<const SGMapLattice *> results) override {};
289 
290  void setToExitState(SGMapLattice *lattice) override {
291  (void)lattice->meet(SGMap());
292  }
293 };
294 } // namespace
295 
296 LogicalResult
297 SGMapPropagation::visitOperation(Operation *op,
298  ArrayRef<SGMapLattice *> operands,
301  .Case<xegpu::DpasOp>(
302  [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
303  .Case<xegpu::StoreNdOp>(
304  [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
305  .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
306  visitStoreScatterOp(storeScatterOp, operands, results);
307  })
308  .Case<xegpu::LoadNdOp>(
309  [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
310  .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
311  visitLoadGatherOp(loadGatherOp, operands, results);
312  })
313  .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
314  visitCreateDescOp(createDescOp, operands, results);
315  })
316  .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
317  visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
318  })
319  /// No need to propagate the layout to operands in CreateNdDescOp because
320  /// they are scalars (offsets, sizes, etc.).
321  .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
322  .Case<vector::TransposeOp>([&](auto transposeOp) {
323  visitTransposeOp(transposeOp, operands, results);
324  })
325  .Case<vector::BitCastOp>([&](auto bitcastOp) {
326  visitVectorBitcastOp(bitcastOp, operands, results);
327  })
328  .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
329  visitVectorMultiReductionOp(reductionOp, operands, results);
330  })
331  /// All other ops.
332  .Default([&](Operation *op) {
333  for (const SGMapLattice *r : results) {
334  for (SGMapLattice *operand : operands) {
335  /// Propagate the layout of the result to the operand.
336  if (r->getValue().isAssigned())
337  meet(operand, *r);
338  }
339  }
340  });
341  /// Add a dependency from each result to program point after the operation.
342  for (const SGMapLattice *r : results) {
343  addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
344  }
345  return success();
346 }
347 
348 void SGMapPropagation::visitVectorMultiReductionOp(
349  vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
351  /// The layout of the result must be present.
352  auto resultLayout = results[0]->getValue();
353  if (!resultLayout.isAssigned())
354  return;
355  /// We only consider 2D -> 1D reductions at this point.
356  assert(resultLayout.getLayout().size() == 1 &&
357  "Expected 1D layout for reduction result.");
358  /// Given that the result is 1D, the layout of the operand should be 2D with
359  /// default layout.
360  auto operandLayout = getDefaultSgMap(2);
361  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
362  /// Accumulator should have the same layout as the result.
363  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
364 }
365 
366 /// Propagate the layout of the result tensor to the source tensor descriptor in
367 /// UpdateNdOffsetOp.
368 void SGMapPropagation::visitUpdateNdOffsetOp(
369  xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
371  /// The layout of the result must be present.
372  auto resultLayout = results[0]->getValue();
373  if (!resultLayout.isAssigned())
374  return;
375  /// Propagate the layout to the source operand.
376  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
377 }
378 
379 /// Set the layouts for DPAS A, B, and C operands.
380 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
381  ArrayRef<SGMapLattice *> operands,
383  auto aTy = dpas.getLhsType();
384  auto bTy = dpas.getRhsType();
385  propagateIfChanged(operands[0],
386  operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
387  propagateIfChanged(operands[1],
388  operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
389  if (operands.size() > 2) {
390  auto cTy = dpas.getAccType();
391  propagateIfChanged(operands[2],
392  operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
393  }
394 }
395 
396 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
397 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
398  ArrayRef<SGMapLattice *> operands,
400  auto storeLayout = getDefaultSgMap(store.getValueType());
401  /// Both operands should have the same layout
402  for (SGMapLattice *operand : operands) {
403  propagateIfChanged(operand, operand->meet(storeLayout));
404  }
405 }
406 
407 /// Propagate the layout of the value to the tensor descriptor operand in
408 /// LoadNdOp.
409 void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
410  ArrayRef<SGMapLattice *> operands,
412  auto valueLayout = results[0]->getValue();
413  /// Need the layout of the value to propagate to the tensor descriptor.
414  if (!valueLayout.isAssigned())
415  return;
416  SGMap tensorDescLayout = valueLayout;
417  /// LoadNdOp has the transpose effect. However, at the stage of this analysis
418  /// this effect is not expected and should be abstracted away. Emit a warning.
419  if (auto transpose = load.getTranspose()) {
420  load.emitWarning("Transpose effect is not expected for LoadNdOp at "
421  "SGMapPropagation stage.");
422  tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
423  }
424  /// Propagate the new layout to the tensor descriptor operand.
425  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
426 }
427 
428 /// For vector::TransposeOp, the layout of the result is transposed and
429 /// propagated to the operand.
430 void SGMapPropagation::visitTransposeOp(
431  vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
433  /// Need the layout of transpose result to propagate to the operands.
434  auto resultLayout = results[0]->getValue();
435  if (!resultLayout.isAssigned())
436  return;
437  auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation());
438  /// Propagate the new layout to the vector operand.
439  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
440 }
441 
442 /// For vector::BitCastOp, the wi_data of the source layout is changed based on
443 /// the bit width of the source and result types.
444 void SGMapPropagation::visitVectorBitcastOp(
445  vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
447  /// Need the layout of bitcast result to propagate to the operands.
448  auto resultLayout = results[0]->getValue();
449  if (!resultLayout.isAssigned())
450  return;
451  auto inElemTyBitWidth =
452  bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
453  auto outElemTyBitWidth =
454  bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
455 
456  /// WiLayout does not change.
457  const WiLayout &newWiLayout = resultLayout.getLayout();
458  const WiData &currData = resultLayout.getData();
459  WiData newWiData;
460  /// It's a widening bitcast
461  if (inElemTyBitWidth < outElemTyBitWidth) {
462  auto ratio = outElemTyBitWidth / inElemTyBitWidth;
463  newWiData = resultLayout.getData()[0] == 1
464  ? WiData({1, currData[1] * ratio})
465  : WiData({currData[0] * ratio, 1});
466  } else {
467  /// It's a narrowing bitcast
468  auto ratio = inElemTyBitWidth / outElemTyBitWidth;
469  newWiData = resultLayout.getData()[0] == 1
470  ? WiData({1, currData[1] / ratio})
471  : WiData({currData[0] / ratio, 1});
472  }
473 
474  propagateIfChanged(operands[0],
475  operands[0]->meet(SGMap(newWiLayout, newWiData)));
476 }
477 
478 /// Propagate the layout of the result to the tensor descriptor and mask
479 /// operands in LoadGatherOp.
480 void SGMapPropagation::visitLoadGatherOp(
481  xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
483  auto valueLayout = results[0]->getValue();
484  /// Need the layout of the value to propagate to the tensor descriptor.
485  if (!valueLayout.isAssigned())
486  return;
487 
488  SGMap tensorDescLayout = valueLayout;
489  if (load.getTranspose()) {
490  /// LoadGatherOp has the transpose effect. However, at the stage of this
491  /// analyis this effect is not expected and should be abstracted away. Emit
492  /// a warning.
493  load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
494  "SGMapPropagation stage.");
495  tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
496  }
497  /// Mask operand should have 1D default layout.
498  auto maskLayout = getDefaultSgMap(1);
499  /// Propagate the new layout to the tensor descriptor operand.
500  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
501  /// Propagate the new layout to the mask operand.
502  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
503 }
504 
505 /// Propagate the layout of the descriptor to the vector offset operand in
506 /// CreateDescOp.
507 void SGMapPropagation::visitCreateDescOp(
508  xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
510  auto descLayout = results[0]->getValue();
511  /// Need the layout of the descriptor to propagate to the operands.
512  if (!descLayout.isAssigned())
513  return;
514  /// For offset operand propagate 1D default layout.
515  SGMap layout = getDefaultSgMap(1);
516  propagateIfChanged(operands[1], operands[1]->meet(layout));
517 }
518 
519 /// Set the layout for the value, tensor descriptor, and mask operands in the
520 /// StoreScatterOp.
521 void SGMapPropagation::visitStoreScatterOp(
522  xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
524  /// Currently, for 2D StoreScatterOp we expect that the height dimension of
525  /// the tensor descriptor is evenly divisible by the subgroup size.
526  /// TODO: Add support for other 2D shapes.
527  auto tdescShape = storeScatter.getTensorDescType().getShape();
528  if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) {
529  storeScatter.emitError("Height dimension of the tensor descriptor should "
530  "be evenly divisible by the subgroup size.");
531  return;
532  }
533  auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
534  SGMap storeScatterLayout = valueLayout;
535  if (storeScatter.getTranspose()) {
536  /// StoreScatteOp allows transpose effect. However, at the stage of this
537  /// analyis this effect is not expected and should be abstracted away. Emit
538  /// a warning.
539  storeScatter.emitWarning("Transpose effect is not expected for "
540  "StoreScatterOp at SGMapPropagation stage.");
541  storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
542  }
543  /// Propagate the value layout.
544  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
545  /// Propagate the tensor descriptor layout.
546  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
547  /// Use default 1D layout for mask operand.
548  auto maskLayout = getDefaultSgMap(1);
549  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
550 }
551 
552 namespace {
553 
554 ///===----------------------------------------------------------------------===///
555 /// RunSGMapPropagation
556 ///===----------------------------------------------------------------------===///
557 
558 /// Driver class for running the SGMapPropagation analysis.
559 class RunSGMapPropagation {
560 public:
561  RunSGMapPropagation(Operation *op) : target(op) {
562  SymbolTableCollection symbolTable;
563  solver.load<DeadCodeAnalysis>();
564  solver.load<SparseConstantPropagation>();
565  solver.load<SGMapPropagation>(symbolTable);
566  (void)solver.initializeAndRun(op);
567  }
568 
569  SGMap getSGMap(Value val);
570 
571  void printAnalysisResult(llvm::raw_ostream &os);
572 
573 private:
574  DataFlowSolver solver;
575  const Operation *target;
576 };
577 } // namespace
578 
579 SGMap RunSGMapPropagation::getSGMap(Value val) {
580  auto *state = solver.lookupState<SGMapLattice>(val);
581  if (!state)
582  return {};
583  return state->getValue();
584 }
585 
586 void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
587  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
588  os << "function: " << funcOp.getName() << ":\n";
589  // Function arguments
590  for (auto arg : funcOp.getArguments()) {
591  auto layout = getSGMap(arg);
592  os << "argument: " << arg << "\n";
593  os << "sg_map : ";
594  layout.print(os);
595  os << "\n";
596  }
597  // Function ops
598  funcOp.walk([&](Operation *op) {
599  // Skip ops that do not have results
600  if (op->getResults().empty())
601  return;
602  os << "op : ";
603  /// For control-flow ops, print the op name only.
604  if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
605  os << op->getName();
606  else
607  op->print(os);
608  os << "\n";
609  /// Print the sg_map for each result.
610  for (auto [i, r] : llvm::enumerate(op->getResults())) {
611  auto layout = getSGMap(r);
612  os << "sg_map for result #" << i << ": ";
613  layout.print(os);
614  os << "\n";
615  }
616  });
617  };
618 
620  if (auto modOp = dyn_cast<ModuleOp>(target)) {
621  for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
622  funcOps.push_back(funcOp);
623  }
624  /// Collect all GpuFuncOps in the module.
625  for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
626  for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
627  funcOps.push_back(gpuFuncOp);
628  }
629  }
630  }
631  /// Print the analysis result for each function.
632  for (auto funcOp : funcOps) {
633  printFunctionResult(funcOp);
634  }
635 }
636 
637 namespace {
638 struct XeGPUSubgroupDistributePass final
639  : public xegpu::impl::XeGPUSubgroupDistributeBase<
640  XeGPUSubgroupDistributePass> {
641  XeGPUSubgroupDistributePass() = default;
642  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
643  default;
644  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
645  : XeGPUSubgroupDistributeBase(options) {}
646  void runOnOperation() override;
647 };
648 } // namespace
649 
650 void XeGPUSubgroupDistributePass::runOnOperation() {
651  Operation *op = getOperation();
652  RunSGMapPropagation solver(op);
653 
654  // Print the analysis result and exit.
655  if (printOnly) {
656  auto &os = llvm::outs();
657  solver.printAnalysisResult(os);
658  return;
659  }
660 }
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
constexpr unsigned subgroupSize
HW dependent constants.
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
The general data-flow analysis solver.
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getResults()
Definition: Operation.h:415
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...