MLIR 23.0.0git
uArchBase.h
Go to the documentation of this file.
1//===- uArch.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Base uArch definition for different architectures.
11//
12//
13//===----------------------------------------------------------------------===//
14#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16
17#include <any>
18#include <functional>
19#include <iostream>
20#include <map>
21#include <mutex>
22#include <shared_mutex>
23#include <tuple>
24
25#include "mlir/IR/Types.h"
26#include "llvm/ADT/SmallVector.h"
27
28namespace mlir {
29namespace xegpu {
30namespace uArch {
31
32constexpr unsigned generalPackedFormatBitSize{32};
33
34// An enum class to represent the scope of an instruction
36enum class InstructionKind {
37 SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
38 // matrix multiply-add operation
39 Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
40 Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
41 Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
42 StoreScatter, // Lane-level store (scalar, vector)
43 LoadGather, // Lane-level load (scalar, vector)
44 // @TODO: Add more instructions as needed
45};
46
47// A struct to represent basic information about an instruction.
48// The primary purpose of the Instruction struct is to provide a generic way to
49// represent information about an instruction and to use this information to
50// generate the uArch. Specifc instruction in a uArch can inherit from this
51// struct and add more fields as needed.
55
56 ~Instruction() = default;
57 // Get methods
59 InstructionScope getScope() const { return scope; }
60 static llvm::StringRef toString(InstructionKind instKind) {
61 switch (instKind) {
63 return "dpas";
65 return "store_nd";
67 return "load_nd";
69 return "prefetch_nd";
71 return "store";
73 return "load";
74 }
75 llvm_unreachable("Unknown InstructionKind");
76 }
77
78 static std::optional<InstructionKind>
79 parseInstructionKind(llvm::StringRef str) {
80 if (str.equals_insensitive("dpas"))
82 return std::nullopt;
83 }
84
85protected:
86 const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
87 const InstructionScope scope; // scope of the instruction (e.g., lane,
88 // subgroup, workgroup, cluster)
89 // @TODO: Add more fields as needed
90};
91
92enum class RegisterFileMode : uint8_t { Small, Large };
93enum class RegisterFileType : uint8_t { GRF, ARF };
94
95// A struct to represent register file information
97 // Constructor
98 RegisterFileInfo() = default;
103
104 // Get methods
105 uint32_t getSize() const { return size; }
106
108 return mode;
109 }
110
114
115protected:
116 uint32_t size; // size per register in bits
118 mode; // e.g., "small", "large" GRF modes
120 numRegsPerThreadPerMode; // number of registers per thread per mode
121};
122
123enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
124
125// A struct to represent cache information
126struct CacheInfo {
127 // Constructor
128 CacheInfo() = default;
132
133 virtual ~CacheInfo() = default;
134
135 // Get methods
136 uint32_t getSize() const { return size; }
137 uint32_t getLineSize() const { return line_size; }
139
140protected:
141 uint32_t size;
142 uint32_t line_size;
144 // @TODO: Add more fields as needed (e.g., associativity, num_banks,
145 // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
146 // latency, throughput, bandwidth)
147};
148
149struct uArch {
150 // Constructor
151 uArch(StringRef name, StringRef description,
154 for (const Instruction *instr : instructionRegistry)
155 this->instructionRegistry[instr->getInstructionKind()] = instr;
156 }
157 virtual ~uArch() = default;
158 StringRef getName() const { return name; }
159 StringRef getDescription() const { return description; }
160 virtual int getSubgroupSize() const = 0;
161 virtual unsigned getGeneralPackedFormatBitSize() const = 0;
162
164 auto it = instructionRegistry.find(instKind);
165 assert(it != instructionRegistry.end() &&
166 "Instruction not found in registry");
167 return it->second;
168 }
169
171 return instructionRegistry.contains(instr);
172 }
173
174protected:
175 StringRef name;
176 StringRef description;
177 llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
179};
180
181// A struct to represent shared memory information
183 // Constructor
184 SharedMemory(uint32_t size, uint32_t alignment)
186
187 // Get methods
188 uint32_t getSize() const { return size; }
189 uint32_t getAlignment() const { return alignment; }
190
191protected:
192 uint32_t size; // in bytes
193 uint32_t alignment; // in bytes
194 // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
195};
196
209
210//===----------------------------------------------------------------------===//
211// Interfaces
212//===----------------------------------------------------------------------===//
215 // Get supported Matrix shapes
217 getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
218 // @TODO: This method takes an context object as a parameter, this is to
219 // create the Type objects from the same context. Since type objects are
220 // uniqued in a specific context, to do things like "aType == bType" (where
221 // aType and bType are both same type) kind of checks, the both types should
222 // be from the same context.
223 //
224 // One alternative to this is to create enum to represent each types, but this
225 // adds an extra burden to user to convert these enums to specific types. In
226 // fact the utility that would convert enumToType() and vice versa would still
227 // have to use the context object.
228 //
229 // Untill we have a better solution, we stick to passing context object to
230 // this method.
232 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
233 virtual bool
234 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
235 std::pair<uint32_t, uint32_t> BShape,
236 std::pair<uint32_t, uint32_t> CShape,
237 std::pair<uint32_t, uint32_t> DShape, Type AType,
238 Type BType, Type CType, Type DType) = 0;
239 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
240 Type DType) = 0;
241 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
242 std::pair<uint32_t, uint32_t> BShape,
243 std::pair<uint32_t, uint32_t> CShape,
244 std::pair<uint32_t, uint32_t> DShape, Type AType,
245 Type BType, Type CType, Type DType) = 0;
249
250 virtual ~MMAInstructionInterface() = default;
251};
252
253//===----------------------------------------------------------------------===//
254// Common instructions (shared across architectures)
255//===----------------------------------------------------------------------===//
256
260 static bool classof(const Instruction *B) {
261 return B->getInstructionKind() == InstructionKind::LoadGather;
262 }
263
264 virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
266};
267
271 static bool classof(const Instruction *B) {
272 return B->getInstructionKind() == InstructionKind::StoreScatter;
273 }
274
275 virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
277};
278
279} // namespace uArch
280} // namespace xegpu
281} // namespace mlir
282
283#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
Include the generated interface declarations.
uint32_t getLineSize() const
Definition uArchBase.h:137
CacheHierarchyLevel hierarchy_level
Definition uArchBase.h:143
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition uArchBase.h:129
CacheHierarchyLevel getHierarchyLevel() const
Definition uArchBase.h:138
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:53
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition uArchBase.h:79
const InstructionScope scope
Definition uArchBase.h:87
InstructionScope getScope() const
Definition uArchBase.h:59
static llvm::StringRef toString(InstructionKind instKind)
Definition uArchBase.h:60
InstructionKind getInstructionKind() const
Definition uArchBase.h:58
const InstructionKind instKind
Definition uArchBase.h:86
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:260
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition uArchBase.h:118
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition uArchBase.h:99
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition uArchBase.h:120
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition uArchBase.h:107
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition uArchBase.h:111
SharedMemory(uint32_t size, uint32_t alignment)
Definition uArchBase.h:184
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const =0
static bool classof(const Instruction *B)
Definition uArchBase.h:271
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition uArchBase.h:203
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:178
StringRef getDescription() const
Definition uArchBase.h:159
virtual unsigned getGeneralPackedFormatBitSize() const =0
bool isSupportedInstruction(InstructionKind instr) const
Definition uArchBase.h:170
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:151
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163
virtual ~uArch()=default
StringRef getName() const
Definition uArchBase.h:158