12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/FileSystem.h"
14 #include "llvm/Support/MathExtras.h"
19 #include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
20 #include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
25 template <
typename Op>
26 LogicalResult verifyMatrixInput(
Op op) {
27 static_assert(llvm::is_one_of<
Op, BlockLoad2dOp, BlockStore2dOp,
28 BlockPrefetch2dOp>::value,
29 "Unexpected template parameter");
33 if (pitch && width && *pitch < *width)
35 "4th operand (base pitch) should be >= 2nd operand (base width)");
37 uint32_t elemSize = op.getElemSizeInBits();
38 if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
39 return op->
emitOpError(
"expecting 'elem_size_in_bits' to be 8, 16, or 32");
41 uint32_t tileHeight = op.getTileHeight();
42 if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
43 return op->
emitOpError(
"expecting tile_height to be 1, 2, 4, 8, 16, or 32");
45 uint32_t vBlocks = op.getVBlocks();
46 if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
47 return op->
emitOpError(
"expecting v_blocks to be 1, 2, 4, or 8");
52 LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
53 VectorType resTy = op.getRes().getType();
54 if (!resTy.getElementType().isIntOrFloat())
55 return op.emitOpError()
56 <<
"expecting result element type to be int or float";
57 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
58 unsigned resSize = resTy.getNumElements() * resElemTySize;
59 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
61 if (resSize != expectedSize)
62 return op.emitOpError() <<
"result size of " << resSize
63 <<
" bits does not match the expected size of "
64 << expectedSize <<
" bits";
66 if (op.getTranspose() && op.getPackRegister())
67 return op.emitOpError(
"transpose and pack_register are mutually exclusive");
69 if (!op.getTranspose() && !op.getPackRegister()) {
70 uint32_t tileHeight = op.getTileHeight();
71 if (tileHeight < 1 || tileHeight > 32)
72 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
74 uint32_t tileWidth = op.getTileWidth();
75 uint32_t vBlocks = op.getVBlocks();
76 switch (op.getElemSizeInBits()) {
78 if (tileWidth < 4 || tileWidth > 64)
79 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
80 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
81 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
82 if (tileWidth * vBlocks > 64)
83 return op.emitOpError(
84 "tile_width * v_blocks should be less than or equal "
85 "to 64 for 8 bit elements");
88 if (tileWidth < 2 || tileWidth > 32)
89 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
90 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
91 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
92 if (tileWidth * vBlocks > 32)
93 return op.emitOpError(
94 "tile_width * v_blocks should be less than or equal "
95 "to 32 for 16 bit elements");
98 if (tileWidth < 1 || tileWidth > 16)
99 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
100 if (vBlocks != 1 && vBlocks != 2)
101 return op.emitOpError(
"expecting v_blocks to be 1 or 2");
102 if (tileWidth * vBlocks > 16)
103 return op.emitOpError(
104 "tile_width * v_blocks should be less than or equal "
105 "to 16 for 32 bit elements");
108 if (tileWidth < 1 || tileWidth > 8)
109 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
111 return op.emitOpError(
"expecting v_blocks to be 1");
114 return op.emitOpError(
115 "expecting elem_size_in_bits to be 8, 16, 32, or 64");
121 if (op.getTranspose()) {
122 assert(!op.getPackRegister() &&
"Expecting pack_register should be false");
124 uint32_t vBlocks = op.getVBlocks();
126 return op.emitOpError(
"expecting v_blocks to be 1");
128 uint32_t tileHeight = op.getTileHeight();
129 uint32_t tileWidth = op.getTileWidth();
130 switch (op.getElemSizeInBits()) {
132 if (tileHeight < 1 || tileHeight > 32)
133 return op.emitOpError(
"expecting tile_height to be between 1 and 32");
134 if (tileWidth < 1 || tileWidth > 8)
135 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
139 return op.emitOpError(
140 "expecting tile_height to be 8 for 64 bit elements");
141 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
142 return op.emitOpError(
"expecting tile_width to be 1, 2, or 4");
145 return op.emitOpError(
"transpose is only supported for 32 and 64 bit "
152 assert(op.getPackRegister() && !op.getTranspose() &&
153 "Expecting pack_register should be true and transpose should be "
156 uint32_t vBlocks = op.getVBlocks();
157 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
158 return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
160 uint32_t tileHeight = op.getTileHeight();
161 uint32_t tileWidth = op.getTileWidth();
162 switch (op.getElemSizeInBits()) {
164 if (tileHeight < 4 || tileHeight > 32)
165 return op.emitOpError(
"expecting tile_height to be between 4 and 32");
166 if (tileWidth < 4 || tileWidth > 16)
167 return op.emitOpError(
"expecting tile_width to be between 4 and 16");
170 if (tileHeight < 2 || tileHeight > 32)
171 return op.emitOpError(
"expecting tile_height to be between 2 and 32");
172 if (tileWidth < 2 || tileWidth > 16)
173 return op.emitOpError(
"expecting tile_width to be between 2 and 16");
174 if (tileWidth * vBlocks > 32)
175 return op.emitOpError(
176 "tile_width * v_blocks should be less than or equal "
177 "to 32 for 16 bit elements");
180 return op.emitOpError(
"pack_register is only supported for 8 and 16 bit "
187 static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
188 uint32_t tileHeight = op.getTileHeight();
189 if (tileHeight < 1 || tileHeight > 8)
190 return op.emitOpError(
"expecting tile_height to be between 1 and 8");
192 uint32_t tileWidth = op.getTileWidth();
193 switch (op.getElemSizeInBits()) {
195 if (tileWidth < 4 || tileWidth > 64)
196 return op.emitOpError(
"expecting tile_width to be between 4 and 64");
199 if (tileWidth < 2 || tileWidth > 32)
200 return op.emitOpError(
"expecting tile_width to be between 2 and 32");
203 if (tileWidth < 1 || tileWidth > 16)
204 return op.emitOpError(
"expecting tile_width to be between 1 and 16");
207 if (tileWidth < 1 || tileWidth > 8)
208 return op.emitOpError(
"expecting tile_width to be between 1 and 8");
211 return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
214 uint32_t vBlocks = op.getVBlocks();
216 return op.emitOpError(
"expecting v_blocks to be 1");
223 if (verify2DBlockLoadRestriction(*this).failed())
226 if (verifyMatrixInput(*this).failed())
229 VectorType resTy = getRes().getType();
230 if (!resTy.getElementType().isIntOrFloat())
231 return emitOpError() <<
"expecting result element type to be int of float";
232 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
233 if (getElemSizeInBits() == 32 || getPackRegister()) {
234 if (resElemTySize != 32)
235 return emitOpError() <<
"expecting result element type to be 32 bits";
238 uint32_t tileWidth = getTileWidth();
239 if (getPackRegister()) {
242 "tile_width when pack_register is true should be equal "
243 "to subgroup size (16 elements)");
251 if (verify2DBlockStoreRestriction(*this).failed())
254 if (verifyMatrixInput(*this).failed())
257 uint32_t tileWidth = getTileWidth();
258 switch (getElemSizeInBits()) {
260 if (tileWidth != 16 && tileWidth != 32)
261 return emitOpError(
"tile_width for 8 bit elements should be equal to "
266 return emitOpError(
"tile_width for 16 bit elements should be equal "
271 return emitOpError(
"tile_width for 32 bit elements should be equal "
275 llvm_unreachable(
"unexpected element size");
282 if (verifyMatrixInput(*this).failed())
285 uint32_t tileWidth = getTileWidth();
286 switch (getElemSizeInBits()) {
288 if (tileWidth != 16 && tileWidth != 32)
289 return emitOpError(
"tile_width for 8 bit elements should be equal to "
294 return emitOpError(
"tile_width for 16 bit elements should be equal "
298 if (tileWidth != 8 && tileWidth != 16)
300 "tile_width for 32 bit elements should be equal to 8 or 16");
303 llvm_unreachable(
"unexpected element size");
312 return emitOpError(
"type of C operand must match result type");
319 StringRef triple, StringRef chip, DictionaryAttr flags,
320 ArrayAttr linkFiles) {
321 if (O < 0 || O > 3) {
323 <<
"The optimization level must be a number between 0 and 3.";
325 if (triple.empty()) {
326 return emitError() <<
"The target triple cannot be empty.";
329 return emitError() <<
"The target chip cannot be empty.";
333 if (
auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
334 StringRef filePath = fileStrAttr.getValue();
335 if (filePath.empty()) {
336 return emitError() <<
"File paths in linkFiles cannot be empty.";
338 if (!llvm::sys::fs::exists(filePath)) {
339 return emitError() <<
"File '" << filePath <<
"' does not exist.";
347 void XeVMDialect::initialize() {
350 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
354 #define GET_ATTRDEF_LIST
355 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
357 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
358 mlir::xevm::XeVMTargetAttr>();
361 #define GET_OP_CLASSES
362 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
364 #define GET_ATTRDEF_CLASSES
365 #include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
Attributes are known-constant values of operations.
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
constexpr unsigned subgroupSize
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...