MLIR  22.0.0git
uArchBase.h
Go to the documentation of this file.
1 //===- uArch.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // Base uArch definition for different architectures.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14 #ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15 #define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16 
17 #include <any>
18 #include <functional>
19 #include <iostream>
20 #include <map>
21 #include <mutex>
22 #include <shared_mutex>
23 #include <tuple>
24 
25 #include "mlir/IR/Types.h"
26 #include "llvm/ADT/SmallVector.h"
27 
28 namespace mlir {
29 namespace xegpu {
30 namespace uArch {
31 
32 constexpr unsigned generalPackedFormatBitSize{32};
33 
34 // An enum class to represent the scope of an instruction
36 enum class InstructionKind {
37  SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
38  // matrix multiply-add operation
39  Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
40  Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
41  Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
42  // @TODO: Add more instructions as needed
43 };
44 
45 // A struct to represent basic information about an instruction.
46 // The primary purpose of the Instruction struct is to provide a generic way to
47 // represent information about an instruction and to use this information to
48 // generate the uArch. Specifc instruction in a uArch can inherit from this
49 // struct and add more fields as needed.
50 struct Instruction {
52  : instKind(kind), scope(scope) {}
53 
54  ~Instruction() = default;
55  // Get methods
57  InstructionScope getScope() const { return scope; }
58  static llvm::StringRef toString(InstructionKind instKind) {
59  switch (instKind) {
61  return "dpas";
63  return "store_nd";
65  return "load_nd";
67  return "prefetch_nd";
68  }
69  llvm_unreachable("Unknown InstructionKind");
70  }
71 
72  static std::optional<InstructionKind>
73  parseInstructionKind(llvm::StringRef str) {
74  if (str.equals_insensitive("dpas"))
76  return std::nullopt;
77  }
78 
79 protected:
80  const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
81  const InstructionScope scope; // scope of the instruction (e.g., lane,
82  // subgroup, workgroup, cluster)
83  // @TODO: Add more fields as needed
84 };
85 
86 enum class RegisterFileMode : uint8_t { Small, Large };
87 enum class RegisterFileType : uint8_t { GRF, ARF };
88 
89 // A struct to represent register file information
91  // Constructor
92  RegisterFileInfo() = default;
95  const llvm::SmallVector<uint32_t, 4> &numRegs)
96  : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
97 
98  // Get methods
99  uint32_t getSize() const { return size; }
100 
102  return mode;
103  }
104 
107  }
108 
109 protected:
110  uint32_t size; // size per register in bits
112  mode; // e.g., "small", "large" GRF modes
114  numRegsPerThreadPerMode; // number of registers per thread per mode
115 };
116 
117 enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
118 
119 // A struct to represent cache information
120 struct CacheInfo {
121  // Constructor
122  CacheInfo() = default;
123  CacheInfo(uint32_t size, uint32_t line_size,
126 
127  virtual ~CacheInfo() = default;
128 
129  // Get methods
130  uint32_t getSize() const { return size; }
131  uint32_t getLineSize() const { return line_size; }
133 
134 protected:
135  uint32_t size;
136  uint32_t line_size;
138  // @TODO: Add more fields as needed (e.g., associativity, num_banks,
139  // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
140  // latency, throughput, bandwidth)
141 };
142 
143 struct uArch {
144  // Constructor
145  uArch(StringRef name, StringRef description,
148  for (const Instruction *instr : instructionRegistry)
149  this->instructionRegistry[instr->getInstructionKind()] = instr;
150  }
151  virtual ~uArch() = default;
152  StringRef getName() const { return name; }
153  StringRef getDescription() const { return description; }
154  virtual int getSubgroupSize() const = 0;
155  virtual unsigned getGeneralPackedFormatBitSize() const = 0;
156 
157  const Instruction *getInstruction(InstructionKind instKind) const {
158  auto it = instructionRegistry.find(instKind);
159  assert(it != instructionRegistry.end() &&
160  "Instruction not found in registry");
161  return it->second;
162  }
163 
165  return instructionRegistry.contains(instr);
166  }
167 
168 protected:
169  StringRef name;
170  StringRef description;
171  llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
173 };
174 
175 // A struct to represent shared memory information
176 struct SharedMemory {
177  // Constructor
178  SharedMemory(uint32_t size, uint32_t alignment)
179  : size(size), alignment(alignment) {}
180 
181  // Get methods
182  uint32_t getSize() const { return size; }
183  uint32_t getAlignment() const { return alignment; }
184 
185 protected:
186  uint32_t size; // in bytes
187  uint32_t alignment; // in bytes
188  // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
189 };
190 
191 struct XeCoreInfo {
192  uint32_t num_threads;
196 
198  uint32_t num_vector_units, uint32_t num_matrix_units)
201  }
202 };
203 
204 //===----------------------------------------------------------------------===//
205 // Interfaces
206 //===----------------------------------------------------------------------===//
209  // Get supported Matrix shapes
211  getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
212  // @TODO: This method takes an context object as a parameter, this is to
213  // create the Type objects from the same context. Since type objects are
214  // uniqued in a specific context, to do things like "aType == bType" (where
215  // aType and bType are both same type) kind of checks, the both types should
216  // be from the same context.
217  //
218  // One alternative to this is to create enum to represent each types, but this
219  // adds an extra burden to user to convert these enums to specific types. In
220  // fact the utility that would convert enumToType() and vice versa would still
221  // have to use the context object.
222  //
223  // Untill we have a better solution, we stick to passing context object to
224  // this method.
226  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
227  virtual bool
228  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
229  std::pair<uint32_t, uint32_t> BShape,
230  std::pair<uint32_t, uint32_t> CShape,
231  std::pair<uint32_t, uint32_t> DShape, Type AType,
232  Type BType, Type CType, Type DType) = 0;
233  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
234  Type DType) = 0;
235  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
236  std::pair<uint32_t, uint32_t> BShape,
237  std::pair<uint32_t, uint32_t> CShape,
238  std::pair<uint32_t, uint32_t> DShape, Type AType,
239  Type BType, Type CType, Type DType) = 0;
243 
244  virtual ~MMAInstructionInterface() = default;
245 };
246 
247 } // namespace uArch
248 } // namespace xegpu
249 } // namespace mlir
250 
251 #endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
constexpr unsigned generalPackedFormatBitSize
Definition: uArchBase.h:32
Include the generated interface declarations.
uint32_t getLineSize() const
Definition: uArchBase.h:131
CacheHierarchyLevel hierarchy_level
Definition: uArchBase.h:137
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition: uArchBase.h:123
CacheHierarchyLevel getHierarchyLevel() const
Definition: uArchBase.h:132
uint32_t getSize() const
Definition: uArchBase.h:130
Instruction(InstructionKind kind, InstructionScope scope)
Definition: uArchBase.h:51
const InstructionScope scope
Definition: uArchBase.h:81
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition: uArchBase.h:73
InstructionScope getScope() const
Definition: uArchBase.h:57
static llvm::StringRef toString(InstructionKind instKind)
Definition: uArchBase.h:58
InstructionKind getInstructionKind() const
Definition: uArchBase.h:56
const InstructionKind instKind
Definition: uArchBase.h:80
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition: uArchBase.h:105
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition: uArchBase.h:112
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition: uArchBase.h:93
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition: uArchBase.h:114
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition: uArchBase.h:101
uint32_t getAlignment() const
Definition: uArchBase.h:183
SharedMemory(uint32_t size, uint32_t alignment)
Definition: uArchBase.h:178
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition: uArchBase.h:197
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition: uArchBase.h:172
const Instruction * getInstruction(InstructionKind instKind) const
Definition: uArchBase.h:157
StringRef getDescription() const
Definition: uArchBase.h:153
virtual unsigned getGeneralPackedFormatBitSize() const =0
bool isSupportedInstruction(InstructionKind instr) const
Definition: uArchBase.h:164
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition: uArchBase.h:145
virtual ~uArch()=default
StringRef getName() const
Definition: uArchBase.h:152