MLIR 23.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11// This file defines the uArch details for Xe2 and its derived architectures.
12// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13//
14//===----------------------------------------------------------------------===//
15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
23#include <map>
24#include <string>
25
26using namespace mlir;
27using namespace mlir::xegpu::uArch;
28
29namespace mlir {
30namespace xegpu {
31namespace uArch {
32
33struct Xe2Plus : public uArch {
34 Xe2Plus(StringRef archName, StringRef archDescription,
36 const XeCoreInfo &xeCore)
37 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38 int getSubgroupSize() const override { return 16; }
39 unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40
41protected:
43};
44
45//===----------------------------------------------------------------------===//
46// uArch instructions
47//===----------------------------------------------------------------------===//
52 static bool classof(const Instruction *B) {
53 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54 }
55 // Source :
56 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57 std::optional<
58 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
64 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65 if (elemByteSize == 1)
66 return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67 llvm::ArrayRef<int>(kHeight),
68 llvm::ArrayRef<int>(kCount));
69 else if (elemByteSize == 2 || elemByteSize == 4)
70 return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71 llvm::ArrayRef<int>(kHeight),
72 llvm::ArrayRef<int>(kCount));
73 return std::nullopt;
74 }
75
76 int32_t getPackedFormatBitSize() const { return 16; }
77};
78
83 static bool classof(const Instruction *B) {
84 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85 }
86
87 // Source :
88 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89 std::optional<
90 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91 getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92 bool upConv = false) const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
97
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
101
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
106 // (elemBytes, transform, transpose, upConvert)
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108 // (widths, heights, counts)
109 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111 static const llvm::DenseMap<Key, Value> kMap = {
112 {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116 // Block Loads with Transform:
117 {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119 // Block Loads with Transpose:
120 {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121 };
122 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
125 return it->second;
126 return std::nullopt;
127 }
128
129 int32_t getPackedFormatBitSize() const { return 16; }
130};
131
136 static bool classof(const Instruction *B) {
137 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138 }
139 // Source :
140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141 std::optional<
142 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
148
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
151 // elemBytes
152 using Key = int;
153 // (widths, heights, counts)
154 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156 static const llvm::DenseMap<Key, Value> kMap = {
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
160 };
161 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
164 return it->second;
165 return std::nullopt;
166 }
167 int32_t getPackedFormatBitSize() const { return 16; }
168};
169
178 static bool classof(const Instruction *B) {
179 return B->getInstructionKind() ==
181 }
182 // Source:
183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184
185 // Override all virtuals from MatrixOpInterface
187 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190 virtual bool
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape, Type AType,
195 Type BType, Type CType, Type DType) override;
196 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197 Type DType) override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape, Type AType,
202 Type BType, Type CType, Type DType) override;
204 getSupportedM(Type type) const override;
206 getSupportedK(Type type) const override;
208 getSupportedN(Type type) const override;
209
212
213protected:
214 const unsigned packedFormatBitSizeA;
215 const unsigned packedFormatBitSizeB;
216};
217
219 int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
220};
221
223 int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
224};
225
227 int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
228};
229
231 int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
232};
233
234//===----------------------------------------------------------------------===//
235// uArch instances
236//===----------------------------------------------------------------------===//
237
238struct PVCuArch final : public Xe2Plus {
240 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
241 static const Subgroup2DBlockLoadInstruction loadNdInst;
242 static const Subgroup2DBlockStoreInstruction storeNdInst;
243 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
244 static const SpirvStoreScatterInstruction storeScatterInst;
245 static const SpirvLoadGatherInstruction loadGatherInst;
246 static const StoreMatrixInstruction storeMatrixInst;
247 static const LoadMatrixInstruction loadMatrixInst;
248 static const Instruction *arr[] = {
249 &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
250 &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
251 return arr;
252 }
253
255 : Xe2Plus("pvc", // archName
256 "Ponte Vecchio Architecture", // archDescription
258 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
259 ) {}
260 static const uArch *getInstance() {
261 static const PVCuArch instance;
262 return reinterpret_cast<const uArch *>(&instance);
263 }
264};
265
266struct BMGuArch : public Xe2Plus {
268 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
269 static const Subgroup2DBlockLoadInstruction loadNdInst;
270 static const Subgroup2DBlockStoreInstruction storeNdInst;
271 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
272 static const SpirvStoreScatterInstruction storeScatterInst;
273 static const SpirvLoadGatherInstruction loadGatherInst;
274 static const StoreMatrixInstruction storeMatrixInst;
275 static const LoadMatrixInstruction loadMatrixInst;
276 static const Instruction *arr[] = {
277 &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
278 &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
279 return arr;
280 }
281
283 : Xe2Plus("bmg", // archName
284 "Battlemage Architecture", // archDescription
286 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
287 ) {}
288 static const uArch *getInstance() {
289 static const BMGuArch instance;
290 return reinterpret_cast<const uArch *>(&instance);
291 }
292};
293
294inline const uArch *getUArch(llvm::StringRef archName) {
295 if (archName.equals_insensitive("pvc"))
296 return PVCuArch::getInstance();
297 else if (archName.equals_insensitive("bmg"))
298 return BMGuArch::getInstance();
299 else
300 llvm_unreachable("No matching uArch found");
301
302 return nullptr;
303}
304
305} // namespace uArch
306} // namespace xegpu
307} // namespace mlir
308
309//===----------------------------------------------------------------------===//
310// Instruction implementations
311//===----------------------------------------------------------------------===//
312
315 MMAOpndKind matrixType) {
316 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
318 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
320 for (unsigned x : a) {
321 for (unsigned y : b) {
322 result.emplace_back(x, y);
323 }
324 }
325 return result;
326 };
327
328 auto M = getSupportedM(dataType);
329 auto K = getSupportedK(dataType);
330 auto N = getSupportedN(dataType);
332
333 switch (matrixType) {
335 resultMatrix = combineVectors(M, K);
336 break;
338 resultMatrix = combineVectors(K, N);
339 break;
341 resultMatrix = combineVectors(M, N);
342 break;
344 resultMatrix = combineVectors(M, N);
345 break;
346 }
347 return resultMatrix;
348}
349
352 MMAOpndKind matrixType) {
353 Type bf16Type = BFloat16Type::get(&context);
354 Type f16Type = Float16Type::get(&context);
355 Type tf32Type = FloatTF32Type::get(&context);
356 Type f32Type = Float32Type::get(&context);
357
358 switch (matrixType) {
360 return {bf16Type, f16Type, tf32Type};
362 return {bf16Type, f16Type, tf32Type};
364 return {bf16Type, f16Type, f32Type};
366 return {bf16Type, f16Type, f32Type};
367 }
368 return {};
369}
370
372 Type BType,
373 Type CType,
374 Type DType) {
375 if (AType.isF16() || BType.isF16()) {
376 if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
377 (!DType.isF32() && !DType.isF16())) {
378 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
379 return false;
380 }
381 } else if (AType.isBF16() || BType.isBF16()) {
382 if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
383 (!DType.isF32() && !DType.isBF16())) {
384 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
385 return false;
386 }
387 } else if (AType.isTF32() || BType.isTF32()) {
388 if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
389 (!DType.isF32())) {
390 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
391 return false;
392 }
393 } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
394 AType.isInteger(8)) &&
395 !(BType.isInteger(2) || BType.isInteger(4) ||
396 BType.isInteger(8))) {
397 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
398 return false;
399 }
400
401 return true;
402}
403
405 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
406 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
407 Type AType, Type BType, Type CType, Type DType) {
408 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
409 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
410 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
411 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
412 return llvm::is_contained(supportedAShapes, AShape) &&
413 llvm::is_contained(supportedBShapes, BShape) &&
414 llvm::is_contained(supportedCShapes, CShape) &&
415 llvm::is_contained(supportedDShapes, DShape) &&
416 checkSupportedTypes(AType, BType, CType, DType);
417}
418
420 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
421 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
422 Type AType, Type BType, Type CType, Type DType) {
423 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
424 BType, CType, DType);
425}
426
429 return {1, 2, 3, 4, 5, 6, 7, 8};
430}
431
434 // assert if data type is not int or float type
435 assert(type.isIntOrFloat() && "Matrix type must be int or float");
436 auto bitWidth = type.getIntOrFloatBitWidth();
437 uint32_t kSize = 0;
438 switch (bitWidth) {
439 case 2:
440 kSize = 64;
441 break;
442 case 4:
443 kSize = 64;
444 break;
445 case 8:
446 kSize = 32;
447 break;
448 case 16:
449 kSize = 16;
450 break;
451 case 32:
452 kSize = 8;
453 break;
454 default:
455 llvm_unreachable("Invalid int or float");
456 }
457 return {kSize};
458}
459
462 return {16};
463}
464
465#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:55
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition IntelGpuXe2.h:91
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
unsigned getGeneralPackedFormatBitSize() const override
Definition IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition IntelGpuXe2.h:34
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:184
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:157