MLIR  22.0.0git
XeVMDialect.cpp
Go to the documentation of this file.
1 //===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
12 #include "llvm/ADT/SmallSet.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 #include "llvm/Support/FileSystem.h"
15 #include "llvm/Support/MathExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::xevm;
19 
20 #include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
21 #include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
22 
23 namespace {
24 static constexpr uint32_t subgroupSize = 16;
25 
26 template <typename Op>
27 LogicalResult verifyMatrixInput(Op op) {
28  static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
29  BlockPrefetch2dOp>::value,
30  "Unexpected template parameter");
31 
32  std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
33  std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
34  if (pitch && width && *pitch < *width)
35  return op->emitOpError(
36  "4th operand (base pitch) should be >= 2nd operand (base width)");
37 
38  uint32_t elemSize = op.getElemSizeInBits();
39  if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
40  return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
41 
42  uint32_t tileHeight = op.getTileHeight();
43  if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
44  return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
45 
46  uint32_t vBlocks = op.getVBlocks();
47  if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
48  return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
49 
50  return success();
51 }
52 
53 LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
54  VectorType resTy = op.getRes().getType();
55  if (!resTy.getElementType().isIntOrFloat())
56  return op.emitOpError()
57  << "expecting result element type to be int or float";
58  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
59  unsigned resSize = resTy.getNumElements() * resElemTySize;
60  unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
61  op.getTileWidth() * op.getVBlocks() / subgroupSize;
62  if (resSize != expectedSize)
63  return op.emitOpError() << "result size of " << resSize
64  << " bits does not match the expected size of "
65  << expectedSize << " bits";
66 
67  if (op.getTranspose() && op.getPackRegister())
68  return op.emitOpError("transpose and pack_register are mutually exclusive");
69 
70  if (!op.getTranspose() && !op.getPackRegister()) {
71  uint32_t tileHeight = op.getTileHeight();
72  if (tileHeight < 1 || tileHeight > 32)
73  return op.emitOpError("expecting tile_height to be between 1 and 32");
74 
75  uint32_t tileWidth = op.getTileWidth();
76  uint32_t vBlocks = op.getVBlocks();
77  switch (op.getElemSizeInBits()) {
78  case 8:
79  if (tileWidth < 4 || tileWidth > 64)
80  return op.emitOpError("expecting tile_width to be between 4 and 64");
81  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
82  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
83  if (tileWidth * vBlocks > 64)
84  return op.emitOpError(
85  "tile_width * v_blocks should be less than or equal "
86  "to 64 for 8 bit elements");
87  break;
88  case 16:
89  if (tileWidth < 2 || tileWidth > 32)
90  return op.emitOpError("expecting tile_width to be between 2 and 32");
91  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
92  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
93  if (tileWidth * vBlocks > 32)
94  return op.emitOpError(
95  "tile_width * v_blocks should be less than or equal "
96  "to 32 for 16 bit elements");
97  break;
98  case 32:
99  if (tileWidth < 1 || tileWidth > 16)
100  return op.emitOpError("expecting tile_width to be between 1 and 16");
101  if (vBlocks != 1 && vBlocks != 2)
102  return op.emitOpError("expecting v_blocks to be 1 or 2");
103  if (tileWidth * vBlocks > 16)
104  return op.emitOpError(
105  "tile_width * v_blocks should be less than or equal "
106  "to 16 for 32 bit elements");
107  break;
108  case 64:
109  if (tileWidth < 1 || tileWidth > 8)
110  return op.emitOpError("expecting tile_width to be between 1 and 8");
111  if (vBlocks != 1)
112  return op.emitOpError("expecting v_blocks to be 1");
113  break;
114  default:
115  return op.emitOpError(
116  "expecting elem_size_in_bits to be 8, 16, 32, or 64");
117  }
118 
119  return success();
120  }
121 
122  if (op.getTranspose()) {
123  assert(!op.getPackRegister() && "Expecting pack_register should be false");
124 
125  uint32_t vBlocks = op.getVBlocks();
126  if (vBlocks != 1)
127  return op.emitOpError("expecting v_blocks to be 1");
128 
129  uint32_t tileHeight = op.getTileHeight();
130  uint32_t tileWidth = op.getTileWidth();
131  switch (op.getElemSizeInBits()) {
132  case 32:
133  if (tileHeight < 1 || tileHeight > 32)
134  return op.emitOpError("expecting tile_height to be between 1 and 32");
135  if (tileWidth < 1 || tileWidth > 8)
136  return op.emitOpError("expecting tile_width to be between 1 and 8");
137  break;
138  case 64:
139  if (tileHeight != 8)
140  return op.emitOpError(
141  "expecting tile_height to be 8 for 64 bit elements");
142  if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
143  return op.emitOpError("expecting tile_width to be 1, 2, or 4");
144  break;
145  default:
146  return op.emitOpError("transpose is only supported for 32 and 64 bit "
147  "elements");
148  }
149 
150  return success();
151  }
152 
153  assert(op.getPackRegister() && !op.getTranspose() &&
154  "Expecting pack_register should be true and transpose should be "
155  "false");
156 
157  uint32_t vBlocks = op.getVBlocks();
158  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
159  return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
160 
161  uint32_t tileHeight = op.getTileHeight();
162  uint32_t tileWidth = op.getTileWidth();
163  switch (op.getElemSizeInBits()) {
164  case 8:
165  if (tileHeight < 4 || tileHeight > 32)
166  return op.emitOpError("expecting tile_height to be between 4 and 32");
167  if (tileWidth < 4 || tileWidth > 16)
168  return op.emitOpError("expecting tile_width to be between 4 and 16");
169  break;
170  case 16:
171  if (tileHeight < 2 || tileHeight > 32)
172  return op.emitOpError("expecting tile_height to be between 2 and 32");
173  if (tileWidth < 2 || tileWidth > 16)
174  return op.emitOpError("expecting tile_width to be between 2 and 16");
175  if (tileWidth * vBlocks > 32)
176  return op.emitOpError(
177  "tile_width * v_blocks should be less than or equal "
178  "to 32 for 16 bit elements");
179  break;
180  default:
181  return op.emitOpError("pack_register is only supported for 8 and 16 bit "
182  "elements");
183  }
184 
185  return success();
186 }
187 
188 static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
189  uint32_t tileHeight = op.getTileHeight();
190  if (tileHeight < 1 || tileHeight > 8)
191  return op.emitOpError("expecting tile_height to be between 1 and 8");
192 
193  uint32_t tileWidth = op.getTileWidth();
194  switch (op.getElemSizeInBits()) {
195  case 8:
196  if (tileWidth < 4 || tileWidth > 64)
197  return op.emitOpError("expecting tile_width to be between 4 and 64");
198  break;
199  case 16:
200  if (tileWidth < 2 || tileWidth > 32)
201  return op.emitOpError("expecting tile_width to be between 2 and 32");
202  break;
203  case 32:
204  if (tileWidth < 1 || tileWidth > 16)
205  return op.emitOpError("expecting tile_width to be between 1 and 16");
206  break;
207  case 64:
208  if (tileWidth < 1 || tileWidth > 8)
209  return op.emitOpError("expecting tile_width to be between 1 and 8");
210  break;
211  default:
212  return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
213  }
214 
215  uint32_t vBlocks = op.getVBlocks();
216  if (vBlocks != 1)
217  return op.emitOpError("expecting v_blocks to be 1");
218  return success();
219 }
220 
221 } // namespace
222 
223 LogicalResult BlockLoad2dOp::verify() {
224  if (verify2DBlockLoadRestriction(*this).failed())
225  return failure();
226 
227  if (verifyMatrixInput(*this).failed())
228  return failure();
229 
230  VectorType resTy = getRes().getType();
231  if (!resTy.getElementType().isIntOrFloat())
232  return emitOpError() << "expecting result element type to be int of float";
233  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
234  if (getElemSizeInBits() == 32 || getPackRegister()) {
235  if (resElemTySize != 32)
236  return emitOpError() << "expecting result element type to be 32 bits";
237  }
238 
239  uint32_t tileWidth = getTileWidth();
240  if (getPackRegister()) {
241  if (tileWidth != 16)
242  return emitOpError(
243  "tile_width when pack_register is true should be equal "
244  "to subgroup size (16 elements)");
245  return success();
246  }
247 
248  return success();
249 }
250 
251 LogicalResult BlockStore2dOp::verify() {
252  if (verify2DBlockStoreRestriction(*this).failed())
253  return failure();
254 
255  if (verifyMatrixInput(*this).failed())
256  return failure();
257 
258  uint32_t tileWidth = getTileWidth();
259  switch (getElemSizeInBits()) {
260  case 8:
261  if (tileWidth != 16 && tileWidth != 32)
262  return emitOpError("tile_width for 8 bit elements should be equal to "
263  "16 or 32");
264  break;
265  case 16:
266  if (tileWidth != 16)
267  return emitOpError("tile_width for 16 bit elements should be equal "
268  "to 16");
269  break;
270  case 32:
271  if (tileWidth != 16)
272  return emitOpError("tile_width for 32 bit elements should be equal "
273  "to 16");
274  break;
275  default:
276  llvm_unreachable("unexpected element size");
277  }
278 
279  return success();
280 }
281 
282 LogicalResult BlockPrefetch2dOp::verify() {
283  if (verifyMatrixInput(*this).failed())
284  return failure();
285 
286  uint32_t tileWidth = getTileWidth();
287  switch (getElemSizeInBits()) {
288  case 8:
289  if (tileWidth != 16 && tileWidth != 32)
290  return emitOpError("tile_width for 8 bit elements should be equal to "
291  "16 or 32");
292  break;
293  case 16:
294  if (tileWidth != 16)
295  return emitOpError("tile_width for 16 bit elements should be equal "
296  "to 16");
297  break;
298  case 32:
299  if (tileWidth != 8 && tileWidth != 16)
300  return emitOpError(
301  "tile_width for 32 bit elements should be equal to 8 or 16");
302  break;
303  default:
304  llvm_unreachable("unexpected element size");
305  }
306 
307  return success();
308 }
309 
310 template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
311  OpType, BlockLoadOp, BlockStoreOp>::value>>
312 LogicalResult verify1DBlockArg(OpType op) {
313  Type srcOrDstTy;
314  if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315  srcOrDstTy = op.getResult().getType();
316  else
317  srcOrDstTy = op.getVal().getType();
318  VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
319  // scalar case is always valid
320  if (!vTy)
321  return success();
322  int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
323  if (elemTySize == 1) {
324  llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
325  if (validSizes.contains(vTy.getNumElements()))
326  return success();
327  else
328  return op.emitOpError(
329  "vector size must be 2, 4, 8 or 16 for 8-bit element type");
330  } else {
331  llvm::SmallSet<int, 3> validSizes{2, 4, 8};
332  if (validSizes.contains(vTy.getNumElements()))
333  return success();
334  else
335  return op.emitOpError(
336  "vector size must be 2, 4 or 8 for element type > 8 bits");
337  }
338 }
339 
340 LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }
341 
342 LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }
343 
344 LogicalResult MMAOp::verify() {
345  if (getC()) {
346  if (getResult().getType() != getC().getType())
347  return emitOpError("type of C operand must match result type");
348  }
349  return success();
350 }
351 
352 LogicalResult
354  StringRef triple, StringRef chip, DictionaryAttr flags,
355  ArrayAttr linkFiles) {
356  if (O < 0 || O > 3) {
357  return emitError()
358  << "The optimization level must be a number between 0 and 3.";
359  }
360  if (triple.empty()) {
361  return emitError() << "The target triple cannot be empty.";
362  }
363  if (chip.empty()) {
364  return emitError() << "The target chip cannot be empty.";
365  }
366  if (linkFiles) {
367  for (Attribute fileAttr : linkFiles) {
368  if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
369  StringRef filePath = fileStrAttr.getValue();
370  if (filePath.empty()) {
371  return emitError() << "File paths in linkFiles cannot be empty.";
372  }
373  if (!llvm::sys::fs::exists(filePath)) {
374  return emitError() << "File '" << filePath << "' does not exist.";
375  }
376  }
377  }
378  }
379  return success();
380 }
381 
382 void XeVMDialect::initialize() {
383  addOperations<
384 #define GET_OP_LIST
385 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
386  >();
387 
388  addAttributes<
389 #define GET_ATTRDEF_LIST
390 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
391  >();
392  declarePromisedInterface<mlir::gpu::TargetAttrInterface,
393  mlir::xevm::XeVMTargetAttr>();
394 }
395 
396 #define GET_OP_CLASSES
397 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
398 
399 #define GET_ATTRDEF_CLASSES
400 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
LogicalResult verify1DBlockArg(OpType op)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:837
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
constexpr unsigned subgroupSize
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423