MLIR 23.0.0git
uArchBase.h
Go to the documentation of this file.
1//===- uArch.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Base uArch definition for different architectures.
11//
12//
13//===----------------------------------------------------------------------===//
14#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16
17#include <any>
18#include <functional>
19#include <iostream>
20#include <map>
21#include <mutex>
22#include <shared_mutex>
23#include <tuple>
24
25#include "mlir/IR/Types.h"
26#include "llvm/ADT/SmallVector.h"
27
28namespace mlir {
29namespace xegpu {
30namespace uArch {
31
32constexpr unsigned generalPackedFormatBitSize{32};
33
34// An enum class to represent the scope of an instruction
36enum class InstructionKind {
37 SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
38 // matrix multiply-add operation
39 Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
40 Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
41 Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
42 StoreScatter, // Lane-level store (scalar, vector)
43 LoadGather, // Lane-level load (scalar, vector)
44 StoreMatrix, // Lane-level matrix store to slm
45 LoadMatrix // Lane-level matrix load to slm
46 // @TODO: Add more instructions as needed
47};
48
49// A struct to represent basic information about an instruction.
50// The primary purpose of the Instruction struct is to provide a generic way to
51// represent information about an instruction and to use this information to
52// generate the uArch. Specifc instruction in a uArch can inherit from this
53// struct and add more fields as needed.
57
58 ~Instruction() = default;
59 // Get methods
61 InstructionScope getScope() const { return scope; }
62 static llvm::StringRef toString(InstructionKind instKind) {
63 switch (instKind) {
65 return "dpas";
67 return "store_nd";
69 return "load_nd";
71 return "prefetch_nd";
73 return "store";
75 return "load";
77 return "store_matrix";
79 return "load_matrix";
80 }
81 llvm_unreachable("Unknown InstructionKind");
82 }
83
84 static std::optional<InstructionKind>
85 parseInstructionKind(llvm::StringRef str) {
86 if (str.equals_insensitive("dpas"))
88 return std::nullopt;
89 }
90
91protected:
92 const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
93 const InstructionScope scope; // scope of the instruction (e.g., lane,
94 // subgroup, workgroup, cluster)
95 // @TODO: Add more fields as needed
96};
97
98enum class RegisterFileMode : uint8_t { Small, Large };
99enum class RegisterFileType : uint8_t { GRF, ARF };
100
101// A struct to represent register file information
103 // Constructor
104 RegisterFileInfo() = default;
109
110 // Get methods
111 uint32_t getSize() const { return size; }
112
114 return mode;
115 }
116
120
121protected:
122 uint32_t size; // size per register in bits
124 mode; // e.g., "small", "large" GRF modes
126 numRegsPerThreadPerMode; // number of registers per thread per mode
127};
128
129enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
130
131// A struct to represent cache information
132struct CacheInfo {
133 // Constructor
134 CacheInfo() = default;
138
139 virtual ~CacheInfo() = default;
140
141 // Get methods
142 uint32_t getSize() const { return size; }
143 uint32_t getLineSize() const { return line_size; }
145
146protected:
147 uint32_t size;
148 uint32_t line_size;
150 // @TODO: Add more fields as needed (e.g., associativity, num_banks,
151 // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
152 // latency, throughput, bandwidth)
153};
154
155struct uArch {
156 // Constructor
157 uArch(StringRef name, StringRef description,
160 for (const Instruction *instr : instructionRegistry)
161 this->instructionRegistry[instr->getInstructionKind()] = instr;
162 }
163 virtual ~uArch() = default;
164 StringRef getName() const { return name; }
165 StringRef getDescription() const { return description; }
166 virtual int getSubgroupSize() const = 0;
167 virtual unsigned getGeneralPackedFormatBitSize() const = 0;
168
170 auto it = instructionRegistry.find(instKind);
171 assert(it != instructionRegistry.end() &&
172 "Instruction not found in registry");
173 return it->second;
174 }
175
177 return instructionRegistry.contains(instr);
178 }
179
180protected:
181 StringRef name;
182 StringRef description;
183 llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
185};
186
187// A struct to represent shared memory information
189 // Constructor
190 SharedMemory(uint32_t size, uint32_t alignment)
192
193 // Get methods
194 uint32_t getSize() const { return size; }
195 uint32_t getAlignment() const { return alignment; }
196
197protected:
198 uint32_t size; // in bytes
199 uint32_t alignment; // in bytes
200 // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
201};
202
215
216//===----------------------------------------------------------------------===//
217// Interfaces
218//===----------------------------------------------------------------------===//
221 // Get supported Matrix shapes
223 getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
224 // @TODO: This method takes an context object as a parameter, this is to
225 // create the Type objects from the same context. Since type objects are
226 // uniqued in a specific context, to do things like "aType == bType" (where
227 // aType and bType are both same type) kind of checks, the both types should
228 // be from the same context.
229 //
230 // One alternative to this is to create enum to represent each types, but this
231 // adds an extra burden to user to convert these enums to specific types. In
232 // fact the utility that would convert enumToType() and vice versa would still
233 // have to use the context object.
234 //
235 // Untill we have a better solution, we stick to passing context object to
236 // this method.
238 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
239 virtual bool
240 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
241 std::pair<uint32_t, uint32_t> BShape,
242 std::pair<uint32_t, uint32_t> CShape,
243 std::pair<uint32_t, uint32_t> DShape, Type AType,
244 Type BType, Type CType, Type DType) = 0;
245 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
246 Type DType) = 0;
247 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
248 std::pair<uint32_t, uint32_t> BShape,
249 std::pair<uint32_t, uint32_t> CShape,
250 std::pair<uint32_t, uint32_t> DShape, Type AType,
251 Type BType, Type CType, Type DType) = 0;
255
256 virtual ~MMAInstructionInterface() = default;
257};
258
259//===----------------------------------------------------------------------===//
260// Common instructions (shared across architectures)
261//===----------------------------------------------------------------------===//
262
266 static bool classof(const Instruction *B) {
267 return B->getInstructionKind() == InstructionKind::LoadGather;
268 }
269
270 virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
272};
273
277 static bool classof(const Instruction *B) {
278 return B->getInstructionKind() == InstructionKind::StoreScatter;
279 }
280
281 virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
283};
284
288 static bool classof(const Instruction *B) {
289 return B->getInstructionKind() == InstructionKind::LoadMatrix;
290 }
291
292 virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
294};
295
299 static bool classof(const Instruction *B) {
300 return B->getInstructionKind() == InstructionKind::StoreMatrix;
301 }
302
303 virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
305};
306
307} // namespace uArch
308} // namespace xegpu
309} // namespace mlir
310
311#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
Include the generated interface declarations.
uint32_t getLineSize() const
Definition uArchBase.h:143
CacheHierarchyLevel hierarchy_level
Definition uArchBase.h:149
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition uArchBase.h:135
CacheHierarchyLevel getHierarchyLevel() const
Definition uArchBase.h:144
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:55
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition uArchBase.h:85
const InstructionScope scope
Definition uArchBase.h:93
InstructionScope getScope() const
Definition uArchBase.h:61
static llvm::StringRef toString(InstructionKind instKind)
Definition uArchBase.h:62
InstructionKind getInstructionKind() const
Definition uArchBase.h:60
const InstructionKind instKind
Definition uArchBase.h:92
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:266
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:288
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition uArchBase.h:124
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition uArchBase.h:105
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition uArchBase.h:126
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition uArchBase.h:113
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition uArchBase.h:117
SharedMemory(uint32_t size, uint32_t alignment)
Definition uArchBase.h:190
static bool classof(const Instruction *B)
Definition uArchBase.h:299
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const =0
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:277
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition uArchBase.h:209
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:184
StringRef getDescription() const
Definition uArchBase.h:165
virtual unsigned getGeneralPackedFormatBitSize() const =0
bool isSupportedInstruction(InstructionKind instr) const
Definition uArchBase.h:176
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:157
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:169
virtual ~uArch()=default
StringRef getName() const
Definition uArchBase.h:164