MLIR  22.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1 //===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11 // This file defines the uArch details for Xe2 and its derived architectures.
12 // This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13 //
14 //===----------------------------------------------------------------------===//
15 #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16 #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/DebugLog.h"
23 #include <map>
24 #include <string>
25 
26 using namespace mlir;
27 using namespace mlir::xegpu::uArch;
28 
29 namespace mlir {
30 namespace xegpu {
31 namespace uArch {
32 
33 struct Xe2Plus : public uArch {
34  Xe2Plus(StringRef archName, StringRef archDescription,
35  llvm::ArrayRef<const Instruction *> instructionRegistry,
36  const XeCoreInfo &xeCore)
37  : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38  int getSubgroupSize() const override { return 16; }
39  unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40 
41 protected:
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // uArch instructions
47 //===----------------------------------------------------------------------===//
52  static bool classof(const Instruction *B) {
53  return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54  }
55  // Source :
56  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57  std::optional<
58  std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60  const static int kHeight[] = {1, 2, 4, 8};
61  const static int kWidth16[] = {16};
62  const static int kWidth32[] = {16};
63  const static int kCount[] = {1};
64  const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65  if (elemByteSize == 1)
66  return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67  llvm::ArrayRef<int>(kHeight),
68  llvm::ArrayRef<int>(kCount));
69  else if (elemByteSize == 2 || elemByteSize == 4)
70  return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71  llvm::ArrayRef<int>(kHeight),
72  llvm::ArrayRef<int>(kCount));
73  return std::nullopt;
74  }
75 
76  int32_t getPackedFormatBitSize() const { return 16; }
77 };
78 
83  static bool classof(const Instruction *B) {
84  return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85  }
86 
87  // Source :
88  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89  std::optional<
90  std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91  getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92  bool upConv = false) const {
93  static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94  static const int kHeightAtLeast8[] = {8, 16, 32};
95  static const int kHeightAtLeast16[] = {16, 32};
96  static const int kHeightAtLeast32[] = {32};
97 
98  static const int kWidth32[] = {32};
99  static const int kWidth16[] = {16};
100  static const int kWidth8[] = {8};
101 
102  static const int32_t kCount1[] = {1};
103  static const int32_t kCount2[] = {1, 2};
104  static const int32_t kCount4[] = {1, 2, 4};
105  static const int32_t kCount4Only[] = {4};
106  // (elemBytes, transform, transpose, upConvert)
107  using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108  // (widths, heights, counts)
109  using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111  static const llvm::DenseMap<Key, Value> kMap = {
112  {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113  {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114  {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115  {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116  // Block Loads with Transform:
117  {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118  {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119  // Block Loads with Transpose:
120  {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121  };
122  const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123  auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124  if (it != kMap.end())
125  return it->second;
126  return std::nullopt;
127  }
128 
129  int32_t getPackedFormatBitSize() const { return 16; }
130 };
131 
136  static bool classof(const Instruction *B) {
137  return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138  }
139  // Source :
140  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141  std::optional<
142  std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144  static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145 
146  static const int kWidth32[] = {32};
147  static const int kWidth16[] = {16};
148 
149  static const int32_t kCount1[] = {1};
150  static const int32_t kCount2[] = {1, 2};
151  // elemBytes
152  using Key = int;
153  // (widths, heights, counts)
154  using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156  static const llvm::DenseMap<Key, Value> kMap = {
157  {1, {kWidth32, kHeightAtLeast1, kCount2}},
158  {2, {kWidth16, kHeightAtLeast1, kCount2}},
159  {4, {kWidth16, kHeightAtLeast1, kCount1}},
160  };
161  const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162  auto it = kMap.find(elemByteSize);
163  if (it != kMap.end())
164  return it->second;
165  return std::nullopt;
166  }
167  int32_t getPackedFormatBitSize() const { return 16; }
168 };
169 
171  public MMAInstructionInterface {
172  SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
173  unsigned packedFormatBitSizeB)
176  packedFormatBitSizeA(packedFormatBitSizeA),
177  packedFormatBitSizeB(packedFormatBitSizeB) {}
178  static bool classof(const Instruction *B) {
179  return B->getInstructionKind() ==
181  }
182  // Source:
183  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184 
185  // Override all virtuals from MatrixOpInterface
187  getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190  virtual bool
191  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192  std::pair<uint32_t, uint32_t> BShape,
193  std::pair<uint32_t, uint32_t> CShape,
194  std::pair<uint32_t, uint32_t> DShape, Type AType,
195  Type BType, Type CType, Type DType) override;
196  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197  Type DType) override;
198  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199  std::pair<uint32_t, uint32_t> BShape,
200  std::pair<uint32_t, uint32_t> CShape,
201  std::pair<uint32_t, uint32_t> DShape, Type AType,
202  Type BType, Type CType, Type DType) override;
204  getSupportedM(Type type) const override;
206  getSupportedK(Type type) const override;
208  getSupportedN(Type type) const override;
209 
210  unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
211  unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
212 
213 protected:
214  const unsigned packedFormatBitSizeA;
215  const unsigned packedFormatBitSizeB;
216 };
217 
218 //===----------------------------------------------------------------------===//
219 // uArch instances
220 //===----------------------------------------------------------------------===//
221 
222 struct PVCuArch final : public Xe2Plus {
224  static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
225  static const Subgroup2DBlockLoadInstruction loadNdInst;
226  static const Subgroup2DBlockStoreInstruction storeNdInst;
227  static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
228  static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
229  &prefetchNdInst};
230  return arr;
231  }
232 
234  : Xe2Plus("pvc", // archName
235  "Ponte Vecchio Architecture", // archDescription
236  getInstructionRegistryArr(),
237  XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
238  ) {}
239  static const uArch *getInstance() {
240  static const PVCuArch instance;
241  return reinterpret_cast<const uArch *>(&instance);
242  }
243 };
244 
245 struct BMGuArch : public Xe2Plus {
247  static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
248  static const Subgroup2DBlockLoadInstruction loadNdInst;
249  static const Subgroup2DBlockStoreInstruction storeNdInst;
250  static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
251  static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
252  &prefetchNdInst};
253  return arr;
254  }
255 
257  : Xe2Plus("bmg", // archName
258  "Battlemage Architecture", // archDescription
259  getInstructionRegistryArr(),
260  XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
261  ) {}
262  static const uArch *getInstance() {
263  static const BMGuArch instance;
264  return reinterpret_cast<const uArch *>(&instance);
265  }
266 };
267 
268 inline const uArch *getUArch(llvm::StringRef archName) {
269  if (archName.equals_insensitive("pvc"))
270  return PVCuArch::getInstance();
271  else if (archName.equals_insensitive("bmg"))
272  return BMGuArch::getInstance();
273  else
274  llvm_unreachable("No matching uArch found");
275 
276  return nullptr;
277 }
278 
279 } // namespace uArch
280 } // namespace xegpu
281 } // namespace mlir
282 
283 //===----------------------------------------------------------------------===//
284 // Instruction implementations
285 //===----------------------------------------------------------------------===//
286 
289  MMAOpndKind matrixType) {
290  auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
292  -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
294  for (unsigned x : a) {
295  for (unsigned y : b) {
296  result.emplace_back(x, y);
297  }
298  }
299  return result;
300  };
301 
302  auto M = getSupportedM(dataType);
303  auto K = getSupportedK(dataType);
304  auto N = getSupportedN(dataType);
306 
307  switch (matrixType) {
309  resultMatrix = combineVectors(M, K);
310  break;
312  resultMatrix = combineVectors(K, N);
313  break;
315  resultMatrix = combineVectors(M, N);
316  break;
318  resultMatrix = combineVectors(M, N);
319  break;
320  }
321  return resultMatrix;
322 }
323 
326  MMAOpndKind matrixType) {
327  Type bf16Type = BFloat16Type::get(&context);
328  Type f16Type = Float16Type::get(&context);
329  Type tf32Type = FloatTF32Type::get(&context);
330  Type f32Type = Float32Type::get(&context);
331 
332  switch (matrixType) {
334  return {bf16Type, f16Type, tf32Type};
336  return {bf16Type, f16Type, tf32Type};
338  return {bf16Type, f16Type, f32Type};
340  return {bf16Type, f16Type, f32Type};
341  }
342  return {};
343 }
344 
346  Type BType,
347  Type CType,
348  Type DType) {
349  if (AType.isF16() || BType.isF16()) {
350  if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
351  (!DType.isF32() && !DType.isF16())) {
352  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
353  return false;
354  }
355  } else if (AType.isBF16() || BType.isBF16()) {
356  if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
357  (!DType.isF32() && !DType.isBF16())) {
358  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
359  return false;
360  }
361  } else if (AType.isTF32() || BType.isTF32()) {
362  if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
363  (!DType.isF32())) {
364  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
365  return false;
366  }
367  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
368  AType.isInteger(8)) &&
369  !(BType.isInteger(2) || BType.isInteger(4) ||
370  BType.isInteger(8))) {
371  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
372  return false;
373  }
374 
375  return true;
376 }
377 
379  std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
380  std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
381  Type AType, Type BType, Type CType, Type DType) {
382  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
383  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
384  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
385  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
386  return llvm::is_contained(supportedAShapes, AShape) &&
387  llvm::is_contained(supportedBShapes, BShape) &&
388  llvm::is_contained(supportedCShapes, CShape) &&
389  llvm::is_contained(supportedDShapes, DShape) &&
390  checkSupportedTypes(AType, BType, CType, DType);
391 }
392 
394  std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
395  std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
396  Type AType, Type BType, Type CType, Type DType) {
397  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
398  BType, CType, DType);
399 }
400 
403  return {1, 2, 3, 4, 5, 6, 7, 8};
404 }
405 
408  // assert if data type is not int or float type
409  assert(type.isIntOrFloat() && "Matrix type must be int or float");
410  auto bitWidth = type.getIntOrFloatBitWidth();
411  uint32_t kSize = 0;
412  switch (bitWidth) {
413  case 2:
414  kSize = 64;
415  break;
416  case 4:
417  kSize = 64;
418  break;
419  case 8:
420  kSize = 32;
421  break;
422  case 16:
423  kSize = 16;
424  break;
425  case 32:
426  kSize = 8;
427  break;
428  default:
429  llvm_unreachable("Invalid int or float");
430  }
431  return {kSize};
432 }
433 
436  return {16};
437 }
438 
439 #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isTF32() const
Definition: Types.cpp:39
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Definition: IntelGpuXe2.h:268
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static const uArch * getInstance()
Definition: IntelGpuXe2.h:262
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
Definition: IntelGpuXe2.h:246
static const uArch * getInstance()
Definition: IntelGpuXe2.h:239
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
Definition: IntelGpuXe2.h:223
static bool classof(const Instruction *B)
Definition: IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition: IntelGpuXe2.h:91
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition: IntelGpuXe2.h:143
static bool classof(const Instruction *B)
Definition: IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition: IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
Definition: IntelGpuXe2.h:288
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
Definition: IntelGpuXe2.h:435
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:393
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
Definition: IntelGpuXe2.h:402
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
Definition: IntelGpuXe2.h:407
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
Definition: IntelGpuXe2.h:172
static bool classof(const Instruction *B)
Definition: IntelGpuXe2.h:178
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
Definition: IntelGpuXe2.h:325
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:378
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:345
unsigned getGeneralPackedFormatBitSize() const override
Definition: IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition: IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition: IntelGpuXe2.h:34