MLIR 23.0.0git
NVGPUDialect.cpp
Go to the documentation of this file.
1//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the NVGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
18#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Verifier.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26using namespace mlir::nvgpu;
27
28#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
29
30void NVGPUDialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
34 >();
35 addAttributes<
36#define GET_ATTRDEF_LIST
37#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
38 >();
39 addOperations<
40#define GET_OP_LIST
41#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
42 >();
43}
44
45bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
46 if (!memorySpace)
47 return false;
48 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
52 return false;
53}
54
55bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
58}
59
60//===----------------------------------------------------------------------===//
61// NVGPU_DeviceAsyncCopyOp
62//===----------------------------------------------------------------------===//
63
64LogicalResult DeviceAsyncCopyOp::verify() {
65 auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
66 auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
67
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError("source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError("destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
73 return emitError()
74 << "destination memref must have a memory space attribute of "
75 "IntegerAttr("
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 << ") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError("source and destination must have the same element type");
80 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() << "expected " << srcMemref.getRank()
82 << " source indices, got " << getSrcIndices().size();
83 if (size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() << "expected " << dstMemref.getRank()
85 << " destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92 diag << "Requested copy elements is " << dstElements << " with width "
93 << dstMemref.getElementTypeBitWidth()
94 << ". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) << ", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) << ", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) << ".";
101 return diag;
102 }
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() << "bypassL1 does not satify alignment for "
107 << dstMemref << " with destination element "
108 << dstElements
109 << ". Unset bypassL1, or set "
110 "destination element to "
111 << req;
112 }
113 }
114 return success();
115}
116
117//===----------------------------------------------------------------------===//
118// NVGPU_MmaSyncOp
119//===----------------------------------------------------------------------===//
120void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
121 ::mlir::OperationState &odsState, Value matrixA,
122 Value matrixB, Value matrixC, ArrayAttr mmaShape) {
123 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
124 mmaShape, UnitAttr());
125}
126
127void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
128 ::mlir::OperationState &odsState, Value matrixA,
129 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
130 bool tf32Enabled) {
131 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
132 odsBuilder.getI64ArrayAttr(mmaShape),
133 tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
134}
135
136/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
137static LogicalResult verifyMmaSyncOp(Operation *op,
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled, bool sparse = false) {
143 // The verification for mma.sync covering various shapes and data types is
144 // based on the fundamental tensor core shape.
145
146 // "Fundamental" tensor core shapes:
147 // - For F32 (TF32), F16, S8, and S4 data
148 // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
149 // - F64 is an exception and is of shape 8-by-8-by-256b.
150 int64_t shapeM = 8;
151 int64_t shapeN = 8;
152 int64_t shapeK; // set based on data type (128b for all data types except F64)
153
154 // Number of elements A, B, and C per thread per fundamental tensor core tile
155 int64_t numElementA; // set based on data type (32b except F64)
156 int64_t numElementB; // set based on data type (32b except F64)
157 int64_t numElementC{2}; // two accumulator elements per fundamental tile
158
159 // nvgpu.mma.sync vector operands (per thread)
160 auto aVector = matrixA.getType();
161 auto bVector = matrixB.getType();
162 auto cVector = matrixC.getType();
163
164 // vector shapes
165 ArrayRef<int64_t> aShape = aVector.getShape();
166 ArrayRef<int64_t> bShape = bVector.getShape();
167 ArrayRef<int64_t> cShape = cVector.getShape();
168
169 // vector element type
170 Type aType = aVector.getElementType();
171
172 // Certain data types are not allowed in sparse mode.
173 if (sparse && aType.isF64())
174 return op->emitError() << "f64 is not supported for sparse mode";
175
176 if (aType.isF64()) {
177 // exception to 8-by-8-128b fundamental tensor core tile size
178 shapeK = 4;
179 numElementA = 1;
180 numElementB = 1;
181 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
182 aType.isInteger(8) || aType.isInteger(4)) {
183 // 8-by-8-128b fundamental tensor core tile size
184 int operandBitwidth = aType.getIntOrFloatBitWidth();
185 shapeK = 128 / operandBitwidth; // 128b wide shapeK
186
187 numElementA = 32 / operandBitwidth; // 32b wide operand A
188 numElementB = 32 / operandBitwidth; // 32b wide operand B
189 } else {
190 return op->emitError()
191 << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
192 "supported by "
193 << op->getName();
194 }
195
196 //
197 // Basic verification
198 //
199
200 if (aShape.size() != 2) {
201 return op->emitError() << "matrixA must be 2 dimensional vector";
202 }
203
204 if (bShape.size() != 2) {
205 return op->emitError() << "matrixB must be 2 dimensional vector";
206 }
207
208 if (cShape.size() != 2) {
209 return op->emitError() << "matrixC must be 2 dimensional vector";
210 }
211
212 auto [m, n, k] = mmaShape;
213
214 // verify warp-wide size for vector a
215 int64_t sparseFactor = sparse ? 2 : 1;
216 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
217 return op->emitOpError()
218 << "expected " << m * k << " warp-wide matrix A elements";
219
220 // verify warp-wide size for vector b
221 if (bShape[0] * bShape[1] * kWarpSize != k * n)
222 return op->emitOpError()
223 << "expected " << k * n << " warp-wide matrix B elements";
224
225 // verify warp-wide size for vector c
226 if (cShape[0] * cShape[1] * kWarpSize != m * n)
227 return op->emitOpError()
228 << "expected " << m * n << " warp-wide matrix C elements";
229
230 // verify tf32 tensor cores are enabled for only F32 datatype
231 if (tf32Enabled && !(aType.isF32()))
232 return op->emitOpError()
233 << "expected tf32 tensor cores only for F32 operands";
234
235 //
236 // Extended verification
237 //
238
239 // tiles of fundamental tensor core operations
240 int64_t mTile = m / shapeM;
241 int64_t nTile = n / shapeN;
242 int64_t kTile = k / shapeK;
243
244 // verify shape of aVector
245 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
246 (aShape[1] != numElementA))
247 return op->emitOpError() << "expected matrix A to be shaped ("
248 << mTile * kTile << " x " << numElementA << ")";
249
250 // verify shape of bVector
251 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
252 return op->emitOpError() << "expected matrix B to be shaped ("
253 << kTile * nTile << " x " << numElementB << ")";
254
255 // verify shape of cVector
256 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
257 return op->emitOpError() << "expected matrix C to be shaped ("
258 << mTile * nTile << " x " << numElementC << ")";
259
260 return success();
261}
262
263LogicalResult MmaSyncOp::verify() {
264 if (getMmaShape().size() != 3)
265 return emitOpError() << "mmaShape must have exactly 3 elements";
266
267 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
268 getMatrixC(), getMmaShapeAsArray(),
269 getOperation()->hasAttr(getTf32EnabledAttrName()));
270}
271
272//===----------------------------------------------------------------------===//
273// NVGPU_MmaSparseSyncOp
274//===----------------------------------------------------------------------===//
275void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
276 ::mlir::OperationState &odsState, Value matrixA,
277 Value matrixB, Value matrixC, Value sparseMetadata,
278 ArrayRef<int64_t> mmaShape) {
279 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
280 sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
281}
282
283LogicalResult MmaSparseSyncOp::verify() {
284 unsigned sparsitySelector = getSparsitySelector();
285 if (sparsitySelector > 1)
286 return emitOpError() << "sparsity selector should be 0 or 1";
287
288 if (getMmaShape().size() != 3)
289 return emitOpError() << "mmaShape must have exactly 3 elements";
290
291 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
292 getMatrixC(), getMmaShapeAsArray(),
293 getOperation()->hasAttr(getTf32EnabledAttrName()),
294 true);
295}
296
297//===----------------------------------------------------------------------===//
298// NVGPU_LdMatrixOp
299//===----------------------------------------------------------------------===//
300LogicalResult LdMatrixOp::verify() {
301 // ldmatrix reads data from source in shared memory
302 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
303
304 // ldmatrix writes data to result/destination in vector registers
305 auto resVector = llvm::cast<VectorType>(getRes().getType());
306
307 // vector register shape, element type, and bitwidth
308 ArrayRef<int64_t> resShape = resVector.getShape();
309 Type resType = resVector.getElementType();
310 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
311
312 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
313 int64_t numElementsPer32b = 32 / elementBitWidth;
314
315 // number of 8-by-8 tiles
316 int64_t numTiles = getNumTiles();
317
318 // transpose elements in vector registers at 16b granularity when true
319 bool isTranspose = getTranspose();
320
321 //
322 // verification
323 //
324
325 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
326 return emitError()
327 << "expected nvgpu.ldmatrix srcMemref must have a memory space "
328 "attribute of IntegerAttr("
329 << NVGPUDialect::kSharedMemoryAddressSpace
330 << ") or gpu::AddressSpaceAttr(Workgroup)";
331 if (elementBitWidth > 32)
332 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
333 if (isTranspose && !(elementBitWidth == 16))
334 return emitError()
335 << "nvgpu.ldmatrix transpose works only at 16b granularity";
336 if (resShape.size() != 2) {
337 return emitError() << "results must be 2 dimensional vector";
338 }
339 if (!(resShape[1] == numElementsPer32b))
340 return emitError() << "expected vector register shape[1] = "
341 << numElementsPer32b;
342 if (!(resShape[0] == numTiles))
343 return emitError()
344 << "expected vector register shape[0] and numTiles to match";
345
346 return success();
347}
348
349//===----------------------------------------------------------------------===//
350// NVGPU_TmaAsyncLoadOp
351//===----------------------------------------------------------------------===//
352
353static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
354 switch (kind) {
355 case TensorMapSwizzleKind::SWIZZLE_32B:
356 return 32;
357 case TensorMapSwizzleKind::SWIZZLE_64B:
358 return 64;
359 case TensorMapSwizzleKind::SWIZZLE_128B:
360 return 128;
361 default:
362 return 0;
363 }
364}
365
366std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
367 Operation *op, TensorMapDescriptorType descType,
368 std::optional<MemRefType> memrefType = std::nullopt) {
369 MemRefType descMemref = descType.getTensor();
370 // Limitation
371 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
372 return op->emitError() << "Interleave options are not supported yet.";
373
374 // Address space check for shared memory check
375 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
376 return op->emitError() << "the tensor map descriptor has incorrect address "
377 "space, it must be shared memory address space.";
378 }
379 // Support only static shape for the time being
380 if (!descMemref.hasStaticShape())
381 return op->emitError() << "the tensor map descriptor must be static shaped";
382
383 for (auto dim : descMemref.getShape()) {
384 if (dim <= 0 || dim > kMaxTMADimension) {
385 return op->emitError() << "the tensor map descriptor must have "
386 "dimensions between 1 and "
387 << kMaxTMADimension << " but it is " << dim;
388 }
389 }
390 if (descMemref.getRank() > 1 &&
391 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
392 unsigned lastDimensionByte =
393 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
394 unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
395 if (lastDimensionByte != expectByte)
396 return op->emitError() << "the tensormap descriptor must have last "
397 "dimension of "
398 << expectByte << " bytes but it is "
399 << lastDimensionByte << " bytes";
400 }
401
402 // No verification if memref type is not provided
403 if (!memrefType.has_value())
404 return std::nullopt;
405
406 MemRefType dstMemref = memrefType.value();
407
408 // Check element type
409 if (descMemref.getElementType() != dstMemref.getElementType()) {
410 return op->emitError() << "the element type of tensor map descriptor and "
411 "memref must be same";
412 }
413
414 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
415 return op->emitError() << "the destination memref has incorrect address "
416 "space, it must be shared memory address space.";
417 }
418 if (!dstMemref.hasStaticShape())
419 return op->emitError() << "the destination memref must be static shaped";
420
421 if (dstMemref.getRank() != descMemref.getRank()) {
422 return op->emitError() << "the shape of tensor map descriptor and "
423 "memref must have same rank";
424 }
425 if (!descMemref.getShape().equals(dstMemref.getShape())) {
426 return op->emitError() << "memref and tensor map shapes mismatch "
427 << descMemref << " != " << dstMemref;
428 }
429
430 int lastDimBytes =
431 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
432 if (lastDimBytes % kTMALastdimByte != 0) {
433 return op->emitError() << "the bytes in the last dimension of the tensor "
434 "map must be a multiple of 16";
435 }
436 return std::nullopt;
437}
438
439LogicalResult TmaAsyncLoadOp::verify() {
440 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
441 *this, getTensorMapDescriptor().getType(), getDst().getType());
442 if (error.has_value())
443 return error.value();
444
445 if (getCoordinates().size() > kMaxTMATensorDimension) {
446 return emitError() << "Maximum " << kMaxTMATensorDimension
447 << " coordinates are supported.";
448 }
449 if (getCoordinates().size() !=
450 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
451 return emitError() << "number of coordinates do not match with the rank of "
452 "tensor descriptor map.";
453 }
454
455 return success();
456}
457
458//===----------------------------------------------------------------------===//
459// NVGPU_TmaAsyncStoreOp
460//===----------------------------------------------------------------------===//
461
462LogicalResult TmaAsyncStoreOp::verify() {
463 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
464 *this, getTensorMapDescriptor().getType(), getSrc().getType());
465 if (error.has_value())
466 return error.value();
467
468 if (getCoordinates().size() > kMaxTMATensorDimension) {
469 return emitError() << "Maximum " << kMaxTMATensorDimension
470 << " coordinates are supported.";
471 }
472 if (getCoordinates().size() !=
473 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
474 return emitError() << "number of coordinates do not match with the rank of "
475 "tensor descriptor map.";
476 }
477
478 return success();
479}
480
481LogicalResult TmaCreateDescriptorOp::verify() {
482 if (getBoxDimensions().size() > kMaxTMATensorDimension) {
483 return emitError() << "Maximum " << kMaxTMATensorDimension
484 << " coordinates are supported.";
485 }
486
487 std::optional<InFlightDiagnostic> error =
488 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
489 if (error.has_value())
490 return error.value();
491
492 return success();
493}
494
495//===----------------------------------------------------------------------===//
496// NVGPU_WarpgroupGenerateDescriptorOp
497//===----------------------------------------------------------------------===//
498
499LogicalResult WarpgroupGenerateDescriptorOp::verify() {
500 std::optional<InFlightDiagnostic> error =
501 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
502 if (error.has_value())
503 return error.value();
504
505 if (getTensorMap().getType().getSwizzle() !=
506 TensorMapSwizzleKind::SWIZZLE_128B) {
507 return emitError() << "supports only "
508 << stringifyTensorMapSwizzleKind(
509 TensorMapSwizzleKind::SWIZZLE_128B)
510 << " is supported for the time being";
511 }
512
513 if (getTensorMap().getType().getInterleave() !=
514 TensorMapInterleaveKind::INTERLEAVE_NONE) {
515 return emitError() << "supports only "
516 << stringifyTensorMapInterleaveKind(
517 TensorMapInterleaveKind::INTERLEAVE_NONE)
518 << " is supported for the time being";
519 }
520
521 return success();
522}
523
524//===----------------------------------------------------------------------===//
525// WarpgroupMmaOp
526//===----------------------------------------------------------------------===//
527
528LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
529 // F32 += F16 + F16
530 // F16 += F16 + F16
531 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
532 return success();
533 // F32 += TF32 + TF32
534 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
535 return success();
536 // s32 += i8 + i8
537 if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
538 return success();
539 // s32 += i1 + i1
540 if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
541 return success();
542 // F32 += BF16 + BF16
543 // F16 += BF16 + BF16
544 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
545 return success();
546 // F16 += f8 + f8
547 // F32 += f8 + f8
548 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
549 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
550 (typeD.isF32() || typeD.isF16()))
551 return success();
552
553 return failure();
554}
555
556LogicalResult isAllowedSizeM(int sizeM) {
557 if (sizeM % kWgmmaSizeM)
558 return failure();
559 return success();
560}
561
562LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
563 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
564 72, 80, 88, 96, 104, 112, 120, 128,
565 136, 144, 152, 160, 168, 176, 184, 192,
566 200, 208, 216, 224, 232, 240, 248, 256};
567 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
568 80, 96, 112, 128, 144, 160,
569 176, 192, 208, 224, 240, 256};
570 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
571 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
572 if (llvm::is_contained(allowedN, sizeN))
573 return success();
574
575 if (typeA.isInteger(8) || typeA.isInteger(1))
576 if (llvm::is_contained(allowedNshort, sizeN))
577 return success();
578 return failure();
579}
580
581LogicalResult WarpgroupMmaOp::verify() {
582 if (getTransposeA() && !getTransposeB())
583 return emitOpError()
584 << "supports non-transpose A (Row Major) "
585 "and transpose B (Column Major) for the time being ";
586 MemRefType matrixA = getDescriptorA().getType().getTensor();
587 MemRefType matrixB = getDescriptorB().getType().getTensor();
588 VectorType matrixC = getMatrixC().getType().getFragmented();
589 VectorType matrixD = getMatrixD().getType().getFragmented();
590
591 if (matrixC != matrixD)
592 return emitOpError() << "type of matrix C and matrix D must be the same";
593
594 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
595 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
596 return emitOpError()
597 << "has matrices A, B, C and D, they must be 2 dimensional";
598 }
599
600 if (matrixA.getShape()[1] != matrixB.getShape()[0])
601 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
602 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
603 << " )";
604 if (matrixA.getShape()[0] != matrixC.getShape()[0])
605 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
606 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
607 << " )";
608 if (matrixB.getShape()[1] != matrixC.getShape()[1])
609 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
610 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
611 << " )";
612
613 if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
614 matrixA.getElementType(),
615 matrixB.getElementType())))
616 return emitOpError() << matrixC.getElementType()
617 << " += " << matrixA.getElementType() << " * "
618 << matrixB.getElementType()
619 << ", it is not supported.";
620 // Check N
621 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
622 return emitOpError() << "has input type " << matrixB << " n is set to "
623 << matrixB.getDimSize(1) << ", it is not supported";
624 }
625
626 // Currently, f16/bf16 supported
627 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
628 !matrixA.getElementType().isBF16()) {
629 return emitOpError() << "hit a limitation: " << matrixC.getElementType()
630 << " += " << matrixA.getElementType() << " * "
631 << matrixB.getElementType()
632 << ", it is not supported yet";
633 }
634
635 return success();
636}
637
638LogicalResult WarpgroupMmaStoreOp::verify() {
639 MemRefType dstMemrefType = getDstMemref().getType();
640 VectorType vtype = getMatrixD().getType().getFragmented();
641
642 // Limitation
643 if (!vtype.getElementType().isF32()) {
644 return emitOpError()
645 << "hit a limitation: only f32 results for the time being";
646 }
647 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
648 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
649 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
650 << "] values. However, destination memref["
651 << dstMemrefType.getDimSize(0) << "]["
652 << dstMemrefType.getDimSize(1)
653 << "] does not have same size as results";
654 }
655 return success();
656}
657
658//===----------------------------------------------------------------------===//
659// WarpgroupMmaInitAccumulatorOp
660//===----------------------------------------------------------------------===//
661
662LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
663 WarpgroupAccumulatorType accType = getMatrixC().getType();
664 int64_t sizeM = accType.getFragmented().getDimSize(0);
665 int64_t sizeN = accType.getFragmented().getDimSize(1);
666 Type elemType = accType.getFragmented().getElementType();
667
668 if (failed(isAllowedSizeM(sizeM)) ||
669 failed(isAllowedSizeN(sizeN, elemType))) {
670 return emitOpError() << "has type " << accType.getFragmented()
671 << ". It does not fit into warp-group "
672 "level (wgmma) matrix multiplication instruction "
673 "(or not supported yet)";
674 }
675 return success();
676}
677
678//===----------------------------------------------------------------------===//
679// RcpOp
680//===----------------------------------------------------------------------===//
681
682LogicalResult RcpOp::verify() {
683 RcpRoundingModeAttr rounding = getRoundingAttr();
684 bool ftz = getFtz();
685 // Currently, only `rcp_approx` and `ftz` is supported.
686 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
687 return emitOpError() << "has a limitation. " << rounding
688 << " or non-ftz is not supported yet.";
689 }
690 return success();
691}
692
693//===----------------------------------------------------------------------===//
694// TableGen'd dialect, type, and op definitions
695//===----------------------------------------------------------------------===//
696
697#define GET_ATTRDEF_CLASSES
698#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
699
700#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
701
702#define GET_OP_CLASSES
703#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
704
705#define GET_TYPEDEF_CLASSES
706#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
This represents an operation in an abstracted form, suitable for use with the builder APIs.