15 #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16 #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/DebugLog.h"
34 Xe2Plus(StringRef archName, StringRef archDescription,
37 :
uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
65 if (elemByteSize == 1)
69 else if (elemByteSize == 2 || elemByteSize == 4)
92 bool upConv =
false)
const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
112 {{1,
false,
false,
false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1,
false,
false,
true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4,
false,
false,
false}, {kWidth16, kHeightAtLeast1, kCount1}},
117 {{1,
true,
false,
false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2,
true,
false,
false}, {kWidth16, kHeightAtLeast16, kCount2}},
120 {{4,
false,
true,
false}, {kWidth8, kHeightAtLeast16, kCount1}},
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
173 unsigned packedFormatBitSizeB)
176 packedFormatBitSizeA(packedFormatBitSizeA),
177 packedFormatBitSizeB(packedFormatBitSizeB) {}
179 return B->getInstructionKind() ==
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape,
Type AType,
196 virtual bool checkSupportedTypes(
Type AType,
Type BType,
Type CType,
197 Type DType)
override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape,
Type AType,
204 getSupportedM(
Type type)
const override;
206 getSupportedK(
Type type)
const override;
208 getSupportedN(
Type type)
const override;
228 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
235 "Ponte Vecchio Architecture",
236 getInstructionRegistryArr(),
241 return reinterpret_cast<const uArch *
>(&instance);
251 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
258 "Battlemage Architecture",
259 getInstructionRegistryArr(),
264 return reinterpret_cast<const uArch *
>(&instance);
269 if (archName.equals_insensitive(
"pvc"))
271 else if (archName.equals_insensitive(
"bmg"))
274 llvm_unreachable(
"No matching uArch found");
294 for (
unsigned x : a) {
295 for (
unsigned y : b) {
296 result.emplace_back(x, y);
307 switch (matrixType) {
309 resultMatrix = combineVectors(M, K);
312 resultMatrix = combineVectors(K, N);
315 resultMatrix = combineVectors(M, N);
318 resultMatrix = combineVectors(M, N);
332 switch (matrixType) {
334 return {bf16Type, f16Type, tf32Type};
336 return {bf16Type, f16Type, tf32Type};
338 return {bf16Type, f16Type, f32Type};
340 return {bf16Type, f16Type, f32Type};
350 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isF16())) ||
352 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
356 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isBF16())) ||
358 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
362 if (AType != BType || (CType && (!CType.
isF32() && !DType.
isF32())) ||
364 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
371 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
379 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
380 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
386 return llvm::is_contained(supportedAShapes, AShape) &&
387 llvm::is_contained(supportedBShapes, BShape) &&
388 llvm::is_contained(supportedCShapes, CShape) &&
389 llvm::is_contained(supportedDShapes, DShape) &&
394 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
395 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
398 BType, CType, DType);
403 return {1, 2, 3, 4, 5, 6, 7, 8};
409 assert(type.
isIntOrFloat() &&
"Matrix type must be int or float");
429 llvm_unreachable(
"Invalid int or float");
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
@ Subgroup2DBlockPrefetch
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static const uArch * getInstance()
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
Subgroup2DBlockLoadInstruction()
int32_t getPackedFormatBitSize() const
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
static bool classof(const Instruction *B)
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
int32_t getPackedFormatBitSize() const
Subgroup2DBlockPrefetchInstruction()
static bool classof(const Instruction *B)
Subgroup2DBlockStoreInstruction()
int32_t getPackedFormatBitSize() const
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
const unsigned packedFormatBitSizeA
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
unsigned getPackedFormatBitSizeB() const
unsigned getPackedFormatBitSizeA() const
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
const unsigned packedFormatBitSizeB
unsigned getGeneralPackedFormatBitSize() const override
int getSubgroupSize() const override
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)