13#include "llvm/ADT/SmallSet.h"
14#include "llvm/ADT/TypeSwitch.h"
15#include "llvm/Support/FileSystem.h"
16#include "llvm/Support/MathExtras.h"
21#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
22#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
25static constexpr uint32_t subgroupSize = 16;
28LogicalResult verifyMatrixInput(
Op op) {
29 static_assert(llvm::is_one_of<
Op, BlockLoad2dOp, BlockStore2dOp,
30 BlockPrefetch2dOp>::value,
31 "Unexpected template parameter");
35 if (pitch && width && *pitch < *width)
37 "4th operand (base pitch) should be >= 2nd operand (base width)");
39 uint32_t elemSize = op.getElemSizeInBits();
40 if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
41 return op->
emitOpError(
"expecting 'elem_size_in_bits' to be 8, 16, or 32");
43 uint32_t tileHeight = op.getTileHeight();
44 if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
45 return op->
emitOpError(
"expecting tile_height to be 1, 2, 4, 8, 16, or 32");
47 uint32_t vBlocks = op.getVBlocks();
48 if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
49 return op->
emitOpError(
"expecting v_blocks to be 1, 2, 4, or 8");
54LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
55 VectorType resTy = op.getRes().getType();
56 if (!resTy.getElementType().isIntOrFloat())
57 return op.emitOpError()
58 <<
"expecting result element type to be int or float";
59 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
60 unsigned resSize = resTy.getNumElements() * resElemTySize;
61 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
62 op.getTileWidth() * op.getVBlocks() / subgroupSize;
63 if (resSize != expectedSize)
64 return op.emitOpError() <<
"result size of " << resSize
65 <<
" bits does not match the expected size of "
66 << expectedSize <<
" bits";
68 if (op.getTranspose() && op.getPackRegister())
69 return op.emitOpError(
"transpose and pack_register are mutually exclusive");
71 if (!op.getTranspose() && !op.getPackRegister()) {
72 uint32_t tileHeight = op.getTileHeight();
73 if (tileHeight < 1 || tileHeight > 32)
74 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
76 uint32_t tileWidth = op.getTileWidth();
77 uint32_t vBlocks = op.getVBlocks();
78 switch (op.getElemSizeInBits()) {
80 if (tileWidth < 4 || tileWidth > 64)
81 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
82 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
83 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
84 if (tileWidth * vBlocks > 64)
85 return op.emitOpError(
86 "tile_width * v_blocks should be less than or equal "
87 "to 64 for 8 bit elements");
90 if (tileWidth < 2 || tileWidth > 32)
91 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
92 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
93 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
94 if (tileWidth * vBlocks > 32)
95 return op.emitOpError(
96 "tile_width * v_blocks should be less than or equal "
97 "to 32 for 16 bit elements");
100 if (tileWidth < 1 || tileWidth > 16)
101 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
102 if (vBlocks != 1 && vBlocks != 2)
103 return op.emitOpError(
"expecting v_blocks to be 1 or 2");
104 if (tileWidth * vBlocks > 16)
105 return op.emitOpError(
106 "tile_width * v_blocks should be less than or equal "
107 "to 16 for 32 bit elements");
110 if (tileWidth < 1 || tileWidth > 8)
111 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
113 return op.emitOpError(
"expecting v_blocks to be 1");
116 return op.emitOpError(
117 "expecting elem_size_in_bits to be 8, 16, 32, or 64");
123 if (op.getTranspose()) {
124 assert(!op.getPackRegister() &&
"Expecting pack_register should be false");
126 uint32_t vBlocks = op.getVBlocks();
128 return op.emitOpError(
"expecting v_blocks to be 1");
130 uint32_t tileHeight = op.getTileHeight();
131 uint32_t tileWidth = op.getTileWidth();
132 switch (op.getElemSizeInBits()) {
134 if (tileHeight < 1 || tileHeight > 32)
135 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
136 if (tileWidth < 1 || tileWidth > 8)
137 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
141 return op.emitOpError(
142 "expecting tile_height to be 8 for 64 bit elements");
143 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
144 return op.emitOpError(
"expecting tile_width to be 1, 2, or 4");
147 return op.emitOpError(
"transpose is only supported for 32 and 64 bit "
154 assert(op.getPackRegister() && !op.getTranspose() &&
155 "Expecting pack_register should be true and transpose should be "
158 uint32_t vBlocks = op.getVBlocks();
159 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
160 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
162 uint32_t tileHeight = op.getTileHeight();
163 uint32_t tileWidth = op.getTileWidth();
164 switch (op.getElemSizeInBits()) {
166 if (tileHeight < 4 || tileHeight > 32)
167 return op.emitOpError(
"expecting tile_height to be between 4 and 32");
168 if (tileWidth < 4 || tileWidth > 16)
169 return op.emitOpError(
"expecting tile_width to be between 4 and 16");
172 if (tileHeight < 2 || tileHeight > 32)
173 return op.emitOpError(
"expecting tile_height to be between 2 and 32");
174 if (tileWidth < 2 || tileWidth > 16)
175 return op.emitOpError(
"expecting tile_width to be between 2 and 16");
176 if (tileWidth * vBlocks > 32)
177 return op.emitOpError(
178 "tile_width * v_blocks should be less than or equal "
179 "to 32 for 16 bit elements");
182 return op.emitOpError(
"pack_register is only supported for 8 and 16 bit "
189static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
190 uint32_t tileHeight = op.getTileHeight();
191 if (tileHeight < 1 || tileHeight > 8)
192 return op.emitOpError(
"expecting tile_height to be between 1 and 8");
194 uint32_t tileWidth = op.getTileWidth();
195 switch (op.getElemSizeInBits()) {
197 if (tileWidth < 4 || tileWidth > 64)
198 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
201 if (tileWidth < 2 || tileWidth > 32)
202 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
205 if (tileWidth < 1 || tileWidth > 16)
206 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
209 if (tileWidth < 1 || tileWidth > 8)
210 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
213 return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
216 uint32_t vBlocks = op.getVBlocks();
218 return op.emitOpError(
"expecting v_blocks to be 1");
224LogicalResult BlockLoad2dOp::verify() {
225 if (verify2DBlockLoadRestriction(*this).failed())
228 if (verifyMatrixInput(*this).failed())
231 VectorType resTy = getRes().getType();
232 if (!resTy.getElementType().isIntOrFloat())
233 return emitOpError() <<
"expecting result element type to be int of float";
234 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
235 if (getElemSizeInBits() == 32 || getPackRegister()) {
236 if (resElemTySize != 32)
237 return emitOpError() <<
"expecting result element type to be 32 bits";
240 uint32_t tileWidth = getTileWidth();
241 if (getPackRegister()) {
244 "tile_width when pack_register is true should be equal "
245 "to subgroup size (16 elements)");
252LogicalResult BlockStore2dOp::verify() {
253 if (verify2DBlockStoreRestriction(*this).failed())
256 if (verifyMatrixInput(*this).failed())
259 uint32_t tileWidth = getTileWidth();
260 switch (getElemSizeInBits()) {
262 if (tileWidth != 16 && tileWidth != 32)
263 return emitOpError(
"tile_width for 8 bit elements should be equal to "
268 return emitOpError(
"tile_width for 16 bit elements should be equal "
273 return emitOpError(
"tile_width for 32 bit elements should be equal "
277 llvm_unreachable(
"unexpected element size");
283LogicalResult BlockPrefetch2dOp::verify() {
284 if (verifyMatrixInput(*this).failed())
287 uint32_t tileWidth = getTileWidth();
288 switch (getElemSizeInBits()) {
290 if (tileWidth != 16 && tileWidth != 32)
291 return emitOpError(
"tile_width for 8 bit elements should be equal to "
296 return emitOpError(
"tile_width for 16 bit elements should be equal "
300 if (tileWidth != 8 && tileWidth != 16)
302 "tile_width for 32 bit elements should be equal to 8 or 16");
305 llvm_unreachable(
"unexpected element size");
311template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
312 OpType, BlockLoadOp, BlockStoreOp>::value>>
315 if constexpr (std::is_same_v<OpType, BlockLoadOp>)
316 srcOrDstTy = op.getResult().getType();
318 srcOrDstTy = op.getVal().getType();
319 VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
323 int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
324 if (elemTySize == 1) {
325 llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
326 if (validSizes.contains(vTy.getNumElements()))
329 return op.emitOpError(
330 "vector size must be 2, 4, 8 or 16 for 8-bit element type");
332 llvm::SmallSet<int, 3> validSizes{2, 4, 8};
333 if (validSizes.contains(vTy.getNumElements()))
336 return op.emitOpError(
337 "vector size must be 2, 4 or 8 for element type > 8 bits");
345LogicalResult MMAOp::verify() {
348 return emitOpError(
"type of C operand must match result type");
353LogicalResult MMAMxOp::verify() {
356 return emitOpError(
"type of C operand must match result type");
361LogicalResult TruncfOp::verify() {
362 Type srcTy = getSrc().getType();
363 Type dstTy = getDst().getType();
364 if (isa<VectorType>(srcTy) != isa<VectorType>(dstTy))
365 return emitOpError(
"both src and dst should be vector types or both should "
370 "dst element bitwidth should be less than src element bitwidth");
374LogicalResult ExtfOp::verify() {
375 Type srcTy = getSrc().getType();
376 Type dstTy = getDst().getType();
377 if (isa<VectorType>(srcTy) != isa<VectorType>(dstTy))
378 return emitOpError(
"both src and dst should be vector types or both should "
383 "dst element bitwidth should be greater than src element bitwidth");
389 StringRef triple, StringRef chip, DictionaryAttr flags,
391 if (O < 0 || O > 3) {
393 <<
"The optimization level must be a number between 0 and 3.";
395 if (triple.empty()) {
396 return emitError() <<
"The target triple cannot be empty.";
399 return emitError() <<
"The target chip cannot be empty.";
402 for (Attribute fileAttr : linkFiles) {
403 if (
auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
404 StringRef filePath = fileStrAttr.getValue();
405 if (filePath.empty()) {
406 return emitError() <<
"File paths in linkFiles cannot be empty.";
408 if (!llvm::sys::fs::exists(filePath)) {
409 return emitError() <<
"File '" << filePath <<
"' does not exist.";
417void XeVMDialect::initialize() {
420#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
424#define GET_ATTRDEF_LIST
425#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
427 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
428 mlir::xevm::XeVMTargetAttr>();
431#define GET_OP_CLASSES
432#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
434#define GET_ATTRDEF_CLASSES
435#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
LogicalResult verify1DBlockArg(OpType op)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::function_ref< Fn > function_ref