MLIR  22.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1 //===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11 // This file defines the uArch details for Xe2 and its derived architectures.
12 // This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13 //
14 //===----------------------------------------------------------------------===//
15 #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16 #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17 
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/DebugLog.h"
23 #include <map>
24 #include <string>
25 
26 #define DEBUG_TYPE "xegpu-uarch"
27 
28 using namespace mlir;
29 using namespace mlir::xegpu::uArch;
30 
31 namespace mlir {
32 namespace xegpu {
33 namespace uArch {
34 
35 struct Xe2Plus : public uArch {
37  Xe2Plus(const std::string &archName, const std::string &archDescription,
38  const XeCoreInfo &xeCore,
39  const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
40  const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
41  const std::map<InstructionKind, std::shared_ptr<Instruction>>
42  &instrs = {})
43  : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
44  xeCore(xeCore) {}
45 };
46 
47 // struct to represent DPAS instruction
51 
52  // Override all virtuals from MatrixOpInterface
54  getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
56  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
57  virtual bool
58  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
59  std::pair<uint32_t, uint32_t> BShape,
60  std::pair<uint32_t, uint32_t> CShape,
61  std::pair<uint32_t, uint32_t> DShape, Type AType,
62  Type BType, Type CType, Type DType) override;
63  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
64  Type DType) override;
65  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
66  std::pair<uint32_t, uint32_t> BShape,
67  std::pair<uint32_t, uint32_t> CShape,
68  std::pair<uint32_t, uint32_t> DShape, Type AType,
69  Type BType, Type CType, Type DType) override;
70  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
71  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
72  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
73 };
74 
75 struct PVCuArch : public Xe2Plus {
76  // Maintaines ownership of the instructions owned by PVUarch
79  : Xe2Plus("pvc", // archName
80  "Ponte Vecchio Architecture", // archDescription
81  XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
82  {/* registerFileInfo */}, // Optional: empty
83  {/* cacheInfo */}, // Optional: empty
84  {/* instructions */} // Optional: empty
85  ) {
86  // Intialize register file info
87  // GRF
88  this->registerFileInfo.emplace(
91  64 * 1024, // size in bits
93  {128, 256} // registers per thread per mode
94  ));
95  // Initialize cache info
96  // L1 cache, XeCore level
97  this->cacheInfo.push_back(
98  CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
99  // L2 cache, XeStack level
100  this->cacheInfo.push_back(
101  CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
102 
103  // Add the instructions-
104  auto dpas = std::make_shared<DPASInstruction>();
105  instructions.emplace(dpas->getInstructionKind(), dpas);
106  owned_instructions.push_back(dpas);
107  }
108 };
109 
110 struct BMGuArch : public Xe2Plus {
111  // Maintaines ownership of the instructions owned by PVUarch
114  : Xe2Plus("bmg", // archName
115  "Battlemage Architecture", // archDescription
116  XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
117  {/* registerFileInfo */}, // Optional: empty
118  {/* cacheInfo */}, // Optional: empty
119  {/* instructions */} // Optional: empty
120  ) {
121  // Intialize register file info
122  // GRF
123  this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
124  64 * 1024, // size in bits
126  {128, 256} // registers per thread per mode
127  );
128  // Initialize cache info
129  // L1 cache, XeCore level
130  this->cacheInfo.push_back(
131  CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
132  // L2 cache, XeStack level
133  this->cacheInfo.push_back(
134  CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
135 
136  // Add the instructions
137  auto dpas = std::make_shared<DPASInstruction>();
138  instructions.emplace(dpas->getInstructionKind(), dpas);
139  owned_instructions.push_back(dpas);
140  }
141 };
142 } // namespace uArch
143 } // namespace xegpu
144 } // namespace mlir
145 
148  auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
150  -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
152  for (unsigned x : a) {
153  for (unsigned y : b) {
154  result.emplace_back(x, y);
155  }
156  }
157  return result;
158  };
159 
160  auto M = getSupportedM(dataType);
161  auto K = getSupportedK(dataType);
162  auto N = getSupportedN(dataType);
164 
165  switch (matrixType) {
167  resultMatrix = combineVectors(M, K);
168  break;
170  resultMatrix = combineVectors(K, N);
171  break;
173  resultMatrix = combineVectors(M, N);
174  break;
176  resultMatrix = combineVectors(M, N);
177  break;
178  }
179  return resultMatrix;
180 }
181 
184  MMAOpndKind matrixType) {
185  Type bf16Type = BFloat16Type::get(&context);
186  Type f16Type = Float16Type::get(&context);
187  Type tf32Type = FloatTF32Type::get(&context);
188  Type f32Type = Float32Type::get(&context);
189 
190  switch (matrixType) {
192  return {bf16Type, f16Type, tf32Type};
194  return {bf16Type, f16Type, tf32Type};
196  return {bf16Type, f16Type, f32Type};
198  return {bf16Type, f16Type, f32Type};
199  }
200  return {};
201 }
202 
204  Type CType, Type DType) {
205  if (AType.isF16() || BType.isF16()) {
206  if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
207  (!DType.isF32() && !DType.isF16())) {
208  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
209  return false;
210  }
211  } else if (AType.isBF16() || BType.isBF16()) {
212  if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
213  (!DType.isF32() && !DType.isBF16())) {
214  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
215  return false;
216  }
217  } else if (AType.isTF32() || BType.isTF32()) {
218  if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
219  (!DType.isF32())) {
220  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
221  return false;
222  }
223  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
224  AType.isInteger(8)) &&
225  !(BType.isInteger(2) || BType.isInteger(4) ||
226  BType.isInteger(8))) {
227  LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
228  return false;
229  }
230 
231  return true;
232 }
233 
235  std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
236  std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
237  Type AType, Type BType, Type CType, Type DType) {
238  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
239  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
240  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
241  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
242  return llvm::is_contained(supportedAShapes, AShape) &&
243  llvm::is_contained(supportedBShapes, BShape) &&
244  llvm::is_contained(supportedCShapes, CShape) &&
245  llvm::is_contained(supportedDShapes, DShape) &&
246  checkSupportedTypes(AType, BType, CType, DType);
247 }
248 
249 inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
250  std::pair<uint32_t, uint32_t> BShape,
251  std::pair<uint32_t, uint32_t> CShape,
252  std::pair<uint32_t, uint32_t> DShape,
253  Type AType, Type BType, Type CType,
254  Type DType) {
255  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
256  BType, CType, DType);
257 }
258 
261  return {1, 2, 3, 4, 5, 6, 7, 8};
262 }
263 
266  // assert if data type is not int or float type
267  assert(type.isIntOrFloat() && "Matrix type must be int or float");
268  auto bitWidth = type.getIntOrFloatBitWidth();
269  uint32_t kSize = 0;
270  switch (bitWidth) {
271  case 2:
272  kSize = 64;
273  break;
274  case 4:
275  kSize = 64;
276  break;
277  case 8:
278  kSize = 32;
279  break;
280  case 16:
281  kSize = 16;
282  break;
283  case 32:
284  kSize = 8;
285  break;
286  default:
287  llvm_unreachable("Invalid int or float");
288  }
289  return {kSize};
290 }
291 
294  return {16};
295 }
296 
297 #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isTF32() const
Definition: Types.cpp:39
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallVector< std::shared_ptr< Instruction >, 8 > owned_instructions
Definition: IntelGpuXe2.h:112
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) override
Definition: IntelGpuXe2.h:293
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
Definition: IntelGpuXe2.h:183
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) override
Definition: IntelGpuXe2.h:260
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:203
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:234
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
Definition: IntelGpuXe2.h:249
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) override
Definition: IntelGpuXe2.h:265
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
Definition: IntelGpuXe2.h:147
llvm::SmallVector< std::shared_ptr< Instruction >, 8 > owned_instructions
Definition: IntelGpuXe2.h:77
Xe2Plus(const std::string &archName, const std::string &archDescription, const XeCoreInfo &xeCore, const std::map< RegisterFileType, RegisterFileInfo > &regInfo={}, const llvm::SmallVector< CacheInfo, 4 > &cacheInfo={}, const std::map< InstructionKind, std::shared_ptr< Instruction >> &instrs={})
Definition: IntelGpuXe2.h:37