MLIR  22.0.0git
uArchBase.h
Go to the documentation of this file.
1 //===- uArch.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // Base uArch definition for different architectures.
11 //
12 //
13 //===----------------------------------------------------------------------===//
14 #ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
15 #define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
16 
17 #include <any>
18 #include <functional>
19 #include <iostream>
20 #include <map>
21 #include <mutex>
22 #include <shared_mutex>
23 #include <tuple>
24 
25 #include "mlir/IR/Types.h"
26 #include "llvm/ADT/SmallVector.h"
27 
28 namespace mlir {
29 namespace xegpu {
30 namespace uArch {
31 
32 // An enum class to represent the scope of an instruction
34 enum class InstructionKind {
35  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
36  // multiply-add operation
37  // @TODO: Add more instructions as needed
38 };
39 
40 // A struct to represent basic information about an instruction.
41 // The primary purpose of the Instruction struct is to provide a generic way to
42 // represent information about an instruction and to use this information to
43 // generate the uArch. Specifc instruction in a uArch can inherit from this
44 // struct and add more fields as needed.
45 struct Instruction {
47  : instKind(kind), scope(scope) {}
48 
49  virtual ~Instruction() = default;
50  // Get methods
53  static llvm::StringRef toString(InstructionKind instKind) {
54  switch (instKind) {
56  return "dpas";
57  }
58  llvm_unreachable("Unknown InstructionKind");
59  }
60 
61  static std::optional<InstructionKind>
62  parseInstructionKind(llvm::StringRef str) {
63  if (str.equals_insensitive("dpas"))
64  return InstructionKind::DPAS;
65  return std::nullopt;
66  }
67 
68 protected:
69  InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
70  InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
71  // workgroup, cluster)
72  // @TODO: Add more fields as needed
73 };
74 
75 enum class RegisterFileMode : uint8_t { Small, Large };
76 enum class RegisterFileType : uint8_t { GRF, ARF };
77 
78 // A struct to represent register file information
80  // Constructor
81  RegisterFileInfo() = default;
84  const llvm::SmallVector<uint32_t, 4> &numRegs)
85  : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
86 
87  // Get methods
88  uint32_t getSize() const { return size; }
89 
91  return mode;
92  }
93 
96  }
97 
98 protected:
99  uint32_t size; // size per register in bits
101  mode; // e.g., "small", "large" GRF modes
103  numRegsPerThreadPerMode; // number of registers per thread per mode
104 };
105 
106 enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
107 
108 // A struct to represent cache information
109 struct CacheInfo {
110  // Constructor
111  CacheInfo() = default;
112  CacheInfo(uint32_t size, uint32_t line_size,
115 
116  virtual ~CacheInfo() = default;
117 
118  // Get methods
119  uint32_t getSize() const { return size; }
120  uint32_t getLineSize() const { return line_size; }
122 
123 protected:
124  uint32_t size;
125  uint32_t line_size;
127  // @TODO: Add more fields as needed (e.g., associativity, num_banks,
128  // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
129  // latency, throughput, bandwidth)
130 };
131 
132 // A struct to represent the uArch
133 // This struct is used to represent the microarchitecture of a target device.
134 struct uArch {
135  // Constructor
137  const std::string &name, const std::string &description,
138  const std::map<RegisterFileType, RegisterFileInfo> &registerFileInfo = {},
140  const std::map<InstructionKind, std::shared_ptr<Instruction>>
141  &instructions = {})
145 
146  // Get methods
147  const std::string &getName() const { return name; }
148 
149  const std::string &getDescription() const { return description; }
150 
151  const std::map<RegisterFileType, RegisterFileInfo> &
153  return registerFileInfo;
154  }
155 
157  return cacheInfo;
158  }
159 
160  const std::map<InstructionKind, std::shared_ptr<Instruction>> &
161  getInstructions() const {
162  return instructions;
163  }
164 
165  // Get the name of the supported instruction names for that
166  // architecture. It returns the names of the instructions added to the uArch.
168  llvm::SmallVector<StringRef, 8> instructionNames;
169  for (const auto &inst : instructions) {
170  instructionNames.push_back(Instruction::toString(inst.first));
171  }
172  return instructionNames;
173  }
174 
175  // Checks if an instruction is supported in this uArch
177  return instructions.find(instr) != instructions.end();
178  }
179 
180 protected:
181  std::string name; // Name of the uArch, similar to target triple
182  std::string description;
183  std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
185  std::map<InstructionKind, std::shared_ptr<Instruction>>
186  instructions; // set of instructions supported by the uArch
187 };
188 
189 // A struct to represent shared memory information
190 struct SharedMemory {
191  // Constructor
192  SharedMemory(uint32_t size, uint32_t alignment)
193  : size(size), alignment(alignment) {}
194 
195  // Get methods
196  uint32_t getSize() const { return size; }
197  uint32_t getAlignment() const { return alignment; }
198 
199 protected:
200  uint32_t size; // in bytes
201  uint32_t alignment; // in bytes
202  // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
203 };
204 
205 struct XeCoreInfo {
206  uint32_t num_threads;
210 
212  uint32_t num_vector_units, uint32_t num_matrix_units)
215  }
216 };
217 
218 //===----------------------------------------------------------------------===//
219 // Interfaces
220 //===----------------------------------------------------------------------===//
223  // Get supported Matrix shapes
225  getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
226  // @TODO: This method takes an context object as a parameter, this is to
227  // create the Type objects from the same context. Since type objects are
228  // uniqued in a specific context, to do things like "aType == bType" (where
229  // aType and bType are both same type) kind of checks, the both types should
230  // be from the same context.
231  //
232  // One alternative to this is to create enum to represent each types, but this
233  // adds an extra burden to user to convert these enums to specific types. In
234  // fact the utility that would convert enumToType() and vice versa would still
235  // have to use the context object.
236  //
237  // Untill we have a better solution, we stick to passing context object to
238  // this method.
240  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
241  virtual bool
242  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
243  std::pair<uint32_t, uint32_t> BShape,
244  std::pair<uint32_t, uint32_t> CShape,
245  std::pair<uint32_t, uint32_t> DShape, Type AType,
246  Type BType, Type CType, Type DType) = 0;
247  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
248  Type DType) = 0;
249  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
250  std::pair<uint32_t, uint32_t> BShape,
251  std::pair<uint32_t, uint32_t> CShape,
252  std::pair<uint32_t, uint32_t> DShape, Type AType,
253  Type BType, Type CType, Type DType) = 0;
257 
258  virtual ~MMAInstructionInterface() = default;
259 };
260 
261 } // namespace uArch
262 } // namespace xegpu
263 } // namespace mlir
264 
265 #endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
uint32_t getLineSize() const
Definition: uArchBase.h:120
CacheHierarchyLevel hierarchy_level
Definition: uArchBase.h:126
virtual ~CacheInfo()=default
CacheInfo(uint32_t size, uint32_t line_size, CacheHierarchyLevel hierarchy_level)
Definition: uArchBase.h:112
CacheHierarchyLevel getHierarchyLevel() const
Definition: uArchBase.h:121
uint32_t getSize() const
Definition: uArchBase.h:119
Instruction(InstructionKind kind, InstructionScope scope)
Definition: uArchBase.h:46
static std::optional< InstructionKind > parseInstructionKind(llvm::StringRef str)
Definition: uArchBase.h:62
static llvm::StringRef toString(InstructionKind instKind)
Definition: uArchBase.h:53
virtual ~Instruction()=default
InstructionKind getInstructionKind()
Definition: uArchBase.h:51
InstructionScope getScope()
Definition: uArchBase.h:52
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type)=0
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType)=0
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType)=0
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type)=0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type)=0
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType)=0
const llvm::SmallVector< uint32_t, 4 > & getNumRegsPerThreadPerMode() const
Definition: uArchBase.h:94
llvm::SmallVector< RegisterFileMode, 4 > mode
Definition: uArchBase.h:101
RegisterFileInfo(uint32_t size, const llvm::SmallVector< RegisterFileMode, 4 > &mode, const llvm::SmallVector< uint32_t, 4 > &numRegs)
Definition: uArchBase.h:82
llvm::SmallVector< uint32_t, 4 > numRegsPerThreadPerMode
Definition: uArchBase.h:103
const llvm::SmallVector< RegisterFileMode, 4 > & getModes() const
Definition: uArchBase.h:90
uint32_t getAlignment() const
Definition: uArchBase.h:197
SharedMemory(uint32_t size, uint32_t alignment)
Definition: uArchBase.h:192
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory, uint32_t num_vector_units, uint32_t num_matrix_units)
Definition: uArchBase.h:211
const std::map< RegisterFileType, RegisterFileInfo > & getRegisterFileInfo() const
Definition: uArchBase.h:152
const llvm::SmallVector< CacheInfo, 4 > & getCacheInfo() const
Definition: uArchBase.h:156
const std::string & getDescription() const
Definition: uArchBase.h:149
const std::map< InstructionKind, std::shared_ptr< Instruction > > & getInstructions() const
Definition: uArchBase.h:161
const std::string & getName() const
Definition: uArchBase.h:147
llvm::SmallVector< StringRef, 8 > getSupportedInstructionNames() const
Definition: uArchBase.h:167
bool checkSupportedInstruction(InstructionKind instr) const
Definition: uArchBase.h:176
std::map< InstructionKind, std::shared_ptr< Instruction > > instructions
Definition: uArchBase.h:186
llvm::SmallVector< CacheInfo, 4 > cacheInfo
Definition: uArchBase.h:184
uArch(const std::string &name, const std::string &description, const std::map< RegisterFileType, RegisterFileInfo > &registerFileInfo={}, const llvm::SmallVector< CacheInfo, 4 > &cacheInfo={}, const std::map< InstructionKind, std::shared_ptr< Instruction >> &instructions={})
Definition: uArchBase.h:136
std::map< RegisterFileType, RegisterFileInfo > registerFileInfo
Definition: uArchBase.h:183