12 #include "llvm/ADT/SmallSet.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 #include "llvm/Support/FileSystem.h"
15 #include "llvm/Support/MathExtras.h"
20 #include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
21 #include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
26 template <
typename Op>
27 LogicalResult verifyMatrixInput(
Op op) {
28 static_assert(llvm::is_one_of<
Op, BlockLoad2dOp, BlockStore2dOp,
29 BlockPrefetch2dOp>::value,
30 "Unexpected template parameter");
34 if (pitch && width && *pitch < *width)
36 "4th operand (base pitch) should be >= 2nd operand (base width)");
38 uint32_t elemSize = op.getElemSizeInBits();
39 if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
40 return op->
emitOpError(
"expecting 'elem_size_in_bits' to be 8, 16, or 32");
42 uint32_t tileHeight = op.getTileHeight();
43 if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
44 return op->
emitOpError(
"expecting tile_height to be 1, 2, 4, 8, 16, or 32");
46 uint32_t vBlocks = op.getVBlocks();
47 if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
48 return op->
emitOpError(
"expecting v_blocks to be 1, 2, 4, or 8");
53 LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
54 VectorType resTy = op.getRes().getType();
55 if (!resTy.getElementType().isIntOrFloat())
56 return op.emitOpError()
57 <<
"expecting result element type to be int or float";
58 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
59 unsigned resSize = resTy.getNumElements() * resElemTySize;
60 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
62 if (resSize != expectedSize)
63 return op.emitOpError() <<
"result size of " << resSize
64 <<
" bits does not match the expected size of "
65 << expectedSize <<
" bits";
67 if (op.getTranspose() && op.getPackRegister())
68 return op.emitOpError(
"transpose and pack_register are mutually exclusive");
70 if (!op.getTranspose() && !op.getPackRegister()) {
71 uint32_t tileHeight = op.getTileHeight();
72 if (tileHeight < 1 || tileHeight > 32)
73 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
75 uint32_t tileWidth = op.getTileWidth();
76 uint32_t vBlocks = op.getVBlocks();
77 switch (op.getElemSizeInBits()) {
79 if (tileWidth < 4 || tileWidth > 64)
80 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
81 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
82 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
83 if (tileWidth * vBlocks > 64)
84 return op.emitOpError(
85 "tile_width * v_blocks should be less than or equal "
86 "to 64 for 8 bit elements");
89 if (tileWidth < 2 || tileWidth > 32)
90 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
91 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
92 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
93 if (tileWidth * vBlocks > 32)
94 return op.emitOpError(
95 "tile_width * v_blocks should be less than or equal "
96 "to 32 for 16 bit elements");
99 if (tileWidth < 1 || tileWidth > 16)
100 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
101 if (vBlocks != 1 && vBlocks != 2)
102 return op.emitOpError(
"expecting v_blocks to be 1 or 2");
103 if (tileWidth * vBlocks > 16)
104 return op.emitOpError(
105 "tile_width * v_blocks should be less than or equal "
106 "to 16 for 32 bit elements");
109 if (tileWidth < 1 || tileWidth > 8)
110 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
112 return op.emitOpError(
"expecting v_blocks to be 1");
115 return op.emitOpError(
116 "expecting elem_size_in_bits to be 8, 16, 32, or 64");
122 if (op.getTranspose()) {
123 assert(!op.getPackRegister() &&
"Expecting pack_register should be false");
125 uint32_t vBlocks = op.getVBlocks();
127 return op.emitOpError(
"expecting v_blocks to be 1");
129 uint32_t tileHeight = op.getTileHeight();
130 uint32_t tileWidth = op.getTileWidth();
131 switch (op.getElemSizeInBits()) {
133 if (tileHeight < 1 || tileHeight > 32)
134 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
135 if (tileWidth < 1 || tileWidth > 8)
136 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
140 return op.emitOpError(
141 "expecting tile_height to be 8 for 64 bit elements");
142 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
143 return op.emitOpError(
"expecting tile_width to be 1, 2, or 4");
146 return op.emitOpError(
"transpose is only supported for 32 and 64 bit "
153 assert(op.getPackRegister() && !op.getTranspose() &&
154 "Expecting pack_register should be true and transpose should be "
157 uint32_t vBlocks = op.getVBlocks();
158 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
159 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
161 uint32_t tileHeight = op.getTileHeight();
162 uint32_t tileWidth = op.getTileWidth();
163 switch (op.getElemSizeInBits()) {
165 if (tileHeight < 4 || tileHeight > 32)
166 return op.emitOpError(
"expecting tile_height to be between 4 and 32");
167 if (tileWidth < 4 || tileWidth > 16)
168 return op.emitOpError(
"expecting tile_width to be between 4 and 16");
171 if (tileHeight < 2 || tileHeight > 32)
172 return op.emitOpError(
"expecting tile_height to be between 2 and 32");
173 if (tileWidth < 2 || tileWidth > 16)
174 return op.emitOpError(
"expecting tile_width to be between 2 and 16");
175 if (tileWidth * vBlocks > 32)
176 return op.emitOpError(
177 "tile_width * v_blocks should be less than or equal "
178 "to 32 for 16 bit elements");
181 return op.emitOpError(
"pack_register is only supported for 8 and 16 bit "
188 static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
189 uint32_t tileHeight = op.getTileHeight();
190 if (tileHeight < 1 || tileHeight > 8)
191 return op.emitOpError(
"expecting tile_height to be between 1 and 8");
193 uint32_t tileWidth = op.getTileWidth();
194 switch (op.getElemSizeInBits()) {
196 if (tileWidth < 4 || tileWidth > 64)
197 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
200 if (tileWidth < 2 || tileWidth > 32)
201 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
204 if (tileWidth < 1 || tileWidth > 16)
205 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
208 if (tileWidth < 1 || tileWidth > 8)
209 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
212 return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
215 uint32_t vBlocks = op.getVBlocks();
217 return op.emitOpError(
"expecting v_blocks to be 1");
224 if (verify2DBlockLoadRestriction(*this).failed())
227 if (verifyMatrixInput(*this).failed())
230 VectorType resTy = getRes().getType();
231 if (!resTy.getElementType().isIntOrFloat())
232 return emitOpError() <<
"expecting result element type to be int of float";
233 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
234 if (getElemSizeInBits() == 32 || getPackRegister()) {
235 if (resElemTySize != 32)
236 return emitOpError() <<
"expecting result element type to be 32 bits";
239 uint32_t tileWidth = getTileWidth();
240 if (getPackRegister()) {
243 "tile_width when pack_register is true should be equal "
244 "to subgroup size (16 elements)");
252 if (verify2DBlockStoreRestriction(*this).failed())
255 if (verifyMatrixInput(*this).failed())
258 uint32_t tileWidth = getTileWidth();
259 switch (getElemSizeInBits()) {
261 if (tileWidth != 16 && tileWidth != 32)
262 return emitOpError(
"tile_width for 8 bit elements should be equal to "
267 return emitOpError(
"tile_width for 16 bit elements should be equal "
272 return emitOpError(
"tile_width for 32 bit elements should be equal "
276 llvm_unreachable(
"unexpected element size");
283 if (verifyMatrixInput(*this).failed())
286 uint32_t tileWidth = getTileWidth();
287 switch (getElemSizeInBits()) {
289 if (tileWidth != 16 && tileWidth != 32)
290 return emitOpError(
"tile_width for 8 bit elements should be equal to "
295 return emitOpError(
"tile_width for 16 bit elements should be equal "
299 if (tileWidth != 8 && tileWidth != 16)
301 "tile_width for 32 bit elements should be equal to 8 or 16");
304 llvm_unreachable(
"unexpected element size");
310 template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
311 OpType, BlockLoadOp, BlockStoreOp>::value>>
314 if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315 srcOrDstTy = op.getResult().getType();
317 srcOrDstTy = op.getVal().getType();
318 VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
322 int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
323 if (elemTySize == 1) {
324 llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
325 if (validSizes.contains(vTy.getNumElements()))
328 return op.emitOpError(
329 "vector size must be 2, 4, 8 or 16 for 8-bit element type");
331 llvm::SmallSet<int, 3> validSizes{2, 4, 8};
332 if (validSizes.contains(vTy.getNumElements()))
335 return op.emitOpError(
336 "vector size must be 2, 4 or 8 for element type > 8 bits");
347 return emitOpError(
"type of C operand must match result type");
354 StringRef triple, StringRef chip, DictionaryAttr flags,
355 ArrayAttr linkFiles) {
356 if (O < 0 || O > 3) {
358 <<
"The optimization level must be a number between 0 and 3.";
360 if (triple.empty()) {
361 return emitError() <<
"The target triple cannot be empty.";
364 return emitError() <<
"The target chip cannot be empty.";
368 if (
auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
369 StringRef filePath = fileStrAttr.getValue();
370 if (filePath.empty()) {
371 return emitError() <<
"File paths in linkFiles cannot be empty.";
373 if (!llvm::sys::fs::exists(filePath)) {
374 return emitError() <<
"File '" << filePath <<
"' does not exist.";
382 void XeVMDialect::initialize() {
385 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
389 #define GET_ATTRDEF_LIST
390 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
392 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
393 mlir::xevm::XeVMTargetAttr>();
396 #define GET_OP_CLASSES
397 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
399 #define GET_ATTRDEF_CLASSES
400 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
LogicalResult verify1DBlockArg(OpType op)
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
constexpr unsigned subgroupSize
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...