MLIR 23.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11// This file defines the uArch details for Xe2 and its derived architectures.
12// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13//
14//===----------------------------------------------------------------------===//
15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
23#include <map>
24#include <string>
25
26using namespace mlir;
27using namespace mlir::xegpu::uArch;
28
29namespace mlir {
30namespace xegpu {
31namespace uArch {
32
33struct Xe2Plus : public uArch {
34 Xe2Plus(StringRef archName, StringRef archDescription,
36 const XeCoreInfo &xeCore)
37 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38 int getSubgroupSize() const override { return 16; }
39 unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40
41protected:
43};
44
45//===----------------------------------------------------------------------===//
46// uArch instructions
47//===----------------------------------------------------------------------===//
52 static bool classof(const Instruction *B) {
53 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54 }
55 // Source :
56 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57 std::optional<
58 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
64 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65 if (elemByteSize == 1)
66 return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67 llvm::ArrayRef<int>(kHeight),
68 llvm::ArrayRef<int>(kCount));
69 else if (elemByteSize == 2 || elemByteSize == 4)
70 return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71 llvm::ArrayRef<int>(kHeight),
72 llvm::ArrayRef<int>(kCount));
73 return std::nullopt;
74 }
75
76 int32_t getPackedFormatBitSize() const { return 16; }
77};
78
83 static bool classof(const Instruction *B) {
84 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85 }
86
87 // Source :
88 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89 std::optional<
90 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91 getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92 bool upConv = false) const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
97
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
101
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
106 // (elemBytes, transform, transpose, upConvert)
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108 // (widths, heights, counts)
109 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111 static const llvm::DenseMap<Key, Value> kMap = {
112 {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116 // Block Loads with Transform:
117 {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119 // Block Loads with Transpose:
120 {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121 };
122 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
125 return it->second;
126 return std::nullopt;
127 }
128
129 int32_t getPackedFormatBitSize() const { return 16; }
130};
131
136 static bool classof(const Instruction *B) {
137 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138 }
139 // Source :
140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141 std::optional<
142 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
148
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
151 // elemBytes
152 using Key = int;
153 // (widths, heights, counts)
154 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156 static const llvm::DenseMap<Key, Value> kMap = {
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
160 };
161 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
164 return it->second;
165 return std::nullopt;
166 }
167 int32_t getPackedFormatBitSize() const { return 16; }
168};
169
178 static bool classof(const Instruction *B) {
179 return B->getInstructionKind() ==
181 }
182 // Source:
183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184
185 // Override all virtuals from MatrixOpInterface
187 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190 virtual bool
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape, Type AType,
195 Type BType, Type CType, Type DType) override;
196 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197 Type DType) override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape, Type AType,
202 Type BType, Type CType, Type DType) override;
204 getSupportedM(Type type) const override;
206 getSupportedK(Type type) const override;
208 getSupportedN(Type type) const override;
209
212
213protected:
214 const unsigned packedFormatBitSizeA;
215 const unsigned packedFormatBitSizeB;
216};
217
219 int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
220};
221
223 int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
224};
225
226//===----------------------------------------------------------------------===//
227// uArch instances
228//===----------------------------------------------------------------------===//
229
230struct PVCuArch final : public Xe2Plus {
232 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
233 static const Subgroup2DBlockLoadInstruction loadNdInst;
234 static const Subgroup2DBlockStoreInstruction storeNdInst;
235 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
236 static const SpirvStoreScatterInstruction storeScatterInst;
237 static const SpirvLoadGatherInstruction loadGatherInst;
238 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
239 &storeNdInst, &prefetchNdInst,
240 &storeScatterInst, &loadGatherInst};
241 return arr;
242 }
243
245 : Xe2Plus("pvc", // archName
246 "Ponte Vecchio Architecture", // archDescription
248 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
249 ) {}
250 static const uArch *getInstance() {
251 static const PVCuArch instance;
252 return reinterpret_cast<const uArch *>(&instance);
253 }
254};
255
256struct BMGuArch : public Xe2Plus {
258 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
259 static const Subgroup2DBlockLoadInstruction loadNdInst;
260 static const Subgroup2DBlockStoreInstruction storeNdInst;
261 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
262 static const SpirvStoreScatterInstruction storeScatterInst;
263 static const SpirvLoadGatherInstruction loadGatherInst;
264 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
265 &storeNdInst, &prefetchNdInst,
266 &storeScatterInst, &loadGatherInst};
267 return arr;
268 }
269
271 : Xe2Plus("bmg", // archName
272 "Battlemage Architecture", // archDescription
274 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
275 ) {}
276 static const uArch *getInstance() {
277 static const BMGuArch instance;
278 return reinterpret_cast<const uArch *>(&instance);
279 }
280};
281
282struct CRIuArch : public Xe2Plus {
284 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
285 static const Subgroup2DBlockLoadInstruction loadNdInst;
286 static const Subgroup2DBlockStoreInstruction storeNdInst;
287 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
288 static const SpirvStoreScatterInstruction storeScatterInst;
289 static const SpirvLoadGatherInstruction loadGatherInst;
290 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
291 &storeNdInst, &prefetchNdInst,
292 &storeScatterInst, &loadGatherInst};
293 return arr;
294 }
295
297 : Xe2Plus("cri", // archName
298 "Crescent Island Architecture", // archDescription
300 // Using bmg config as placeholder
301 // TODO: Update to actual XeCore and SharedMemory config
302 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
303 ) {}
304 static const uArch *getInstance() {
305 static const CRIuArch instance;
306 return reinterpret_cast<const uArch *>(&instance);
307 }
308};
309
310inline const uArch *getUArch(llvm::StringRef archName) {
311 if (archName.equals_insensitive("pvc"))
312 return PVCuArch::getInstance();
313 if (archName.equals_insensitive("bmg"))
314 return BMGuArch::getInstance();
315 if (archName.equals_insensitive("cri"))
316 return CRIuArch::getInstance();
317 return nullptr;
318}
319
320} // namespace uArch
321} // namespace xegpu
322} // namespace mlir
323
324//===----------------------------------------------------------------------===//
325// Instruction implementations
326//===----------------------------------------------------------------------===//
327
330 MMAOpndKind matrixType) {
331 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
333 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
335 for (unsigned x : a) {
336 for (unsigned y : b) {
337 result.emplace_back(x, y);
338 }
339 }
340 return result;
341 };
342
343 auto M = getSupportedM(dataType);
344 auto K = getSupportedK(dataType);
345 auto N = getSupportedN(dataType);
347
348 switch (matrixType) {
350 resultMatrix = combineVectors(M, K);
351 break;
353 resultMatrix = combineVectors(K, N);
354 break;
356 resultMatrix = combineVectors(M, N);
357 break;
359 resultMatrix = combineVectors(M, N);
360 break;
361 }
362 return resultMatrix;
363}
364
367 MMAOpndKind matrixType) {
368 Type bf16Type = BFloat16Type::get(&context);
369 Type f16Type = Float16Type::get(&context);
370 Type tf32Type = FloatTF32Type::get(&context);
371 Type f32Type = Float32Type::get(&context);
372
373 switch (matrixType) {
375 return {bf16Type, f16Type, tf32Type};
377 return {bf16Type, f16Type, tf32Type};
379 return {bf16Type, f16Type, f32Type};
381 return {bf16Type, f16Type, f32Type};
382 }
383 return {};
384}
385
387 Type BType,
388 Type CType,
389 Type DType) {
390 if (AType.isF16() || BType.isF16()) {
391 if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
392 (!DType.isF32() && !DType.isF16())) {
393 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
394 return false;
395 }
396 } else if (AType.isBF16() || BType.isBF16()) {
397 if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
398 (!DType.isF32() && !DType.isBF16())) {
399 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
400 return false;
401 }
402 } else if (AType.isTF32() || BType.isTF32()) {
403 if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
404 (!DType.isF32())) {
405 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
406 return false;
407 }
408 } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
409 AType.isInteger(8)) &&
410 !(BType.isInteger(2) || BType.isInteger(4) ||
411 BType.isInteger(8))) {
412 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
413 return false;
414 }
415
416 return true;
417}
418
420 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
421 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
422 Type AType, Type BType, Type CType, Type DType) {
423 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
424 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
425 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
426 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
427 return llvm::is_contained(supportedAShapes, AShape) &&
428 llvm::is_contained(supportedBShapes, BShape) &&
429 llvm::is_contained(supportedCShapes, CShape) &&
430 llvm::is_contained(supportedDShapes, DShape) &&
431 checkSupportedTypes(AType, BType, CType, DType);
432}
433
435 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
436 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
437 Type AType, Type BType, Type CType, Type DType) {
438 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
439 BType, CType, DType);
440}
441
444 return {1, 2, 3, 4, 5, 6, 7, 8};
445}
446
449 // assert if data type is not int or float type
450 assert(type.isIntOrFloat() && "Matrix type must be int or float");
451 auto bitWidth = type.getIntOrFloatBitWidth();
452 uint32_t kSize = 0;
453 switch (bitWidth) {
454 case 2:
455 kSize = 64;
456 break;
457 case 4:
458 kSize = 64;
459 break;
460 case 8:
461 kSize = 32;
462 break;
463 case 16:
464 kSize = 16;
465 break;
466 case 32:
467 kSize = 8;
468 break;
469 default:
470 llvm_unreachable("Invalid int or float");
471 }
472 return {kSize};
473}
474
477 return {16};
478}
479
480#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
static const uArch * getInstance()
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:53
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition IntelGpuXe2.h:91
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
unsigned getGeneralPackedFormatBitSize() const override
Definition IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition IntelGpuXe2.h:34
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:178
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:151