15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
34 Xe2Plus(StringRef archName, StringRef archDescription,
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
65 if (elemByteSize == 1)
69 else if (elemByteSize == 2 || elemByteSize == 4)
92 bool upConv =
false)
const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
112 {{1,
false,
false,
false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1,
false,
false,
true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount1}},
117 {{1,
true,
false,
false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2,
true,
false,
false}, {kWidth16, kHeightAtLeast16, kCount2}},
120 {{4,
false,
true,
false}, {kWidth8, kHeightAtLeast16, kCount1}},
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
179 return B->getInstructionKind() ==
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape,
Type AType,
197 Type DType)
override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape,
Type AType,
252 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
253 &storeNdInst, &prefetchNdInst,
254 &storeScatterInst, &loadGatherInst};
260 "Ponte Vecchio Architecture",
266 return reinterpret_cast<const uArch *
>(&instance);
278 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
279 &storeNdInst, &prefetchNdInst,
280 &storeScatterInst, &loadGatherInst};
286 "Battlemage Architecture",
292 return reinterpret_cast<const uArch *
>(&instance);
297 if (archName.equals_insensitive(
"pvc"))
299 else if (archName.equals_insensitive(
"bmg"))
302 llvm_unreachable(
"No matching uArch found");
322 for (
unsigned x : a) {
323 for (
unsigned y :
b) {
324 result.emplace_back(x, y);
335 switch (matrixType) {
337 resultMatrix = combineVectors(M, K);
340 resultMatrix = combineVectors(K, N);
343 resultMatrix = combineVectors(M, N);
346 resultMatrix = combineVectors(M, N);
355 Type bf16Type = BFloat16Type::get(&context);
356 Type f16Type = Float16Type::get(&context);
357 Type tf32Type = FloatTF32Type::get(&context);
358 Type f32Type = Float32Type::get(&context);
360 switch (matrixType) {
362 return {bf16Type, f16Type, tf32Type};
364 return {bf16Type, f16Type, tf32Type};
366 return {bf16Type, f16Type, f32Type};
368 return {bf16Type, f16Type, f32Type};
378 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isF16())) ||
380 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
384 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isBF16())) ||
386 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
390 if (AType != BType || (CType && (!CType.
isF32() && !DType.
isF32())) ||
392 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
399 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
407 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
408 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
414 return llvm::is_contained(supportedAShapes, AShape) &&
415 llvm::is_contained(supportedBShapes, BShape) &&
416 llvm::is_contained(supportedCShapes, CShape) &&
417 llvm::is_contained(supportedDShapes, DShape) &&
422 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
423 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
426 BType, CType, DType);
431 return {1, 2, 3, 4, 5, 6, 7, 8};
437 assert(type.
isIntOrFloat() &&
"Matrix type must be int or float");
457 llvm_unreachable(
"Invalid int or float");
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
@ Subgroup2DBlockPrefetch
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
static bool classof(const Instruction *B)
int32_t getMaxLaneLoadStoreSize() const
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadStoreSize() const
StoreScatterInstruction()
static bool classof(const Instruction *B)
Subgroup2DBlockLoadInstruction()
int32_t getPackedFormatBitSize() const
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
static bool classof(const Instruction *B)
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
Subgroup2DBlockPrefetchInstruction()
static bool classof(const Instruction *B)
Subgroup2DBlockStoreInstruction()
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
const unsigned packedFormatBitSizeA
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
unsigned getPackedFormatBitSizeB() const
unsigned getPackedFormatBitSizeA() const
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
const unsigned packedFormatBitSizeB
unsigned getGeneralPackedFormatBitSize() const override
int getSubgroupSize() const override
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)