MLIR 22.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11// This file defines the uArch details for Xe2 and its derived architectures.
12// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13//
14//===----------------------------------------------------------------------===//
15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
23#include <map>
24#include <string>
25
26using namespace mlir;
27using namespace mlir::xegpu::uArch;
28
29namespace mlir {
30namespace xegpu {
31namespace uArch {
32
33struct Xe2Plus : public uArch {
34 Xe2Plus(StringRef archName, StringRef archDescription,
36 const XeCoreInfo &xeCore)
37 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38 int getSubgroupSize() const override { return 16; }
39 unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40
41protected:
43};
44
45//===----------------------------------------------------------------------===//
46// uArch instructions
47//===----------------------------------------------------------------------===//
52 static bool classof(const Instruction *B) {
53 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54 }
55 // Source :
56 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57 std::optional<
58 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
64 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65 if (elemByteSize == 1)
66 return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67 llvm::ArrayRef<int>(kHeight),
68 llvm::ArrayRef<int>(kCount));
69 else if (elemByteSize == 2 || elemByteSize == 4)
70 return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71 llvm::ArrayRef<int>(kHeight),
72 llvm::ArrayRef<int>(kCount));
73 return std::nullopt;
74 }
75
76 int32_t getPackedFormatBitSize() const { return 16; }
77};
78
83 static bool classof(const Instruction *B) {
84 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85 }
86
87 // Source :
88 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89 std::optional<
90 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91 getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92 bool upConv = false) const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
97
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
101
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
106 // (elemBytes, transform, transpose, upConvert)
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108 // (widths, heights, counts)
109 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111 static const llvm::DenseMap<Key, Value> kMap = {
112 {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116 // Block Loads with Transform:
117 {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119 // Block Loads with Transpose:
120 {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121 };
122 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
125 return it->second;
126 return std::nullopt;
127 }
128
129 int32_t getPackedFormatBitSize() const { return 16; }
130};
131
136 static bool classof(const Instruction *B) {
137 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138 }
139 // Source :
140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141 std::optional<
142 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
148
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
151 // elemBytes
152 using Key = int;
153 // (widths, heights, counts)
154 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156 static const llvm::DenseMap<Key, Value> kMap = {
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
160 };
161 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
164 return it->second;
165 return std::nullopt;
166 }
167 int32_t getPackedFormatBitSize() const { return 16; }
168};
169
178 static bool classof(const Instruction *B) {
179 return B->getInstructionKind() ==
181 }
182 // Source:
183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184
185 // Override all virtuals from MatrixOpInterface
187 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190 virtual bool
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape, Type AType,
195 Type BType, Type CType, Type DType) override;
196 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197 Type DType) override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape, Type AType,
202 Type BType, Type CType, Type DType) override;
204 getSupportedM(Type type) const override;
206 getSupportedK(Type type) const override;
208 getSupportedN(Type type) const override;
209
212
213protected:
214 const unsigned packedFormatBitSizeA;
215 const unsigned packedFormatBitSizeB;
216};
217
218//===----------------------------------------------------------------------===//
219// uArch instances
220//===----------------------------------------------------------------------===//
221
222struct PVCuArch final : public Xe2Plus {
224 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
225 static const Subgroup2DBlockLoadInstruction loadNdInst;
226 static const Subgroup2DBlockStoreInstruction storeNdInst;
227 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
228 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
229 &prefetchNdInst};
230 return arr;
231 }
232
234 : Xe2Plus("pvc", // archName
235 "Ponte Vecchio Architecture", // archDescription
237 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
238 ) {}
239 static const uArch *getInstance() {
240 static const PVCuArch instance;
241 return reinterpret_cast<const uArch *>(&instance);
242 }
243};
244
245struct BMGuArch : public Xe2Plus {
247 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
248 static const Subgroup2DBlockLoadInstruction loadNdInst;
249 static const Subgroup2DBlockStoreInstruction storeNdInst;
250 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
251 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
252 &prefetchNdInst};
253 return arr;
254 }
255
257 : Xe2Plus("bmg", // archName
258 "Battlemage Architecture", // archDescription
260 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
261 ) {}
262 static const uArch *getInstance() {
263 static const BMGuArch instance;
264 return reinterpret_cast<const uArch *>(&instance);
265 }
266};
267
268inline const uArch *getUArch(llvm::StringRef archName) {
269 if (archName.equals_insensitive("pvc"))
270 return PVCuArch::getInstance();
271 else if (archName.equals_insensitive("bmg"))
272 return BMGuArch::getInstance();
273 else
274 llvm_unreachable("No matching uArch found");
275
276 return nullptr;
277}
278
279} // namespace uArch
280} // namespace xegpu
281} // namespace mlir
282
283//===----------------------------------------------------------------------===//
284// Instruction implementations
285//===----------------------------------------------------------------------===//
286
289 MMAOpndKind matrixType) {
290 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
292 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
294 for (unsigned x : a) {
295 for (unsigned y : b) {
296 result.emplace_back(x, y);
297 }
298 }
299 return result;
300 };
301
302 auto M = getSupportedM(dataType);
303 auto K = getSupportedK(dataType);
304 auto N = getSupportedN(dataType);
306
307 switch (matrixType) {
309 resultMatrix = combineVectors(M, K);
310 break;
312 resultMatrix = combineVectors(K, N);
313 break;
315 resultMatrix = combineVectors(M, N);
316 break;
318 resultMatrix = combineVectors(M, N);
319 break;
320 }
321 return resultMatrix;
322}
323
326 MMAOpndKind matrixType) {
327 Type bf16Type = BFloat16Type::get(&context);
328 Type f16Type = Float16Type::get(&context);
329 Type tf32Type = FloatTF32Type::get(&context);
330 Type f32Type = Float32Type::get(&context);
331
332 switch (matrixType) {
334 return {bf16Type, f16Type, tf32Type};
336 return {bf16Type, f16Type, tf32Type};
338 return {bf16Type, f16Type, f32Type};
340 return {bf16Type, f16Type, f32Type};
341 }
342 return {};
343}
344
346 Type BType,
347 Type CType,
348 Type DType) {
349 if (AType.isF16() || BType.isF16()) {
350 if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
351 (!DType.isF32() && !DType.isF16())) {
352 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
353 return false;
354 }
355 } else if (AType.isBF16() || BType.isBF16()) {
356 if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
357 (!DType.isF32() && !DType.isBF16())) {
358 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
359 return false;
360 }
361 } else if (AType.isTF32() || BType.isTF32()) {
362 if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
363 (!DType.isF32())) {
364 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
365 return false;
366 }
367 } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
368 AType.isInteger(8)) &&
369 !(BType.isInteger(2) || BType.isInteger(4) ||
370 BType.isInteger(8))) {
371 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
372 return false;
373 }
374
375 return true;
376}
377
379 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
380 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
381 Type AType, Type BType, Type CType, Type DType) {
382 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
383 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
384 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
385 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
386 return llvm::is_contained(supportedAShapes, AShape) &&
387 llvm::is_contained(supportedBShapes, BShape) &&
388 llvm::is_contained(supportedCShapes, CShape) &&
389 llvm::is_contained(supportedDShapes, DShape) &&
390 checkSupportedTypes(AType, BType, CType, DType);
391}
392
394 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
395 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
396 Type AType, Type BType, Type CType, Type DType) {
397 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
398 BType, CType, DType);
399}
400
403 return {1, 2, 3, 4, 5, 6, 7, 8};
404}
405
408 // assert if data type is not int or float type
409 assert(type.isIntOrFloat() && "Matrix type must be int or float");
410 auto bitWidth = type.getIntOrFloatBitWidth();
411 uint32_t kSize = 0;
412 switch (bitWidth) {
413 case 2:
414 kSize = 64;
415 break;
416 case 4:
417 kSize = 64;
418 break;
419 case 8:
420 kSize = 32;
421 break;
422 case 16:
423 kSize = 16;
424 break;
425 case 32:
426 kSize = 8;
427 break;
428 default:
429 llvm_unreachable("Invalid int or float");
430 }
431 return {kSize};
432}
433
436 return {16};
437}
438
439#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:51
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition IntelGpuXe2.h:91
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
unsigned getGeneralPackedFormatBitSize() const override
Definition IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition IntelGpuXe2.h:34
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:172
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:145