MLIR 23.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11// This file defines the uArch details for Xe2 and its derived architectures.
12// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13//
14//===----------------------------------------------------------------------===//
15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
23#include <map>
24#include <string>
25
26using namespace mlir;
27using namespace mlir::xegpu::uArch;
28
29namespace mlir {
30namespace xegpu {
31namespace uArch {
32
33struct Xe2Plus : public uArch {
34 Xe2Plus(StringRef archName, StringRef archDescription,
36 const XeCoreInfo &xeCore)
37 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38 int getSubgroupSize() const override { return 16; }
39 unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40
41protected:
43};
44
45//===----------------------------------------------------------------------===//
46// uArch instructions
47//===----------------------------------------------------------------------===//
52 static bool classof(const Instruction *B) {
53 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54 }
55 // Source :
56 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57 std::optional<
58 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
64 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65 if (elemByteSize == 1)
66 return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67 llvm::ArrayRef<int>(kHeight),
68 llvm::ArrayRef<int>(kCount));
69 else if (elemByteSize == 2 || elemByteSize == 4)
70 return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71 llvm::ArrayRef<int>(kHeight),
72 llvm::ArrayRef<int>(kCount));
73 return std::nullopt;
74 }
75
76 int32_t getPackedFormatBitSize() const { return 16; }
77};
78
83 static bool classof(const Instruction *B) {
84 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85 }
86
87 // Source :
88 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89 std::optional<
90 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91 getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92 bool upConv = false) const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
97
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
101
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
106 // (elemBytes, transform, transpose, upConvert)
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108 // (widths, heights, counts)
109 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111 static const llvm::DenseMap<Key, Value> kMap = {
112 {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116 // Block Loads with Transform:
117 {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119 // Block Loads with Transpose:
120 {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121 };
122 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
125 return it->second;
126 return std::nullopt;
127 }
128
129 int32_t getPackedFormatBitSize() const { return 16; }
130};
131
136 static bool classof(const Instruction *B) {
137 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138 }
139 // Source :
140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141 std::optional<
142 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
148
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
151 // elemBytes
152 using Key = int;
153 // (widths, heights, counts)
154 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156 static const llvm::DenseMap<Key, Value> kMap = {
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
160 };
161 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
164 return it->second;
165 return std::nullopt;
166 }
167 int32_t getPackedFormatBitSize() const { return 16; }
168};
169
178 static bool classof(const Instruction *B) {
179 return B->getInstructionKind() ==
181 }
182 // Source:
183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184
185 // Override all virtuals from MatrixOpInterface
187 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190 virtual bool
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape, Type AType,
195 Type BType, Type CType, Type DType) override;
196 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197 Type DType) override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape, Type AType,
202 Type BType, Type CType, Type DType) override;
204 getSupportedM(Type type) const override;
206 getSupportedK(Type type) const override;
208 getSupportedN(Type type) const override;
209
212
213protected:
214 const unsigned packedFormatBitSizeA;
215 const unsigned packedFormatBitSizeB;
216};
217
221 static bool classof(const Instruction *B) {
222 return B->getInstructionKind() == InstructionKind::StoreScatter;
223 }
224
225 // SPIRV restricts vector size
226 int32_t getMaxLaneLoadStoreSize() const { return 16; }
227};
228
232 static bool classof(const Instruction *B) {
233 return B->getInstructionKind() == InstructionKind::LoadGather;
234 }
235
236 // SPIRV restricts vector size
237 int32_t getMaxLaneLoadStoreSize() const { return 16; }
238};
239
240//===----------------------------------------------------------------------===//
241// uArch instances
242//===----------------------------------------------------------------------===//
243
244struct PVCuArch final : public Xe2Plus {
246 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
247 static const Subgroup2DBlockLoadInstruction loadNdInst;
248 static const Subgroup2DBlockStoreInstruction storeNdInst;
249 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
250 static const StoreScatterInstruction storeScatterInst;
251 static const LoadGatherInstruction loadGatherInst;
252 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
253 &storeNdInst, &prefetchNdInst,
254 &storeScatterInst, &loadGatherInst};
255 return arr;
256 }
257
259 : Xe2Plus("pvc", // archName
260 "Ponte Vecchio Architecture", // archDescription
262 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
263 ) {}
264 static const uArch *getInstance() {
265 static const PVCuArch instance;
266 return reinterpret_cast<const uArch *>(&instance);
267 }
268};
269
270struct BMGuArch : public Xe2Plus {
272 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
273 static const Subgroup2DBlockLoadInstruction loadNdInst;
274 static const Subgroup2DBlockStoreInstruction storeNdInst;
275 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
276 static const StoreScatterInstruction storeScatterInst;
277 static const LoadGatherInstruction loadGatherInst;
278 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
279 &storeNdInst, &prefetchNdInst,
280 &storeScatterInst, &loadGatherInst};
281 return arr;
282 }
283
285 : Xe2Plus("bmg", // archName
286 "Battlemage Architecture", // archDescription
288 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
289 ) {}
290 static const uArch *getInstance() {
291 static const BMGuArch instance;
292 return reinterpret_cast<const uArch *>(&instance);
293 }
294};
295
296inline const uArch *getUArch(llvm::StringRef archName) {
297 if (archName.equals_insensitive("pvc"))
298 return PVCuArch::getInstance();
299 else if (archName.equals_insensitive("bmg"))
300 return BMGuArch::getInstance();
301 else
302 llvm_unreachable("No matching uArch found");
303
304 return nullptr;
305}
306
307} // namespace uArch
308} // namespace xegpu
309} // namespace mlir
310
311//===----------------------------------------------------------------------===//
312// Instruction implementations
313//===----------------------------------------------------------------------===//
314
317 MMAOpndKind matrixType) {
318 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
320 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
322 for (unsigned x : a) {
323 for (unsigned y : b) {
324 result.emplace_back(x, y);
325 }
326 }
327 return result;
328 };
329
330 auto M = getSupportedM(dataType);
331 auto K = getSupportedK(dataType);
332 auto N = getSupportedN(dataType);
334
335 switch (matrixType) {
337 resultMatrix = combineVectors(M, K);
338 break;
340 resultMatrix = combineVectors(K, N);
341 break;
343 resultMatrix = combineVectors(M, N);
344 break;
346 resultMatrix = combineVectors(M, N);
347 break;
348 }
349 return resultMatrix;
350}
351
354 MMAOpndKind matrixType) {
355 Type bf16Type = BFloat16Type::get(&context);
356 Type f16Type = Float16Type::get(&context);
357 Type tf32Type = FloatTF32Type::get(&context);
358 Type f32Type = Float32Type::get(&context);
359
360 switch (matrixType) {
362 return {bf16Type, f16Type, tf32Type};
364 return {bf16Type, f16Type, tf32Type};
366 return {bf16Type, f16Type, f32Type};
368 return {bf16Type, f16Type, f32Type};
369 }
370 return {};
371}
372
374 Type BType,
375 Type CType,
376 Type DType) {
377 if (AType.isF16() || BType.isF16()) {
378 if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
379 (!DType.isF32() && !DType.isF16())) {
380 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
381 return false;
382 }
383 } else if (AType.isBF16() || BType.isBF16()) {
384 if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
385 (!DType.isF32() && !DType.isBF16())) {
386 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
387 return false;
388 }
389 } else if (AType.isTF32() || BType.isTF32()) {
390 if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
391 (!DType.isF32())) {
392 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
393 return false;
394 }
395 } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
396 AType.isInteger(8)) &&
397 !(BType.isInteger(2) || BType.isInteger(4) ||
398 BType.isInteger(8))) {
399 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
400 return false;
401 }
402
403 return true;
404}
405
407 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
408 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
409 Type AType, Type BType, Type CType, Type DType) {
410 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
411 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
412 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
413 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
414 return llvm::is_contained(supportedAShapes, AShape) &&
415 llvm::is_contained(supportedBShapes, BShape) &&
416 llvm::is_contained(supportedCShapes, CShape) &&
417 llvm::is_contained(supportedDShapes, DShape) &&
418 checkSupportedTypes(AType, BType, CType, DType);
419}
420
422 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
423 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
424 Type AType, Type BType, Type CType, Type DType) {
425 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
426 BType, CType, DType);
427}
428
431 return {1, 2, 3, 4, 5, 6, 7, 8};
432}
433
436 // assert if data type is not int or float type
437 assert(type.isIntOrFloat() && "Matrix type must be int or float");
438 auto bitWidth = type.getIntOrFloatBitWidth();
439 uint32_t kSize = 0;
440 switch (bitWidth) {
441 case 2:
442 kSize = 64;
443 break;
444 case 4:
445 kSize = 64;
446 break;
447 case 8:
448 kSize = 32;
449 break;
450 case 16:
451 kSize = 16;
452 break;
453 case 32:
454 kSize = 8;
455 break;
456 default:
457 llvm_unreachable("Invalid int or float");
458 }
459 return {kSize};
460}
461
464 return {16};
465}
466
467#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:53
static bool classof(const Instruction *B)
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition IntelGpuXe2.h:91
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
unsigned getGeneralPackedFormatBitSize() const override
Definition IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition IntelGpuXe2.h:34
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:178
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:151