MLIR 23.0.0git
uArchBase.h
Go to the documentation of this file.
1//===- uArch.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Base uArch definition for different architectures.
11//
12//
13//===----------------------------------------------------------------------===//
14#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16
17#include <any>
18#include <functional>
19#include <iostream>
20#include <map>
21#include <mutex>
22#include <shared_mutex>
23#include <tuple>
24
25#include "mlir/IR/Types.h"
26#include "llvm/ADT/SmallVector.h"
27
28namespace mlir {
29namespace xegpu {
30namespace uArch {
31
32constexpr unsigned generalPackedFormatBitSize{32};
33
34// An enum class to represent the scope of an instruction
36enum class InstructionKind {
37 SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
38 // matrix multiply-add operation
39 SubgroupScaledMatrixMultiplyAcc, // Scaled Matrix Multiply Accumulate is a
40 // DPAS with scaling factor applied to
41 // operand A or B before multiplication
42 Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
43 Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
44 Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
45 StoreScatter, // Lane-level store (scalar, vector)
46 LoadGather, // Lane-level load (scalar, vector)
47 // @TODO: Add more instructions as needed
48};
49
50// A struct to represent basic information about an instruction.
51// The primary purpose of the Instruction struct is to provide a generic way to
52// represent information about an instruction and to use this information to
53// generate the uArch. Specifc instruction in a uArch can inherit from this
54// struct and add more fields as needed.
58
59 ~Instruction() = default;
60 // Get methods
62 InstructionScope getScope() const { return scope; }
63 static llvm::StringRef toString(InstructionKind instKind) {
64 switch (instKind) {
66 return "dpas";
68 return "dpas_mx";
70 return "store_nd";
72 return "load_nd";
74 return "prefetch_nd";
76 return "store";
78 return "load";
79 }
80 llvm_unreachable("Unknown InstructionKind");
81 }
82
83 static std::optional<InstructionKind>
84 parseInstructionKind(llvm::StringRef str) {
85 if (str.equals_insensitive("dpas"))
87 return std::nullopt;
88 }
89
90protected:
91 const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
92 const InstructionScope scope; // scope of the instruction (e.g., lane,
93 // subgroup, workgroup, cluster)
94 // @TODO: Add more fields as needed
95};
96
97enum class RegisterFileMode : uint8_t { Small, Large };
98enum class RegisterFileType : uint8_t { GRF, ARF };
99
100// A struct to represent register file information
102 // Constructor
103 RegisterFileInfo() = default;
108
109 // Get methods
110 uint32_t getSize() const { return size; }
111
113 return mode;
114 }
115
119
120protected:
121 uint32_t size; // size per register in bits
123 mode; // e.g., "small", "large" GRF modes
125 numRegsPerThreadPerMode; // number of registers per thread per mode
126};
127
128enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
129
130// A struct to represent cache information
131struct CacheInfo {
132 // Constructor
133 CacheInfo() = default;
137
138 virtual ~CacheInfo() = default;
139
140 // Get methods
141 uint32_t getSize() const { return size; }
142 uint32_t getLineSize() const { return line_size; }
144
145protected:
146 uint32_t size;
147 uint32_t line_size;
149 // @TODO: Add more fields as needed (e.g., associativity, num_banks,
150 // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
151 // latency, throughput, bandwidth)
152};
153
154struct uArch {
155 // Constructor
156 uArch(StringRef name, StringRef description,
159 for (const Instruction *instr : instructionRegistry)
160 this->instructionRegistry[instr->getInstructionKind()] = instr;
161 }
162 virtual ~uArch() = default;
163 StringRef getName() const { return name; }
164 StringRef getDescription() const { return description; }
165 virtual int getSubgroupSize() const = 0;
166 virtual unsigned getGeneralPackedFormatBitSize() const = 0;
167
169 auto it = instructionRegistry.find(instKind);
170 assert(it != instructionRegistry.end() &&
171 "Instruction not found in registry");
172 return it->second;
173 }
174
176 return instructionRegistry.contains(instr);
177 }
178
179protected:
180 StringRef name;
181 StringRef description;
182 llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
184};
185
186// A struct to represent shared memory information
188 // Constructor
189 SharedMemory(uint32_t size, uint32_t alignment)
191
192 // Get methods
193 uint32_t getSize() const { return size; }
194 uint32_t getAlignment() const { return alignment; }
195
196protected:
197 uint32_t size; // in bytes
198 uint32_t alignment; // in bytes
199 // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
200};
201
214
215//===----------------------------------------------------------------------===//
216// Interfaces
217//===----------------------------------------------------------------------===//
220 // Get supported Matrix shapes
222 getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
223 // @TODO: This method takes an context object as a parameter, this is to
224 // create the Type objects from the same context. Since type objects are
225 // uniqued in a specific context, to do things like "aType == bType" (where
226 // aType and bType are both same type) kind of checks, the both types should
227 // be from the same context.
228 //
229 // One alternative to this is to create enum to represent each types, but this
230 // adds an extra burden to user to convert these enums to specific types. In
231 // fact the utility that would convert enumToType() and vice versa would still
232 // have to use the context object.
233 //
234 // Untill we have a better solution, we stick to passing context object to
235 // this method.
237 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
238 virtual bool
239 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
240 std::pair<uint32_t, uint32_t> BShape,
241 std::pair<uint32_t, uint32_t> CShape,
242 std::pair<uint32_t, uint32_t> DShape, Type AType,
243 Type BType, Type CType, Type DType) = 0;
244 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
245 Type DType) = 0;
246 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
247 std::pair<uint32_t, uint32_t> BShape,
248 std::pair<uint32_t, uint32_t> CShape,
249 std::pair<uint32_t, uint32_t> DShape, Type AType,
250 Type BType, Type CType, Type DType) = 0;
254 virtual bool isLaneLayoutRowMajorOrder() const = 0;
255 virtual ~MMAInstructionInterface() = default;
256};
257
258//===----------------------------------------------------------------------===//
259// Common instructions (shared across architectures)
260//===----------------------------------------------------------------------===//
261
265 static bool classof(const Instruction *B) {
266 return B->getInstructionKind() == InstructionKind::LoadGather;
267 }
268
269 virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
271};
272
276 static bool classof(const Instruction *B) {
277 return B->getInstructionKind() == InstructionKind::StoreScatter;
278 }
279
280 virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
282};
283
284} // namespace uArch
285} // namespace xegpu
286} // namespace mlir
287
288#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
Include the generated interface declarations.
uint32_t getLineSize() const
Definition uArchBase.h:142
CacheHierarchyLevel hierarchy_level
Definition uArchBase.h:148
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition uArchBase.h:134
CacheHierarchyLevel getHierarchyLevel() const
Definition uArchBase.h:143
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:56
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition uArchBase.h:84
const InstructionScope scope
Definition uArchBase.h:92
InstructionScope getScope() const
Definition uArchBase.h:62
static llvm::StringRef toString(InstructionKind instKind)
Definition uArchBase.h:63
InstructionKind getInstructionKind() const
Definition uArchBase.h:61
const InstructionKind instKind
Definition uArchBase.h:91
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:265
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual bool isLaneLayoutRowMajorOrder() const =0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition uArchBase.h:123
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition uArchBase.h:104
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition uArchBase.h:125
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition uArchBase.h:112
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition uArchBase.h:116
SharedMemory(uint32_t size, uint32_t alignment)
Definition uArchBase.h:189
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:276
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition uArchBase.h:208
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:183
StringRef getDescription() const
Definition uArchBase.h:164
virtual unsigned getGeneralPackedFormatBitSize() const =0
bool isSupportedInstruction(InstructionKind instr) const
Definition uArchBase.h:175
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:156
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168
virtual ~uArch()=default
StringRef getName() const
Definition uArchBase.h:163