MLIR 23.0.0git
XeVMDialect.cpp
Go to the documentation of this file.
1//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
13#include "llvm/ADT/SmallSet.h"
14#include "llvm/ADT/TypeSwitch.h"
15#include "llvm/Support/FileSystem.h"
16#include "llvm/Support/MathExtras.h"
17
18using namespace mlir;
19using namespace mlir::xevm;
20
21#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
22#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
23
24namespace {
25static constexpr uint32_t subgroupSize = 16;
26
27template <typename Op>
28LogicalResult verifyMatrixInput(Op op) {
29 static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
30 BlockPrefetch2dOp>::value,
31 "Unexpected template parameter");
32
33 std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
34 std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
35 if (pitch && width && *pitch < *width)
36 return op->emitOpError(
37 "4th operand (base pitch) should be >= 2nd operand (base width)");
38
39 uint32_t elemSize = op.getElemSizeInBits();
40 if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
41 return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
42
43 uint32_t tileHeight = op.getTileHeight();
44 if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
45 return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
46
47 uint32_t vBlocks = op.getVBlocks();
48 if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
49 return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
50
51 return success();
52}
53
54LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
55 VectorType resTy = op.getRes().getType();
56 if (!resTy.getElementType().isIntOrFloat())
57 return op.emitOpError()
58 << "expecting result element type to be int or float";
59 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
60 unsigned resSize = resTy.getNumElements() * resElemTySize;
61 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
62 op.getTileWidth() * op.getVBlocks() / subgroupSize;
63 if (resSize != expectedSize)
64 return op.emitOpError() << "result size of " << resSize
65 << " bits does not match the expected size of "
66 << expectedSize << " bits";
67
68 if (op.getTranspose() && op.getPackRegister())
69 return op.emitOpError("transpose and pack_register are mutually exclusive");
70
71 if (!op.getTranspose() && !op.getPackRegister()) {
72 uint32_t tileHeight = op.getTileHeight();
73 if (tileHeight < 1 || tileHeight > 32)
74 return op.emitOpError("expecting tile_height to be between 1 and 32");
75
76 uint32_t tileWidth = op.getTileWidth();
77 uint32_t vBlocks = op.getVBlocks();
78 switch (op.getElemSizeInBits()) {
79 case 8:
80 if (tileWidth < 4 || tileWidth > 64)
81 return op.emitOpError("expecting tile_width to be between 4 and 64");
82 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
83 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
84 if (tileWidth * vBlocks > 64)
85 return op.emitOpError(
86 "tile_width * v_blocks should be less than or equal "
87 "to 64 for 8 bit elements");
88 break;
89 case 16:
90 if (tileWidth < 2 || tileWidth > 32)
91 return op.emitOpError("expecting tile_width to be between 2 and 32");
92 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
93 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
94 if (tileWidth * vBlocks > 32)
95 return op.emitOpError(
96 "tile_width * v_blocks should be less than or equal "
97 "to 32 for 16 bit elements");
98 break;
99 case 32:
100 if (tileWidth < 1 || tileWidth > 16)
101 return op.emitOpError("expecting tile_width to be between 1 and 16");
102 if (vBlocks != 1 && vBlocks != 2)
103 return op.emitOpError("expecting v_blocks to be 1 or 2");
104 if (tileWidth * vBlocks > 16)
105 return op.emitOpError(
106 "tile_width * v_blocks should be less than or equal "
107 "to 16 for 32 bit elements");
108 break;
109 case 64:
110 if (tileWidth < 1 || tileWidth > 8)
111 return op.emitOpError("expecting tile_width to be between 1 and 8");
112 if (vBlocks != 1)
113 return op.emitOpError("expecting v_blocks to be 1");
114 break;
115 default:
116 return op.emitOpError(
117 "expecting elem_size_in_bits to be 8, 16, 32, or 64");
118 }
119
120 return success();
121 }
122
123 if (op.getTranspose()) {
124 assert(!op.getPackRegister() && "Expecting pack_register should be false");
125
126 uint32_t vBlocks = op.getVBlocks();
127 if (vBlocks != 1)
128 return op.emitOpError("expecting v_blocks to be 1");
129
130 uint32_t tileHeight = op.getTileHeight();
131 uint32_t tileWidth = op.getTileWidth();
132 switch (op.getElemSizeInBits()) {
133 case 32:
134 if (tileHeight < 1 || tileHeight > 32)
135 return op.emitOpError("expecting tile_height to be between 1 and 32");
136 if (tileWidth < 1 || tileWidth > 8)
137 return op.emitOpError("expecting tile_width to be between 1 and 8");
138 break;
139 case 64:
140 if (tileHeight != 8)
141 return op.emitOpError(
142 "expecting tile_height to be 8 for 64 bit elements");
143 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
144 return op.emitOpError("expecting tile_width to be 1, 2, or 4");
145 break;
146 default:
147 return op.emitOpError("transpose is only supported for 32 and 64 bit "
148 "elements");
149 }
150
151 return success();
152 }
153
154 assert(op.getPackRegister() && !op.getTranspose() &&
155 "Expecting pack_register should be true and transpose should be "
156 "false");
157
158 uint32_t vBlocks = op.getVBlocks();
159 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
160 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
161
162 uint32_t tileHeight = op.getTileHeight();
163 uint32_t tileWidth = op.getTileWidth();
164 switch (op.getElemSizeInBits()) {
165 case 8:
166 if (tileHeight < 4 || tileHeight > 32)
167 return op.emitOpError("expecting tile_height to be between 4 and 32");
168 if (tileWidth < 4 || tileWidth > 16)
169 return op.emitOpError("expecting tile_width to be between 4 and 16");
170 break;
171 case 16:
172 if (tileHeight < 2 || tileHeight > 32)
173 return op.emitOpError("expecting tile_height to be between 2 and 32");
174 if (tileWidth < 2 || tileWidth > 16)
175 return op.emitOpError("expecting tile_width to be between 2 and 16");
176 if (tileWidth * vBlocks > 32)
177 return op.emitOpError(
178 "tile_width * v_blocks should be less than or equal "
179 "to 32 for 16 bit elements");
180 break;
181 default:
182 return op.emitOpError("pack_register is only supported for 8 and 16 bit "
183 "elements");
184 }
185
186 return success();
187}
188
189static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
190 uint32_t tileHeight = op.getTileHeight();
191 if (tileHeight < 1 || tileHeight > 8)
192 return op.emitOpError("expecting tile_height to be between 1 and 8");
193
194 uint32_t tileWidth = op.getTileWidth();
195 switch (op.getElemSizeInBits()) {
196 case 8:
197 if (tileWidth < 4 || tileWidth > 64)
198 return op.emitOpError("expecting tile_width to be between 4 and 64");
199 break;
200 case 16:
201 if (tileWidth < 2 || tileWidth > 32)
202 return op.emitOpError("expecting tile_width to be between 2 and 32");
203 break;
204 case 32:
205 if (tileWidth < 1 || tileWidth > 16)
206 return op.emitOpError("expecting tile_width to be between 1 and 16");
207 break;
208 case 64:
209 if (tileWidth < 1 || tileWidth > 8)
210 return op.emitOpError("expecting tile_width to be between 1 and 8");
211 break;
212 default:
213 return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
214 }
215
216 uint32_t vBlocks = op.getVBlocks();
217 if (vBlocks != 1)
218 return op.emitOpError("expecting v_blocks to be 1");
219 return success();
220}
221
222} // namespace
223
224LogicalResult BlockLoad2dOp::verify() {
225 if (verify2DBlockLoadRestriction(*this).failed())
226 return failure();
227
228 if (verifyMatrixInput(*this).failed())
229 return failure();
230
231 VectorType resTy = getRes().getType();
232 if (!resTy.getElementType().isIntOrFloat())
233 return emitOpError() << "expecting result element type to be int of float";
234 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
235 if (getElemSizeInBits() == 32 || getPackRegister()) {
236 if (resElemTySize != 32)
237 return emitOpError() << "expecting result element type to be 32 bits";
238 }
239
240 uint32_t tileWidth = getTileWidth();
241 if (getPackRegister()) {
242 if (tileWidth != 16)
243 return emitOpError(
244 "tile_width when pack_register is true should be equal "
245 "to subgroup size (16 elements)");
246 return success();
247 }
248
249 return success();
250}
251
252LogicalResult BlockStore2dOp::verify() {
253 if (verify2DBlockStoreRestriction(*this).failed())
254 return failure();
255
256 if (verifyMatrixInput(*this).failed())
257 return failure();
258
259 uint32_t tileWidth = getTileWidth();
260 switch (getElemSizeInBits()) {
261 case 8:
262 if (tileWidth != 16 && tileWidth != 32)
263 return emitOpError("tile_width for 8 bit elements should be equal to "
264 "16 or 32");
265 break;
266 case 16:
267 if (tileWidth != 16)
268 return emitOpError("tile_width for 16 bit elements should be equal "
269 "to 16");
270 break;
271 case 32:
272 if (tileWidth != 16)
273 return emitOpError("tile_width for 32 bit elements should be equal "
274 "to 16");
275 break;
276 default:
277 llvm_unreachable("unexpected element size");
278 }
279
280 return success();
281}
282
283LogicalResult BlockPrefetch2dOp::verify() {
284 if (verifyMatrixInput(*this).failed())
285 return failure();
286
287 uint32_t tileWidth = getTileWidth();
288 switch (getElemSizeInBits()) {
289 case 8:
290 if (tileWidth != 16 && tileWidth != 32)
291 return emitOpError("tile_width for 8 bit elements should be equal to "
292 "16 or 32");
293 break;
294 case 16:
295 if (tileWidth != 16)
296 return emitOpError("tile_width for 16 bit elements should be equal "
297 "to 16");
298 break;
299 case 32:
300 if (tileWidth != 8 && tileWidth != 16)
301 return emitOpError(
302 "tile_width for 32 bit elements should be equal to 8 or 16");
303 break;
304 default:
305 llvm_unreachable("unexpected element size");
306 }
307
308 return success();
309}
310
311template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
312 OpType, BlockLoadOp, BlockStoreOp>::value>>
313LogicalResult verify1DBlockArg(OpType op) {
314 Type srcOrDstTy;
315 if constexpr (std::is_same_v<OpType, BlockLoadOp>)
316 srcOrDstTy = op.getResult().getType();
317 else
318 srcOrDstTy = op.getVal().getType();
319 VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
320 // scalar case is always valid
321 if (!vTy)
322 return success();
323 int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
324 if (elemTySize == 1) {
325 llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
326 if (validSizes.contains(vTy.getNumElements()))
327 return success();
328 else
329 return op.emitOpError(
330 "vector size must be 2, 4, 8 or 16 for 8-bit element type");
331 } else {
332 llvm::SmallSet<int, 3> validSizes{2, 4, 8};
333 if (validSizes.contains(vTy.getNumElements()))
334 return success();
335 else
336 return op.emitOpError(
337 "vector size must be 2, 4 or 8 for element type > 8 bits");
338 }
339}
340
341LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }
342
343LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }
344
345LogicalResult MMAOp::verify() {
346 if (getC()) {
347 if (getResult().getType() != getC().getType())
348 return emitOpError("type of C operand must match result type");
349 }
350 return success();
351}
352
353LogicalResult MMAMxOp::verify() {
354 if (getC()) {
355 if (getResult().getType() != getC().getType())
356 return emitOpError("type of C operand must match result type");
357 }
358 return success();
359}
360
361LogicalResult TruncfOp::verify() {
362 Type srcTy = getSrc().getType();
363 Type dstTy = getDst().getType();
364 if (isa<VectorType>(srcTy) != isa<VectorType>(dstTy))
365 return emitOpError("both src and dst should be vector types or both should "
366 "be scalar types");
367 if (getElementTypeOrSelf(srcTy).getIntOrFloatBitWidth() <=
368 getElementTypeOrSelf(dstTy).getIntOrFloatBitWidth())
369 return emitError(
370 "dst element bitwidth should be less than src element bitwidth");
371 return success();
372}
373
374LogicalResult ExtfOp::verify() {
375 Type srcTy = getSrc().getType();
376 Type dstTy = getDst().getType();
377 if (isa<VectorType>(srcTy) != isa<VectorType>(dstTy))
378 return emitOpError("both src and dst should be vector types or both should "
379 "be scalar types");
380 if (getElementTypeOrSelf(srcTy).getIntOrFloatBitWidth() >=
381 getElementTypeOrSelf(dstTy).getIntOrFloatBitWidth())
382 return emitError(
383 "dst element bitwidth should be greater than src element bitwidth");
384 return success();
385}
386
387LogicalResult
388XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
389 StringRef triple, StringRef chip, DictionaryAttr flags,
390 ArrayAttr linkFiles) {
391 if (O < 0 || O > 3) {
392 return emitError()
393 << "The optimization level must be a number between 0 and 3.";
394 }
395 if (triple.empty()) {
396 return emitError() << "The target triple cannot be empty.";
397 }
398 if (chip.empty()) {
399 return emitError() << "The target chip cannot be empty.";
400 }
401 if (linkFiles) {
402 for (Attribute fileAttr : linkFiles) {
403 if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
404 StringRef filePath = fileStrAttr.getValue();
405 if (filePath.empty()) {
406 return emitError() << "File paths in linkFiles cannot be empty.";
407 }
408 if (!llvm::sys::fs::exists(filePath)) {
409 return emitError() << "File '" << filePath << "' does not exist.";
410 }
411 }
412 }
413 }
414 return success();
415}
416
417void XeVMDialect::initialize() {
418 addOperations<
419#define GET_OP_LIST
420#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
421 >();
422
423 addAttributes<
424#define GET_ATTRDEF_LIST
425#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
426 >();
427 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
428 mlir::xevm::XeVMTargetAttr>();
429}
430
431#define GET_OP_CLASSES
432#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
433
434#define GET_ATTRDEF_CLASSES
435#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
LogicalResult verify1DBlockArg(OpType op)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147