15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
34 Xe2Plus(StringRef archName, StringRef archDescription,
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
65 if (elemByteSize == 1)
69 else if (elemByteSize == 2 || elemByteSize == 4)
92 bool upConv =
false)
const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
112 {{1,
false,
false,
false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1,
false,
false,
true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount1}},
117 {{1,
true,
false,
false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2,
true,
false,
false}, {kWidth16, kHeightAtLeast16, kCount2}},
120 {{4,
false,
true,
false}, {kWidth8, kHeightAtLeast16, kCount1}},
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
179 return B->getInstructionKind() ==
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape,
Type AType,
197 Type DType)
override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape,
Type AType,
238 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
239 &storeNdInst, &prefetchNdInst,
240 &storeScatterInst, &loadGatherInst};
246 "Ponte Vecchio Architecture",
252 return reinterpret_cast<const uArch *
>(&instance);
264 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
265 &storeNdInst, &prefetchNdInst,
266 &storeScatterInst, &loadGatherInst};
272 "Battlemage Architecture",
278 return reinterpret_cast<const uArch *
>(&instance);
283 if (archName.equals_insensitive(
"pvc"))
285 if (archName.equals_insensitive(
"bmg"))
305 for (
unsigned x : a) {
306 for (
unsigned y :
b) {
307 result.emplace_back(x, y);
318 switch (matrixType) {
320 resultMatrix = combineVectors(M, K);
323 resultMatrix = combineVectors(K, N);
326 resultMatrix = combineVectors(M, N);
329 resultMatrix = combineVectors(M, N);
338 Type bf16Type = BFloat16Type::get(&context);
339 Type f16Type = Float16Type::get(&context);
340 Type tf32Type = FloatTF32Type::get(&context);
341 Type f32Type = Float32Type::get(&context);
343 switch (matrixType) {
345 return {bf16Type, f16Type, tf32Type};
347 return {bf16Type, f16Type, tf32Type};
349 return {bf16Type, f16Type, f32Type};
351 return {bf16Type, f16Type, f32Type};
361 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isF16())) ||
363 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
367 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isBF16())) ||
369 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
373 if (AType != BType || (CType && (!CType.
isF32() && !DType.
isF32())) ||
375 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
382 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
390 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
391 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
397 return llvm::is_contained(supportedAShapes, AShape) &&
398 llvm::is_contained(supportedBShapes, BShape) &&
399 llvm::is_contained(supportedCShapes, CShape) &&
400 llvm::is_contained(supportedDShapes, DShape) &&
405 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
406 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
409 BType, CType, DType);
414 return {1, 2, 3, 4, 5, 6, 7, 8};
420 assert(type.
isIntOrFloat() &&
"Matrix type must be int or float");
440 llvm_unreachable(
"Invalid int or float");
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
@ Subgroup2DBlockPrefetch
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
Instruction(InstructionKind kind, InstructionScope scope)
LoadGatherInstructionInterface()
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
StoreScatterInstructionInterface()
Subgroup2DBlockLoadInstruction()
int32_t getPackedFormatBitSize() const
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
static bool classof(const Instruction *B)
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
Subgroup2DBlockPrefetchInstruction()
static bool classof(const Instruction *B)
Subgroup2DBlockStoreInstruction()
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
const unsigned packedFormatBitSizeA
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
unsigned getPackedFormatBitSizeB() const
unsigned getPackedFormatBitSizeA() const
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
const unsigned packedFormatBitSizeB
unsigned getGeneralPackedFormatBitSize() const override
int getSubgroupSize() const override
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)