15 #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16 #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/DebugLog.h"
26 #define DEBUG_TYPE "xegpu-uarch"
37 Xe2Plus(
const std::string &archName,
const std::string &archDescription,
39 const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
41 const std::map<InstructionKind, std::shared_ptr<Instruction>>
43 :
uArch(archName, archDescription, regInfo, cacheInfo, instrs),
58 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
59 std::pair<uint32_t, uint32_t> BShape,
60 std::pair<uint32_t, uint32_t> CShape,
61 std::pair<uint32_t, uint32_t> DShape,
Type AType,
63 virtual bool checkSupportedTypes(
Type AType,
Type BType,
Type CType,
65 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
66 std::pair<uint32_t, uint32_t> BShape,
67 std::pair<uint32_t, uint32_t> CShape,
68 std::pair<uint32_t, uint32_t> DShape,
Type AType,
80 "Ponte Vecchio Architecture",
88 this->registerFileInfo.emplace(
97 this->cacheInfo.push_back(
100 this->cacheInfo.push_back(
104 auto dpas = std::make_shared<DPASInstruction>();
105 instructions.emplace(dpas->getInstructionKind(), dpas);
106 owned_instructions.push_back(dpas);
115 "Battlemage Architecture",
130 this->cacheInfo.push_back(
133 this->cacheInfo.push_back(
137 auto dpas = std::make_shared<DPASInstruction>();
138 instructions.emplace(dpas->getInstructionKind(), dpas);
139 owned_instructions.push_back(dpas);
152 for (
unsigned x : a) {
153 for (
unsigned y : b) {
154 result.emplace_back(x, y);
165 switch (matrixType) {
167 resultMatrix = combineVectors(M, K);
170 resultMatrix = combineVectors(K, N);
173 resultMatrix = combineVectors(M, N);
176 resultMatrix = combineVectors(M, N);
190 switch (matrixType) {
192 return {bf16Type, f16Type, tf32Type};
194 return {bf16Type, f16Type, tf32Type};
196 return {bf16Type, f16Type, f32Type};
198 return {bf16Type, f16Type, f32Type};
206 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isF16())) ||
208 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
212 if (AType != BType || (CType && (!CType.
isF32() && !CType.
isBF16())) ||
214 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
218 if (AType != BType || (CType && (!CType.
isF32() && !DType.
isF32())) ||
220 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
227 LDBG() <<
"Unsupported dpas combinations of Dst, Acc, A and B matrices.";
235 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
236 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
242 return llvm::is_contained(supportedAShapes, AShape) &&
243 llvm::is_contained(supportedBShapes, BShape) &&
244 llvm::is_contained(supportedCShapes, CShape) &&
245 llvm::is_contained(supportedDShapes, DShape) &&
250 std::pair<uint32_t, uint32_t> BShape,
251 std::pair<uint32_t, uint32_t> CShape,
252 std::pair<uint32_t, uint32_t> DShape,
256 BType, CType, DType);
261 return {1, 2, 3, 4, 5, 6, 7, 8};
267 assert(type.
isIntOrFloat() &&
"Matrix type must be int or float");
287 llvm_unreachable(
"Invalid int or float");
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallVector< std::shared_ptr< Instruction >, 8 > owned_instructions
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) override
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) override
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
llvm::SmallVector< std::shared_ptr< Instruction >, 8 > owned_instructions
Xe2Plus(const std::string &archName, const std::string &archDescription, const XeCoreInfo &xeCore, const std::map< RegisterFileType, RegisterFileInfo > ®Info={}, const llvm::SmallVector< CacheInfo, 4 > &cacheInfo={}, const std::map< InstructionKind, std::shared_ptr< Instruction >> &instrs={})