MLIR 22.0.0git
NVGPUDialect.cpp
Go to the documentation of this file.
1//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the NVGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
18#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/Verifier.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26using namespace mlir::nvgpu;
27
28#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
29
30void NVGPUDialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
34 >();
35 addAttributes<
36#define GET_ATTRDEF_LIST
37#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
38 >();
39 addOperations<
40#define GET_OP_LIST
41#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
42 >();
43}
44
45bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
46 if (!memorySpace)
47 return false;
48 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
52 return false;
53}
54
55bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
58}
59
60//===----------------------------------------------------------------------===//
61// NVGPU_DeviceAsyncCopyOp
62//===----------------------------------------------------------------------===//
63
64LogicalResult DeviceAsyncCopyOp::verify() {
65 auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
66 auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
67
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError("source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError("destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
73 return emitError()
74 << "destination memref must have a memory space attribute of "
75 "IntegerAttr("
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 << ") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError("source and destination must have the same element type");
80 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() << "expected " << srcMemref.getRank()
82 << " source indices, got " << getSrcIndices().size();
83 if (size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() << "expected " << dstMemref.getRank()
85 << " destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92 diag << "Requested copy elements is " << dstElements << " with width "
93 << dstMemref.getElementTypeBitWidth()
94 << ". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) << ", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) << ", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) << ".";
101 return diag;
102 }
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() << "bypassL1 does not satify alignment for "
107 << dstMemref << " with destination element "
108 << dstElements
109 << ". Unset bypassL1, or set "
110 "destination element to "
111 << req;
112 }
113 }
114 return success();
115}
116
117//===----------------------------------------------------------------------===//
118// NVGPU_MmaSyncOp
119//===----------------------------------------------------------------------===//
120void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
121 ::mlir::OperationState &odsState, Value matrixA,
122 Value matrixB, Value matrixC, ArrayAttr mmaShape) {
123 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
124 mmaShape, UnitAttr());
125}
126
127void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
128 ::mlir::OperationState &odsState, Value matrixA,
129 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
130 bool tf32Enabled) {
131 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
132 odsBuilder.getI64ArrayAttr(mmaShape),
133 tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
134}
135
136/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
137static LogicalResult verifyMmaSyncOp(Operation *op,
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled, bool sparse = false) {
143 // The verification for mma.sync covering various shapes and data types is
144 // based on the fundamental tensor core shape.
145
146 // "Fundamental" tensor core shapes:
147 // - For F32 (TF32), F16, S8, and S4 data
148 // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
149 // - F64 is an exception and is of shape 8-by-8-by-256b.
150 int64_t shapeM = 8;
151 int64_t shapeN = 8;
152 int64_t shapeK; // set based on data type (128b for all data types except F64)
153
154 // Number of elements A, B, and C per thread per fundamental tensor core tile
155 int64_t numElementA; // set based on data type (32b except F64)
156 int64_t numElementB; // set based on data type (32b except F64)
157 int64_t numElementC{2}; // two accumulator elements per fundamental tile
158
159 // nvgpu.mma.sync vector operands (per thread)
160 auto aVector = matrixA.getType();
161 auto bVector = matrixB.getType();
162 auto cVector = matrixC.getType();
163
164 // vector shapes
165 ArrayRef<int64_t> aShape = aVector.getShape();
166 ArrayRef<int64_t> bShape = bVector.getShape();
167 ArrayRef<int64_t> cShape = cVector.getShape();
168
169 // vector element type
170 Type aType = aVector.getElementType();
171
172 // Certain data types are not allowed in sparse mode.
173 if (sparse && aType.isF64())
174 return op->emitError() << "f64 is not supported for sparse mode";
175
176 if (aType.isF64()) {
177 // exception to 8-by-8-128b fundamental tensor core tile size
178 shapeK = 4;
179 numElementA = 1;
180 numElementB = 1;
181 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
182 aType.isInteger(8) || aType.isInteger(4)) {
183 // 8-by-8-128b fundamental tensor core tile size
184 int operandBitwidth = aType.getIntOrFloatBitWidth();
185 shapeK = 128 / operandBitwidth; // 128b wide shapeK
186
187 numElementA = 32 / operandBitwidth; // 32b wide operand A
188 numElementB = 32 / operandBitwidth; // 32b wide operand B
189 } else {
190 return op->emitError()
191 << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
192 "supported by "
193 << op->getName();
194 }
195
196 //
197 // Basic verification
198 //
199
200 if (aShape.size() != 2) {
201 return op->emitError() << "matrixA must be 2 dimensional vector";
202 }
203
204 if (bShape.size() != 2) {
205 return op->emitError() << "matrixB must be 2 dimensional vector";
206 }
207
208 if (cShape.size() != 2) {
209 return op->emitError() << "matrixC must be 2 dimensional vector";
210 }
211
212 auto [m, n, k] = mmaShape;
213
214 // verify warp-wide size for vector a
215 int64_t sparseFactor = sparse ? 2 : 1;
216 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
217 return op->emitOpError()
218 << "expected " << m * k << " warp-wide matrix A elements";
219
220 // verify warp-wide size for vector b
221 if (bShape[0] * bShape[1] * kWarpSize != k * n)
222 return op->emitOpError()
223 << "expected " << k * n << " warp-wide matrix B elements";
224
225 // verify warp-wide size for vector c
226 if (cShape[0] * cShape[1] * kWarpSize != m * n)
227 return op->emitOpError()
228 << "expected " << m * n << " warp-wide matrix C elements";
229
230 // verify tf32 tensor cores are enabled for only F32 datatype
231 if (tf32Enabled && !(aType.isF32()))
232 return op->emitOpError()
233 << "expected tf32 tensor cores only for F32 operands";
234
235 //
236 // Extended verification
237 //
238
239 // tiles of fundamental tensor core operations
240 int64_t mTile = m / shapeM;
241 int64_t nTile = n / shapeN;
242 int64_t kTile = k / shapeK;
243
244 // verify shape of aVector
245 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
246 (aShape[1] != numElementA))
247 return op->emitOpError() << "expected matrix A to be shaped ("
248 << mTile * kTile << " x " << numElementA << ")";
249
250 // verify shape of bVector
251 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
252 return op->emitOpError() << "expected matrix B to be shaped ("
253 << kTile * nTile << " x " << numElementB << ")";
254
255 // verify shape of cVector
256 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
257 return op->emitOpError() << "expected matrix C to be shaped ("
258 << mTile * nTile << " x " << numElementC << ")";
259
260 return success();
261}
262
263LogicalResult MmaSyncOp::verify() {
264 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
265 getMatrixC(), getMmaShapeAsArray(),
266 getOperation()->hasAttr(getTf32EnabledAttrName()));
267}
268
269//===----------------------------------------------------------------------===//
270// NVGPU_MmaSparseSyncOp
271//===----------------------------------------------------------------------===//
272void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
273 ::mlir::OperationState &odsState, Value matrixA,
274 Value matrixB, Value matrixC, Value sparseMetadata,
275 ArrayRef<int64_t> mmaShape) {
276 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
277 sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
278}
279
280LogicalResult MmaSparseSyncOp::verify() {
281 unsigned sparsitySelector = getSparsitySelector();
282 if (sparsitySelector > 1)
283 return emitOpError() << "sparsity selector should be 0 or 1";
284 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
285 getMatrixC(), getMmaShapeAsArray(),
286 getOperation()->hasAttr(getTf32EnabledAttrName()),
287 true);
288}
289
290//===----------------------------------------------------------------------===//
291// NVGPU_LdMatrixOp
292//===----------------------------------------------------------------------===//
293LogicalResult LdMatrixOp::verify() {
294 // ldmatrix reads data from source in shared memory
295 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
296
297 // ldmatrix writes data to result/destination in vector registers
298 auto resVector = llvm::cast<VectorType>(getRes().getType());
299
300 // vector register shape, element type, and bitwidth
301 ArrayRef<int64_t> resShape = resVector.getShape();
302 Type resType = resVector.getElementType();
303 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
304
305 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
306 int64_t numElementsPer32b = 32 / elementBitWidth;
307
308 // number of 8-by-8 tiles
309 int64_t numTiles = getNumTiles();
310
311 // transpose elements in vector registers at 16b granularity when true
312 bool isTranspose = getTranspose();
313
314 //
315 // verification
316 //
317
318 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
319 return emitError()
320 << "expected nvgpu.ldmatrix srcMemref must have a memory space "
321 "attribute of IntegerAttr("
322 << NVGPUDialect::kSharedMemoryAddressSpace
323 << ") or gpu::AddressSpaceAttr(Workgroup)";
324 if (elementBitWidth > 32)
325 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
326 if (isTranspose && !(elementBitWidth == 16))
327 return emitError()
328 << "nvgpu.ldmatrix transpose works only at 16b granularity";
329 if (resShape.size() != 2) {
330 return emitError() << "results must be 2 dimensional vector";
331 }
332 if (!(resShape[1] == numElementsPer32b))
333 return emitError() << "expected vector register shape[1] = "
334 << numElementsPer32b;
335 if (!(resShape[0] == numTiles))
336 return emitError()
337 << "expected vector register shape[0] and numTiles to match";
338
339 return success();
340}
341
342//===----------------------------------------------------------------------===//
343// NVGPU_TmaAsyncLoadOp
344//===----------------------------------------------------------------------===//
345
346static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
347 switch (kind) {
348 case TensorMapSwizzleKind::SWIZZLE_32B:
349 return 32;
350 case TensorMapSwizzleKind::SWIZZLE_64B:
351 return 64;
352 case TensorMapSwizzleKind::SWIZZLE_128B:
353 return 128;
354 default:
355 return 0;
356 }
357}
358
359std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
360 Operation *op, TensorMapDescriptorType descType,
361 std::optional<MemRefType> memrefType = std::nullopt) {
362 MemRefType descMemref = descType.getTensor();
363 // Limitation
364 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
365 return op->emitError() << "Interleave options are not supported yet.";
366
367 // Address space check for shared memory check
368 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
369 return op->emitError() << "the tensor map descriptor has incorrect address "
370 "space, it must be shared memory address space.";
371 }
372 // Support only static shape for the time being
373 if (!descMemref.hasStaticShape())
374 return op->emitError() << "the tensor map descriptor must be static shaped";
375
376 for (auto dim : descMemref.getShape()) {
377 if (dim <= 0 || dim > kMaxTMADimension) {
378 return op->emitError() << "the tensor map descriptor must have "
379 "dimensions between 1 and "
380 << kMaxTMADimension << " but it is " << dim;
381 }
382 }
383 if (descMemref.getRank() > 1 &&
384 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
385 unsigned lastDimensionByte =
386 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
387 unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
388 if (lastDimensionByte != expectByte)
389 return op->emitError() << "the tensormap descriptor must have last "
390 "dimension of "
391 << expectByte << " bytes but it is "
392 << lastDimensionByte << " bytes";
393 }
394
395 // No verification if memref type is not provided
396 if (!memrefType.has_value())
397 return std::nullopt;
398
399 MemRefType dstMemref = memrefType.value();
400
401 // Check element type
402 if (descMemref.getElementType() != dstMemref.getElementType()) {
403 return op->emitError() << "the element type of tensor map descriptor and "
404 "memref must be same";
405 }
406
407 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
408 return op->emitError() << "the destination memref has incorrect address "
409 "space, it must be shared memory address space.";
410 }
411 if (!dstMemref.hasStaticShape())
412 return op->emitError() << "the destination memref must be static shaped";
413
414 if (dstMemref.getRank() != descMemref.getRank()) {
415 return op->emitError() << "the shape of tensor map descriptor and "
416 "memref must have same rank";
417 }
418 if (!descMemref.getShape().equals(dstMemref.getShape())) {
419 return op->emitError() << "memref and tensor map shapes mismatch "
420 << descMemref << " != " << dstMemref;
421 }
422
423 int lastDimBytes =
424 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
425 if (lastDimBytes % kTMALastdimByte != 0) {
426 return op->emitError() << "the bytes in the last dimension of the tensor "
427 "map must be a multiple of 16";
428 }
429 return std::nullopt;
430}
431
432LogicalResult TmaAsyncLoadOp::verify() {
433 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
434 *this, getTensorMapDescriptor().getType(), getDst().getType());
435 if (error.has_value())
436 return error.value();
437
438 if (getCoordinates().size() > kMaxTMATensorDimension) {
439 return emitError() << "Maximum " << kMaxTMATensorDimension
440 << " coordinates are supported.";
441 }
442 if (getCoordinates().size() !=
443 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
444 return emitError() << "number of coordinates do not match with the rank of "
445 "tensor descriptor map.";
446 }
447
448 return success();
449}
450
451//===----------------------------------------------------------------------===//
452// NVGPU_TmaAsyncStoreOp
453//===----------------------------------------------------------------------===//
454
455LogicalResult TmaAsyncStoreOp::verify() {
456 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
457 *this, getTensorMapDescriptor().getType(), getSrc().getType());
458 if (error.has_value())
459 return error.value();
460
461 if (getCoordinates().size() > kMaxTMATensorDimension) {
462 return emitError() << "Maximum " << kMaxTMATensorDimension
463 << " coordinates are supported.";
464 }
465 if (getCoordinates().size() !=
466 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
467 return emitError() << "number of coordinates do not match with the rank of "
468 "tensor descriptor map.";
469 }
470
471 return success();
472}
473
474LogicalResult TmaCreateDescriptorOp::verify() {
475 if (getBoxDimensions().size() > kMaxTMATensorDimension) {
476 return emitError() << "Maximum " << kMaxTMATensorDimension
477 << " coordinates are supported.";
478 }
479
480 std::optional<InFlightDiagnostic> error =
481 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
482 if (error.has_value())
483 return error.value();
484
485 return success();
486}
487
488//===----------------------------------------------------------------------===//
489// NVGPU_WarpgroupGenerateDescriptorOp
490//===----------------------------------------------------------------------===//
491
492LogicalResult WarpgroupGenerateDescriptorOp::verify() {
493 std::optional<InFlightDiagnostic> error =
494 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
495 if (error.has_value())
496 return error.value();
497
498 if (getTensorMap().getType().getSwizzle() !=
499 TensorMapSwizzleKind::SWIZZLE_128B) {
500 return emitError() << "supports only "
501 << stringifyTensorMapSwizzleKind(
502 TensorMapSwizzleKind::SWIZZLE_128B)
503 << " is supported for the time being";
504 }
505
506 if (getTensorMap().getType().getInterleave() !=
507 TensorMapInterleaveKind::INTERLEAVE_NONE) {
508 return emitError() << "supports only "
509 << stringifyTensorMapInterleaveKind(
510 TensorMapInterleaveKind::INTERLEAVE_NONE)
511 << " is supported for the time being";
512 }
513
514 return success();
515}
516
517//===----------------------------------------------------------------------===//
518// WarpgroupMmaOp
519//===----------------------------------------------------------------------===//
520
521LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
522 // F32 += F16 + F16
523 // F16 += F16 + F16
524 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
525 return success();
526 // F32 += TF32 + TF32
527 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
528 return success();
529 // s32 += i8 + i8
530 if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
531 return success();
532 // s32 += i1 + i1
533 if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
534 return success();
535 // F32 += BF16 + BF16
536 // F16 += BF16 + BF16
537 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
538 return success();
539 // F16 += f8 + f8
540 // F32 += f8 + f8
541 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
542 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
543 (typeD.isF32() || typeD.isF16()))
544 return success();
545
546 return failure();
547}
548
549LogicalResult isAllowedSizeM(int sizeM) {
550 if (sizeM % kWgmmaSizeM)
551 return failure();
552 return success();
553}
554
555LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
556 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
557 72, 80, 88, 96, 104, 112, 120, 128,
558 136, 144, 152, 160, 168, 176, 184, 192,
559 200, 208, 216, 224, 232, 240, 248, 256};
560 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
561 80, 96, 112, 128, 144, 160,
562 176, 192, 208, 224, 240, 256};
563 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
564 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
565 if (llvm::is_contained(allowedN, sizeN))
566 return success();
567
568 if (typeA.isInteger(8) || typeA.isInteger(1))
569 if (llvm::is_contained(allowedNshort, sizeN))
570 return success();
571 return failure();
572}
573
574LogicalResult WarpgroupMmaOp::verify() {
575 if (getTransposeA() && !getTransposeB())
576 return emitOpError()
577 << "supports non-transpose A (Row Major) "
578 "and transpose B (Column Major) for the time being ";
579 MemRefType matrixA = getDescriptorA().getType().getTensor();
580 MemRefType matrixB = getDescriptorB().getType().getTensor();
581 VectorType matrixC = getMatrixC().getType().getFragmented();
582 VectorType matrixD = getMatrixD().getType().getFragmented();
583
584 if (matrixC != matrixD)
585 return emitOpError() << "type of matrix C and matrix D must be the same";
586
587 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
588 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
589 return emitOpError()
590 << "has matrices A, B, C and D, they must be 2 dimensional";
591 }
592
593 if (matrixA.getShape()[1] != matrixB.getShape()[0])
594 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
595 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
596 << " )";
597 if (matrixA.getShape()[0] != matrixC.getShape()[0])
598 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
599 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
600 << " )";
601 if (matrixB.getShape()[1] != matrixC.getShape()[1])
602 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
603 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
604 << " )";
605
606 if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
607 matrixA.getElementType(),
608 matrixB.getElementType())))
609 return emitOpError() << matrixC.getElementType()
610 << " += " << matrixA.getElementType() << " * "
611 << matrixB.getElementType()
612 << ", it is not supported.";
613 // Check N
614 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
615 return emitOpError() << "has input type " << matrixB << " n is set to "
616 << matrixB.getDimSize(1) << ", it is not supported";
617 }
618
619 // Currently, f16/bf16 supported
620 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
621 !matrixA.getElementType().isBF16()) {
622 return emitOpError() << "hit a limitation: " << matrixC.getElementType()
623 << " += " << matrixA.getElementType() << " * "
624 << matrixB.getElementType()
625 << ", it is not supported yet";
626 }
627
628 return success();
629}
630
631LogicalResult WarpgroupMmaStoreOp::verify() {
632 MemRefType dstMemrefType = getDstMemref().getType();
633 VectorType vtype = getMatrixD().getType().getFragmented();
634
635 // Limitation
636 if (!vtype.getElementType().isF32()) {
637 return emitOpError()
638 << "hit a limitation: only f32 results for the time being";
639 }
640 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
641 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
642 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
643 << "] values. However, destination memref["
644 << dstMemrefType.getDimSize(0) << "]["
645 << dstMemrefType.getDimSize(1)
646 << "] does not have same size as results";
647 }
648 return success();
649}
650
651//===----------------------------------------------------------------------===//
652// WarpgroupMmaInitAccumulatorOp
653//===----------------------------------------------------------------------===//
654
655LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
656 WarpgroupAccumulatorType accType = getMatrixC().getType();
657 int64_t sizeM = accType.getFragmented().getDimSize(0);
658 int64_t sizeN = accType.getFragmented().getDimSize(1);
659 Type elemType = accType.getFragmented().getElementType();
660
661 if (failed(isAllowedSizeM(sizeM)) ||
662 failed(isAllowedSizeN(sizeN, elemType))) {
663 return emitOpError() << "has type " << accType.getFragmented()
664 << ". It does not fit into warp-group "
665 "level (wgmma) matrix multiplication instruction "
666 "(or not supported yet)";
667 }
668 return success();
669}
670
671//===----------------------------------------------------------------------===//
672// RcpOp
673//===----------------------------------------------------------------------===//
674
675LogicalResult RcpOp::verify() {
676 RcpRoundingModeAttr rounding = getRoundingAttr();
677 bool ftz = getFtz();
678 // Currently, only `rcp_approx` and `ftz` is supported.
679 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
680 return emitOpError() << "has a limitation. " << rounding
681 << " or non-ftz is not supported yet.";
682 }
683 return success();
684}
685
686//===----------------------------------------------------------------------===//
687// TableGen'd dialect, type, and op definitions
688//===----------------------------------------------------------------------===//
689
690#define GET_ATTRDEF_CLASSES
691#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
692
693#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
694
695#define GET_OP_CLASSES
696#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
697
698#define GET_TYPEDEF_CLASSES
699#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:98
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
This represents an operation in an abstracted form, suitable for use with the builder APIs.