MLIR 22.0.0git
uArchBase.h
Go to the documentation of this file.
1//===- uArch.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Base uArch definition for different architectures.
11//
12//
13//===----------------------------------------------------------------------===//
14#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16
17#include <any>
18#include <functional>
19#include <iostream>
20#include <map>
21#include <mutex>
22#include <shared_mutex>
23#include <tuple>
24
25#include "mlir/IR/Types.h"
26#include "llvm/ADT/SmallVector.h"
27
28namespace mlir {
29namespace xegpu {
30namespace uArch {
31
32constexpr unsigned generalPackedFormatBitSize{32};
33
34// An enum class to represent the scope of an instruction
36enum class InstructionKind {
37 SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
38 // matrix multiply-add operation
39 Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
40 Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
41 Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
42 // @TODO: Add more instructions as needed
43};
44
45// A struct to represent basic information about an instruction.
46// The primary purpose of the Instruction struct is to provide a generic way to
47// represent information about an instruction and to use this information to
48// generate the uArch. Specifc instruction in a uArch can inherit from this
49// struct and add more fields as needed.
53
54 ~Instruction() = default;
55 // Get methods
57 InstructionScope getScope() const { return scope; }
58 static llvm::StringRef toString(InstructionKind instKind) {
59 switch (instKind) {
61 return "dpas";
63 return "store_nd";
65 return "load_nd";
67 return "prefetch_nd";
68 }
69 llvm_unreachable("Unknown InstructionKind");
70 }
71
72 static std::optional<InstructionKind>
73 parseInstructionKind(llvm::StringRef str) {
74 if (str.equals_insensitive("dpas"))
76 return std::nullopt;
77 }
78
79protected:
80 const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
81 const InstructionScope scope; // scope of the instruction (e.g., lane,
82 // subgroup, workgroup, cluster)
83 // @TODO: Add more fields as needed
84};
85
86enum class RegisterFileMode : uint8_t { Small, Large };
87enum class RegisterFileType : uint8_t { GRF, ARF };
88
89// A struct to represent register file information
91 // Constructor
92 RegisterFileInfo() = default;
97
98 // Get methods
99 uint32_t getSize() const { return size; }
100
102 return mode;
103 }
104
108
109protected:
110 uint32_t size; // size per register in bits
112 mode; // e.g., "small", "large" GRF modes
114 numRegsPerThreadPerMode; // number of registers per thread per mode
115};
116
117enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
118
119// A struct to represent cache information
120struct CacheInfo {
121 // Constructor
122 CacheInfo() = default;
126
127 virtual ~CacheInfo() = default;
128
129 // Get methods
130 uint32_t getSize() const { return size; }
131 uint32_t getLineSize() const { return line_size; }
133
134protected:
135 uint32_t size;
136 uint32_t line_size;
138 // @TODO: Add more fields as needed (e.g., associativity, num_banks,
139 // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
140 // latency, throughput, bandwidth)
141};
142
143struct uArch {
144 // Constructor
145 uArch(StringRef name, StringRef description,
148 for (const Instruction *instr : instructionRegistry)
149 this->instructionRegistry[instr->getInstructionKind()] = instr;
150 }
151 virtual ~uArch() = default;
152 StringRef getName() const { return name; }
153 StringRef getDescription() const { return description; }
154 virtual int getSubgroupSize() const = 0;
155 virtual unsigned getGeneralPackedFormatBitSize() const = 0;
156
158 auto it = instructionRegistry.find(instKind);
159 assert(it != instructionRegistry.end() &&
160 "Instruction not found in registry");
161 return it->second;
162 }
163
165 return instructionRegistry.contains(instr);
166 }
167
168protected:
169 StringRef name;
170 StringRef description;
171 llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
173};
174
175// A struct to represent shared memory information
177 // Constructor
178 SharedMemory(uint32_t size, uint32_t alignment)
180
181 // Get methods
182 uint32_t getSize() const { return size; }
183 uint32_t getAlignment() const { return alignment; }
184
185protected:
186 uint32_t size; // in bytes
187 uint32_t alignment; // in bytes
188 // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
189};
190
203
204//===----------------------------------------------------------------------===//
205// Interfaces
206//===----------------------------------------------------------------------===//
209 // Get supported Matrix shapes
211 getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
212 // @TODO: This method takes an context object as a parameter, this is to
213 // create the Type objects from the same context. Since type objects are
214 // uniqued in a specific context, to do things like "aType == bType" (where
215 // aType and bType are both same type) kind of checks, the both types should
216 // be from the same context.
217 //
218 // One alternative to this is to create enum to represent each types, but this
219 // adds an extra burden to user to convert these enums to specific types. In
220 // fact the utility that would convert enumToType() and vice versa would still
221 // have to use the context object.
222 //
223 // Untill we have a better solution, we stick to passing context object to
224 // this method.
226 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
227 virtual bool
228 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
229 std::pair<uint32_t, uint32_t> BShape,
230 std::pair<uint32_t, uint32_t> CShape,
231 std::pair<uint32_t, uint32_t> DShape, Type AType,
232 Type BType, Type CType, Type DType) = 0;
233 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
234 Type DType) = 0;
235 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
236 std::pair<uint32_t, uint32_t> BShape,
237 std::pair<uint32_t, uint32_t> CShape,
238 std::pair<uint32_t, uint32_t> DShape, Type AType,
239 Type BType, Type CType, Type DType) = 0;
243
244 virtual ~MMAInstructionInterface() = default;
245};
246
247} // namespace uArch
248} // namespace xegpu
249} // namespace mlir
250
251#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
Include the generated interface declarations.
uint32_t getLineSize() const
Definition uArchBase.h:131
CacheHierarchyLevel hierarchy_level
Definition uArchBase.h:137
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition uArchBase.h:123
CacheHierarchyLevel getHierarchyLevel() const
Definition uArchBase.h:132
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:51
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition uArchBase.h:73
const InstructionScope scope
Definition uArchBase.h:81
InstructionScope getScope() const
Definition uArchBase.h:57
static llvm::StringRef toString(InstructionKind instKind)
Definition uArchBase.h:58
InstructionKind getInstructionKind() const
Definition uArchBase.h:56
const InstructionKind instKind
Definition uArchBase.h:80
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition uArchBase.h:112
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition uArchBase.h:93
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition uArchBase.h:114
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition uArchBase.h:101
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition uArchBase.h:105
SharedMemory(uint32_t size, uint32_t alignment)
Definition uArchBase.h:178
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition uArchBase.h:197
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:172
StringRef getDescription() const
Definition uArchBase.h:153
virtual unsigned getGeneralPackedFormatBitSize() const =0
bool isSupportedInstruction(InstructionKind instr) const
Definition uArchBase.h:164
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:145
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:157
virtual ~uArch()=default
StringRef getName() const
Definition uArchBase.h:152