15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
34 Xe2Plus(StringRef archName, StringRef archDescription,
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
65 if (elemByteSize == 1)
69 else if (elemByteSize == 2 || elemByteSize == 4)
92 bool upConv =
false)
const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
112 {{1,
false,
false,
false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1,
false,
false,
true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount1}},
117 {{1,
true,
false,
false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2,
true,
false,
false}, {kWidth16, kHeightAtLeast16, kCount2}},
120 {{4,
false,
true,
false}, {kWidth8, kHeightAtLeast16, kCount1}},
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
179 return B->getInstructionKind() ==
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape,
Type AType,
197 Type DType)
override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape,
Type AType,
249 &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
250 &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
256 "Ponte Vecchio Architecture",
262 return reinterpret_cast<const uArch *
>(&instance);
277 &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
278 &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
284 "Battlemage Architecture",
290 return reinterpret_cast<const uArch *
>(&instance);
295 if (archName.equals_insensitive(
"pvc"))
297 else if (archName.equals_insensitive(
"bmg"))
300 llvm_unreachable(
"No matching uArch found");
320 for (
unsigned x : a) {
321 for (
unsigned y :
b) {
322 result.emplace_back(x, y);
333 switch (matrixType) {
335 resultMatrix = combineVectors(M, K);
338 resultMatrix = combineVectors(K, N);
341 resultMatrix = combineVectors(M, N);
344 resultMatrix = combineVectors(M, N);
353 Type bf16Type = BFloat16Type::get(&context);
354 Type f16Type = Float16Type::get(&context);
355 Type tf32Type = FloatTF32Type::get(&context);
356 Type f32Type = Float32Type::get(&context);
358 switch (matrixType) {
360 return {bf16Type, f16Type, tf32Type};
362 return {bf16Type, f16Type, tf32Type};
364 return {bf16Type, f16Type, f32Type};
366 return {bf16Type, f16Type, f32Type};
376 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isF16())) ||
378 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
382 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isBF16())) ||
384 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
388 if (AType != BType || (CType && (!CType.
isF32() && !DType.
isF32())) ||
390 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
397 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
405 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
406 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
412 return llvm::is_contained(supportedAShapes, AShape) &&
413 llvm::is_contained(supportedBShapes, BShape) &&
414 llvm::is_contained(supportedCShapes, CShape) &&
415 llvm::is_contained(supportedDShapes, DShape) &&
420 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
421 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
424 BType, CType, DType);
429 return {1, 2, 3, 4, 5, 6, 7, 8};
435 assert(type.
isIntOrFloat() &&
"Matrix type must be int or float");
455 llvm_unreachable(
"Invalid int or float");
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
@ Subgroup2DBlockPrefetch
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
LoadGatherInstructionInterface()
LoadMatrixInstructionInterface()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
StoreMatrixInstructionInterface()
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
StoreScatterInstructionInterface()
Subgroup2DBlockLoadInstruction()
int32_t getPackedFormatBitSize() const
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
static bool classof(const Instruction *B)
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
Subgroup2DBlockPrefetchInstruction()
static bool classof(const Instruction *B)
Subgroup2DBlockStoreInstruction()
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
const unsigned packedFormatBitSizeA
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
unsigned getPackedFormatBitSizeB() const
unsigned getPackedFormatBitSizeA() const
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
const unsigned packedFormatBitSizeB
unsigned getGeneralPackedFormatBitSize() const override
int getSubgroupSize() const override
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)