MLIR 22.0.0git
XeVMDialect.cpp
Go to the documentation of this file.
1//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
12#include "llvm/ADT/SmallSet.h"
13#include "llvm/ADT/TypeSwitch.h"
14#include "llvm/Support/FileSystem.h"
15#include "llvm/Support/MathExtras.h"
16
17using namespace mlir;
18using namespace mlir::xevm;
19
20#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
21#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
22
23namespace {
24static constexpr uint32_t subgroupSize = 16;
25
26template <typename Op>
27LogicalResult verifyMatrixInput(Op op) {
28 static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
29 BlockPrefetch2dOp>::value,
30 "Unexpected template parameter");
31
32 std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
33 std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
34 if (pitch && width && *pitch < *width)
35 return op->emitOpError(
36 "4th operand (base pitch) should be >= 2nd operand (base width)");
37
38 uint32_t elemSize = op.getElemSizeInBits();
39 if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
40 return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
41
42 uint32_t tileHeight = op.getTileHeight();
43 if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
44 return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
45
46 uint32_t vBlocks = op.getVBlocks();
47 if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
48 return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
49
50 return success();
51}
52
53LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
54 VectorType resTy = op.getRes().getType();
55 if (!resTy.getElementType().isIntOrFloat())
56 return op.emitOpError()
57 << "expecting result element type to be int or float";
58 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
59 unsigned resSize = resTy.getNumElements() * resElemTySize;
60 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
61 op.getTileWidth() * op.getVBlocks() / subgroupSize;
62 if (resSize != expectedSize)
63 return op.emitOpError() << "result size of " << resSize
64 << " bits does not match the expected size of "
65 << expectedSize << " bits";
66
67 if (op.getTranspose() && op.getPackRegister())
68 return op.emitOpError("transpose and pack_register are mutually exclusive");
69
70 if (!op.getTranspose() && !op.getPackRegister()) {
71 uint32_t tileHeight = op.getTileHeight();
72 if (tileHeight < 1 || tileHeight > 32)
73 return op.emitOpError("expecting tile_height to be between 1 and 32");
74
75 uint32_t tileWidth = op.getTileWidth();
76 uint32_t vBlocks = op.getVBlocks();
77 switch (op.getElemSizeInBits()) {
78 case 8:
79 if (tileWidth < 4 || tileWidth > 64)
80 return op.emitOpError("expecting tile_width to be between 4 and 64");
81 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
82 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
83 if (tileWidth * vBlocks > 64)
84 return op.emitOpError(
85 "tile_width * v_blocks should be less than or equal "
86 "to 64 for 8 bit elements");
87 break;
88 case 16:
89 if (tileWidth < 2 || tileWidth > 32)
90 return op.emitOpError("expecting tile_width to be between 2 and 32");
91 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
92 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
93 if (tileWidth * vBlocks > 32)
94 return op.emitOpError(
95 "tile_width * v_blocks should be less than or equal "
96 "to 32 for 16 bit elements");
97 break;
98 case 32:
99 if (tileWidth < 1 || tileWidth > 16)
100 return op.emitOpError("expecting tile_width to be between 1 and 16");
101 if (vBlocks != 1 && vBlocks != 2)
102 return op.emitOpError("expecting v_blocks to be 1 or 2");
103 if (tileWidth * vBlocks > 16)
104 return op.emitOpError(
105 "tile_width * v_blocks should be less than or equal "
106 "to 16 for 32 bit elements");
107 break;
108 case 64:
109 if (tileWidth < 1 || tileWidth > 8)
110 return op.emitOpError("expecting tile_width to be between 1 and 8");
111 if (vBlocks != 1)
112 return op.emitOpError("expecting v_blocks to be 1");
113 break;
114 default:
115 return op.emitOpError(
116 "expecting elem_size_in_bits to be 8, 16, 32, or 64");
117 }
118
119 return success();
120 }
121
122 if (op.getTranspose()) {
123 assert(!op.getPackRegister() && "Expecting pack_register should be false");
124
125 uint32_t vBlocks = op.getVBlocks();
126 if (vBlocks != 1)
127 return op.emitOpError("expecting v_blocks to be 1");
128
129 uint32_t tileHeight = op.getTileHeight();
130 uint32_t tileWidth = op.getTileWidth();
131 switch (op.getElemSizeInBits()) {
132 case 32:
133 if (tileHeight < 1 || tileHeight > 32)
134 return op.emitOpError("expecting tile_height to be between 1 and 32");
135 if (tileWidth < 1 || tileWidth > 8)
136 return op.emitOpError("expecting tile_width to be between 1 and 8");
137 break;
138 case 64:
139 if (tileHeight != 8)
140 return op.emitOpError(
141 "expecting tile_height to be 8 for 64 bit elements");
142 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
143 return op.emitOpError("expecting tile_width to be 1, 2, or 4");
144 break;
145 default:
146 return op.emitOpError("transpose is only supported for 32 and 64 bit "
147 "elements");
148 }
149
150 return success();
151 }
152
153 assert(op.getPackRegister() && !op.getTranspose() &&
154 "Expecting pack_register should be true and transpose should be "
155 "false");
156
157 uint32_t vBlocks = op.getVBlocks();
158 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
159 return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
160
161 uint32_t tileHeight = op.getTileHeight();
162 uint32_t tileWidth = op.getTileWidth();
163 switch (op.getElemSizeInBits()) {
164 case 8:
165 if (tileHeight < 4 || tileHeight > 32)
166 return op.emitOpError("expecting tile_height to be between 4 and 32");
167 if (tileWidth < 4 || tileWidth > 16)
168 return op.emitOpError("expecting tile_width to be between 4 and 16");
169 break;
170 case 16:
171 if (tileHeight < 2 || tileHeight > 32)
172 return op.emitOpError("expecting tile_height to be between 2 and 32");
173 if (tileWidth < 2 || tileWidth > 16)
174 return op.emitOpError("expecting tile_width to be between 2 and 16");
175 if (tileWidth * vBlocks > 32)
176 return op.emitOpError(
177 "tile_width * v_blocks should be less than or equal "
178 "to 32 for 16 bit elements");
179 break;
180 default:
181 return op.emitOpError("pack_register is only supported for 8 and 16 bit "
182 "elements");
183 }
184
185 return success();
186}
187
188static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
189 uint32_t tileHeight = op.getTileHeight();
190 if (tileHeight < 1 || tileHeight > 8)
191 return op.emitOpError("expecting tile_height to be between 1 and 8");
192
193 uint32_t tileWidth = op.getTileWidth();
194 switch (op.getElemSizeInBits()) {
195 case 8:
196 if (tileWidth < 4 || tileWidth > 64)
197 return op.emitOpError("expecting tile_width to be between 4 and 64");
198 break;
199 case 16:
200 if (tileWidth < 2 || tileWidth > 32)
201 return op.emitOpError("expecting tile_width to be between 2 and 32");
202 break;
203 case 32:
204 if (tileWidth < 1 || tileWidth > 16)
205 return op.emitOpError("expecting tile_width to be between 1 and 16");
206 break;
207 case 64:
208 if (tileWidth < 1 || tileWidth > 8)
209 return op.emitOpError("expecting tile_width to be between 1 and 8");
210 break;
211 default:
212 return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
213 }
214
215 uint32_t vBlocks = op.getVBlocks();
216 if (vBlocks != 1)
217 return op.emitOpError("expecting v_blocks to be 1");
218 return success();
219}
220
221} // namespace
222
223LogicalResult BlockLoad2dOp::verify() {
224 if (verify2DBlockLoadRestriction(*this).failed())
225 return failure();
226
227 if (verifyMatrixInput(*this).failed())
228 return failure();
229
230 VectorType resTy = getRes().getType();
231 if (!resTy.getElementType().isIntOrFloat())
232 return emitOpError() << "expecting result element type to be int of float";
233 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
234 if (getElemSizeInBits() == 32 || getPackRegister()) {
235 if (resElemTySize != 32)
236 return emitOpError() << "expecting result element type to be 32 bits";
237 }
238
239 uint32_t tileWidth = getTileWidth();
240 if (getPackRegister()) {
241 if (tileWidth != 16)
242 return emitOpError(
243 "tile_width when pack_register is true should be equal "
244 "to subgroup size (16 elements)");
245 return success();
246 }
247
248 return success();
249}
250
251LogicalResult BlockStore2dOp::verify() {
252 if (verify2DBlockStoreRestriction(*this).failed())
253 return failure();
254
255 if (verifyMatrixInput(*this).failed())
256 return failure();
257
258 uint32_t tileWidth = getTileWidth();
259 switch (getElemSizeInBits()) {
260 case 8:
261 if (tileWidth != 16 && tileWidth != 32)
262 return emitOpError("tile_width for 8 bit elements should be equal to "
263 "16 or 32");
264 break;
265 case 16:
266 if (tileWidth != 16)
267 return emitOpError("tile_width for 16 bit elements should be equal "
268 "to 16");
269 break;
270 case 32:
271 if (tileWidth != 16)
272 return emitOpError("tile_width for 32 bit elements should be equal "
273 "to 16");
274 break;
275 default:
276 llvm_unreachable("unexpected element size");
277 }
278
279 return success();
280}
281
282LogicalResult BlockPrefetch2dOp::verify() {
283 if (verifyMatrixInput(*this).failed())
284 return failure();
285
286 uint32_t tileWidth = getTileWidth();
287 switch (getElemSizeInBits()) {
288 case 8:
289 if (tileWidth != 16 && tileWidth != 32)
290 return emitOpError("tile_width for 8 bit elements should be equal to "
291 "16 or 32");
292 break;
293 case 16:
294 if (tileWidth != 16)
295 return emitOpError("tile_width for 16 bit elements should be equal "
296 "to 16");
297 break;
298 case 32:
299 if (tileWidth != 8 && tileWidth != 16)
300 return emitOpError(
301 "tile_width for 32 bit elements should be equal to 8 or 16");
302 break;
303 default:
304 llvm_unreachable("unexpected element size");
305 }
306
307 return success();
308}
309
310template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
311 OpType, BlockLoadOp, BlockStoreOp>::value>>
312LogicalResult verify1DBlockArg(OpType op) {
313 Type srcOrDstTy;
314 if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315 srcOrDstTy = op.getResult().getType();
316 else
317 srcOrDstTy = op.getVal().getType();
318 VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
319 // scalar case is always valid
320 if (!vTy)
321 return success();
322 int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
323 if (elemTySize == 1) {
324 llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
325 if (validSizes.contains(vTy.getNumElements()))
326 return success();
327 else
328 return op.emitOpError(
329 "vector size must be 2, 4, 8 or 16 for 8-bit element type");
330 } else {
331 llvm::SmallSet<int, 3> validSizes{2, 4, 8};
332 if (validSizes.contains(vTy.getNumElements()))
333 return success();
334 else
335 return op.emitOpError(
336 "vector size must be 2, 4 or 8 for element type > 8 bits");
337 }
338}
339
340LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }
341
342LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }
343
344LogicalResult MMAOp::verify() {
345 if (getC()) {
346 if (getResult().getType() != getC().getType())
347 return emitOpError("type of C operand must match result type");
348 }
349 return success();
350}
351
352LogicalResult
353XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
354 StringRef triple, StringRef chip, DictionaryAttr flags,
355 ArrayAttr linkFiles) {
356 if (O < 0 || O > 3) {
357 return emitError()
358 << "The optimization level must be a number between 0 and 3.";
359 }
360 if (triple.empty()) {
361 return emitError() << "The target triple cannot be empty.";
362 }
363 if (chip.empty()) {
364 return emitError() << "The target chip cannot be empty.";
365 }
366 if (linkFiles) {
367 for (Attribute fileAttr : linkFiles) {
368 if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
369 StringRef filePath = fileStrAttr.getValue();
370 if (filePath.empty()) {
371 return emitError() << "File paths in linkFiles cannot be empty.";
372 }
373 if (!llvm::sys::fs::exists(filePath)) {
374 return emitError() << "File '" << filePath << "' does not exist.";
375 }
376 }
377 }
378 }
379 return success();
380}
381
382void XeVMDialect::initialize() {
383 addOperations<
384#define GET_OP_LIST
385#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
386 >();
387
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
391 >();
392 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
393 mlir::xevm::XeVMTargetAttr>();
394}
395
396#define GET_OP_CLASSES
397#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
398
399#define GET_ATTRDEF_CLASSES
400#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
LogicalResult verify1DBlockArg(OpType op)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152