MLIR  22.0.0git
XeVMDialect.cpp
Go to the documentation of this file.
1 //===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/FileSystem.h"
14 #include "llvm/Support/MathExtras.h"
15 
16 using namespace mlir;
17 using namespace mlir::xevm;
18 
19 #include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
20 #include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
21 
22 namespace {
23 static constexpr uint32_t subgroupSize = 16;
24 
25 template <typename Op>
26 LogicalResult verifyMatrixInput(Op op) {
27  static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
28  BlockPrefetch2dOp>::value,
29  "Unexpected template parameter");
30 
31  std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
32  std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
33  if (pitch && width && *pitch < *width)
34  return op->emitOpError(
35  "4th operand (base pitch) should be >= 2nd operand (base width)");
36 
37  uint32_t elemSize = op.getElemSizeInBits();
38  if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
39  return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
40 
41  uint32_t tileHeight = op.getTileHeight();
42  if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
43  return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
44 
45  uint32_t vBlocks = op.getVBlocks();
46  if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
47  return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
48 
49  return success();
50 }
51 
52 LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
53  VectorType resTy = op.getRes().getType();
54  if (!resTy.getElementType().isIntOrFloat())
55  return op.emitOpError()
56  << "expecting result element type to be int or float";
57  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
58  unsigned resSize = resTy.getNumElements() * resElemTySize;
59  unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
60  op.getTileWidth() * op.getVBlocks() / subgroupSize;
61  if (resSize != expectedSize)
62  return op.emitOpError() << "result size of " << resSize
63  << " bits does not match the expected size of "
64  << expectedSize << " bits";
65 
66  if (op.getTranspose() && op.getPackRegister())
67  return op.emitOpError("transpose and pack_register are mutually exclusive");
68 
69  if (!op.getTranspose() && !op.getPackRegister()) {
70  uint32_t tileHeight = op.getTileHeight();
71  if (tileHeight < 1 || tileHeight > 32)
72  return op.emitOpError("expecting tile_height to be between 1 and 32");
73 
74  uint32_t tileWidth = op.getTileWidth();
75  uint32_t vBlocks = op.getVBlocks();
76  switch (op.getElemSizeInBits()) {
77  case 8:
78  if (tileWidth < 4 || tileWidth > 64)
79  return op.emitOpError("expecting tile_width to be between 4 and 64");
80  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
81  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
82  if (tileWidth * vBlocks > 64)
83  return op.emitOpError(
84  "tile_width * v_blocks should be less than or equal "
85  "to 64 for 8 bit elements");
86  break;
87  case 16:
88  if (tileWidth < 2 || tileWidth > 32)
89  return op.emitOpError("expecting tile_width to be between 2 and 32");
90  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
91  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
92  if (tileWidth * vBlocks > 32)
93  return op.emitOpError(
94  "tile_width * v_blocks should be less than or equal "
95  "to 32 for 16 bit elements");
96  break;
97  case 32:
98  if (tileWidth < 1 || tileWidth > 16)
99  return op.emitOpError("expecting tile_width to be between 1 and 16");
100  if (vBlocks != 1 && vBlocks != 2)
101  return op.emitOpError("expecting v_blocks to be 1 or 2");
102  if (tileWidth * vBlocks > 16)
103  return op.emitOpError(
104  "tile_width * v_blocks should be less than or equal "
105  "to 16 for 32 bit elements");
106  break;
107  case 64:
108  if (tileWidth < 1 || tileWidth > 8)
109  return op.emitOpError("expecting tile_width to be between 1 and 8");
110  if (vBlocks != 1)
111  return op.emitOpError("expecting v_blocks to be 1");
112  break;
113  default:
114  return op.emitOpError(
115  "expecting elem_size_in_bits to be 8, 16, 32, or 64");
116  }
117 
118  return success();
119  }
120 
121  if (op.getTranspose()) {
122  assert(!op.getPackRegister() && "Expecting pack_register should be false");
123 
124  uint32_t vBlocks = op.getVBlocks();
125  if (vBlocks != 1)
126  return op.emitOpError("expecting v_blocks to be 1");
127 
128  uint32_t tileHeight = op.getTileHeight();
129  uint32_t tileWidth = op.getTileWidth();
130  switch (op.getElemSizeInBits()) {
131  case 32:
132  if (tileHeight < 1 || tileHeight > 32)
133  return op.emitOpError("expecting tile_height to be between 1 and 32");
134  if (tileWidth < 1 || tileWidth > 8)
135  return op.emitOpError("expecting tile_width to be between 1 and 8");
136  break;
137  case 64:
138  if (tileHeight != 8)
139  return op.emitOpError(
140  "expecting tile_height to be 8 for 64 bit elements");
141  if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
142  return op.emitOpError("expecting tile_width to be 1, 2, or 4");
143  break;
144  default:
145  return op.emitOpError("transpose is only supported for 32 and 64 bit "
146  "elements");
147  }
148 
149  return success();
150  }
151 
152  assert(op.getPackRegister() && !op.getTranspose() &&
153  "Expecting pack_register should be true and transpose should be "
154  "false");
155 
156  uint32_t vBlocks = op.getVBlocks();
157  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
158  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
159 
160  uint32_t tileHeight = op.getTileHeight();
161  uint32_t tileWidth = op.getTileWidth();
162  switch (op.getElemSizeInBits()) {
163  case 8:
164  if (tileHeight < 4 || tileHeight > 32)
165  return op.emitOpError("expecting tile_height to be between 4 and 32");
166  if (tileWidth < 4 || tileWidth > 16)
167  return op.emitOpError("expecting tile_width to be between 4 and 16");
168  break;
169  case 16:
170  if (tileHeight < 2 || tileHeight > 32)
171  return op.emitOpError("expecting tile_height to be between 2 and 32");
172  if (tileWidth < 2 || tileWidth > 16)
173  return op.emitOpError("expecting tile_width to be between 2 and 16");
174  if (tileWidth * vBlocks > 32)
175  return op.emitOpError(
176  "tile_width * v_blocks should be less than or equal "
177  "to 32 for 16 bit elements");
178  break;
179  default:
180  return op.emitOpError("pack_register is only supported for 8 and 16 bit "
181  "elements");
182  }
183 
184  return success();
185 }
186 
187 static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
188  uint32_t tileHeight = op.getTileHeight();
189  if (tileHeight < 1 || tileHeight > 8)
190  return op.emitOpError("expecting tile_height to be between 1 and 8");
191 
192  uint32_t tileWidth = op.getTileWidth();
193  switch (op.getElemSizeInBits()) {
194  case 8:
195  if (tileWidth < 4 || tileWidth > 64)
196  return op.emitOpError("expecting tile_width to be between 4 and 64");
197  break;
198  case 16:
199  if (tileWidth < 2 || tileWidth > 32)
200  return op.emitOpError("expecting tile_width to be between 2 and 32");
201  break;
202  case 32:
203  if (tileWidth < 1 || tileWidth > 16)
204  return op.emitOpError("expecting tile_width to be between 1 and 16");
205  break;
206  case 64:
207  if (tileWidth < 1 || tileWidth > 8)
208  return op.emitOpError("expecting tile_width to be between 1 and 8");
209  break;
210  default:
211  return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
212  }
213 
214  uint32_t vBlocks = op.getVBlocks();
215  if (vBlocks != 1)
216  return op.emitOpError("expecting v_blocks to be 1");
217  return success();
218 }
219 
220 } // namespace
221 
222 LogicalResult BlockLoad2dOp::verify() {
223  if (verify2DBlockLoadRestriction(*this).failed())
224  return failure();
225 
226  if (verifyMatrixInput(*this).failed())
227  return failure();
228 
229  VectorType resTy = getRes().getType();
230  if (!resTy.getElementType().isIntOrFloat())
231  return emitOpError() << "expecting result element type to be int of float";
232  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
233  if (getElemSizeInBits() == 32 || getPackRegister()) {
234  if (resElemTySize != 32)
235  return emitOpError() << "expecting result element type to be 32 bits";
236  }
237 
238  uint32_t tileWidth = getTileWidth();
239  if (getPackRegister()) {
240  if (tileWidth != 16)
241  return emitOpError(
242  "tile_width when pack_register is true should be equal "
243  "to subgroup size (16 elements)");
244  return success();
245  }
246 
247  return success();
248 }
249 
250 LogicalResult BlockStore2dOp::verify() {
251  if (verify2DBlockStoreRestriction(*this).failed())
252  return failure();
253 
254  if (verifyMatrixInput(*this).failed())
255  return failure();
256 
257  uint32_t tileWidth = getTileWidth();
258  switch (getElemSizeInBits()) {
259  case 8:
260  if (tileWidth != 16 && tileWidth != 32)
261  return emitOpError("tile_width for 8 bit elements should be equal to "
262  "16 or 32");
263  break;
264  case 16:
265  if (tileWidth != 16)
266  return emitOpError("tile_width for 16 bit elements should be equal "
267  "to 16");
268  break;
269  case 32:
270  if (tileWidth != 16)
271  return emitOpError("tile_width for 32 bit elements should be equal "
272  "to 16");
273  break;
274  default:
275  llvm_unreachable("unexpected element size");
276  }
277 
278  return success();
279 }
280 
281 LogicalResult BlockPrefetch2dOp::verify() {
282  if (verifyMatrixInput(*this).failed())
283  return failure();
284 
285  uint32_t tileWidth = getTileWidth();
286  switch (getElemSizeInBits()) {
287  case 8:
288  if (tileWidth != 16 && tileWidth != 32)
289  return emitOpError("tile_width for 8 bit elements should be equal to "
290  "16 or 32");
291  break;
292  case 16:
293  if (tileWidth != 16)
294  return emitOpError("tile_width for 16 bit elements should be equal "
295  "to 16");
296  break;
297  case 32:
298  if (tileWidth != 8 && tileWidth != 16)
299  return emitOpError(
300  "tile_width for 32 bit elements should be equal to 8 or 16");
301  break;
302  default:
303  llvm_unreachable("unexpected element size");
304  }
305 
306  return success();
307 }
308 
309 LogicalResult MMAOp::verify() {
310  if (getC()) {
311  if (getResult().getType() != getC().getType())
312  return emitOpError("type of C operand must match result type");
313  }
314  return success();
315 }
316 
317 LogicalResult
319  StringRef triple, StringRef chip, DictionaryAttr flags,
320  ArrayAttr linkFiles) {
321  if (O < 0 || O > 3) {
322  return emitError()
323  << "The optimization level must be a number between 0 and 3.";
324  }
325  if (triple.empty()) {
326  return emitError() << "The target triple cannot be empty.";
327  }
328  if (chip.empty()) {
329  return emitError() << "The target chip cannot be empty.";
330  }
331  if (linkFiles) {
332  for (Attribute fileAttr : linkFiles) {
333  if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
334  StringRef filePath = fileStrAttr.getValue();
335  if (filePath.empty()) {
336  return emitError() << "File paths in linkFiles cannot be empty.";
337  }
338  if (!llvm::sys::fs::exists(filePath)) {
339  return emitError() << "File '" << filePath << "' does not exist.";
340  }
341  }
342  }
343  }
344  return success();
345 }
346 
347 void XeVMDialect::initialize() {
348  addOperations<
349 #define GET_OP_LIST
350 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
351  >();
352 
353  addAttributes<
354 #define GET_ATTRDEF_LIST
355 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
356  >();
357  declarePromisedInterface<mlir::gpu::TargetAttrInterface,
358  mlir::xevm::XeVMTargetAttr>();
359 }
360 
361 #define GET_OP_CLASSES
362 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
363 
364 #define GET_ATTRDEF_CLASSES
365 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:836
This provides public APIs that all operations should have.
constexpr unsigned subgroupSize
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423