MLIR 23.0.0git
NVGPUDialect.cpp
Go to the documentation of this file.
1//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the NVGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
19#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Verifier.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26using namespace mlir;
27using namespace mlir::nvgpu;
28
29#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
30
31void NVGPUDialect::initialize() {
32 addTypes<
33#define GET_TYPEDEF_LIST
34#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
35 >();
36 addAttributes<
37#define GET_ATTRDEF_LIST
38#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
39 >();
40 addOperations<
41#define GET_OP_LIST
42#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
43 >();
44 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LdMatrixOp>();
45 declarePromisedInterfaces<memref::IndexedMemCopyOpInterface,
46 DeviceAsyncCopyOp>();
47}
48
49bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
50 if (!memorySpace)
51 return false;
52 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
53 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
54 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
55 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
56 return false;
57}
58
59bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
60 Attribute memorySpace = type.getMemorySpace();
61 return isSharedMemoryAddressSpace(memorySpace);
62}
63
64//===----------------------------------------------------------------------===//
65// NVGPU_DeviceAsyncCopyOp
66//===----------------------------------------------------------------------===//
67
68LogicalResult DeviceAsyncCopyOp::verify() {
69 auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
70 auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
71
72 if (!srcMemref.isLastDimUnitStride())
73 return emitError("source memref most minor dim must have unit stride");
74 if (!dstMemref.isLastDimUnitStride())
75 return emitError("destination memref most minor dim must have unit stride");
76 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
77 return emitError()
78 << "destination memref must have a memory space attribute of "
79 "IntegerAttr("
80 << NVGPUDialect::kSharedMemoryAddressSpace
81 << ") or gpu::AddressSpaceAttr(Workgroup)";
82 if (dstMemref.getElementType() != srcMemref.getElementType())
83 return emitError("source and destination must have the same element type");
84 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
85 return emitOpError() << "expected " << srcMemref.getRank()
86 << " source indices, got " << getSrcIndices().size();
87 if (size_t(dstMemref.getRank()) != getDstIndices().size())
88 return emitOpError() << "expected " << dstMemref.getRank()
89 << " destination indices, got "
90 << getDstIndices().size();
91 int64_t dstElements = getDstElements().getZExtValue();
92 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
93 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
94 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
96 diag << "Requested copy elements is " << dstElements << " with width "
97 << dstMemref.getElementTypeBitWidth()
98 << ". But copy elements could be one of ";
99 if ((32 / dstWidth) > 0)
100 diag << (32 / dstWidth) << ", ";
101 if ((64 / dstWidth) > 0)
102 diag << (64 / dstWidth) << ", ";
103 if ((128 / dstWidth) > 0)
104 diag << (128 / dstWidth) << ".";
105 return diag;
106 }
107 if (getBypassL1().has_value()) {
108 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
109 if (getBypassL1().value() && sizeInBytes != 16) {
110 return emitOpError() << "bypassL1 does not satify alignment for "
111 << dstMemref << " with destination element "
112 << dstElements
113 << ". Unset bypassL1, or set "
114 "destination element to "
115 << req;
116 }
117 }
118 return success();
119}
120
121//===----------------------------------------------------------------------===//
122// NVGPU_MmaSyncOp
123//===----------------------------------------------------------------------===//
124void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
125 ::mlir::OperationState &odsState, Value matrixA,
126 Value matrixB, Value matrixC, ArrayAttr mmaShape) {
127 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
128 mmaShape, UnitAttr());
129}
130
131void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
132 ::mlir::OperationState &odsState, Value matrixA,
133 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
134 bool tf32Enabled) {
135 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
136 odsBuilder.getI64ArrayAttr(mmaShape),
137 tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
138}
139
140/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
141static LogicalResult verifyMmaSyncOp(Operation *op,
145 const std::array<int64_t, 3> &mmaShape,
146 bool tf32Enabled, bool sparse = false) {
147 // The verification for mma.sync covering various shapes and data types is
148 // based on the fundamental tensor core shape.
149
150 // "Fundamental" tensor core shapes:
151 // - For F32 (TF32), F16, S8, and S4 data
152 // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
153 // - F64 is an exception and is of shape 8-by-8-by-256b.
154 int64_t shapeM = 8;
155 int64_t shapeN = 8;
156 int64_t shapeK; // set based on data type (128b for all data types except F64)
157
158 // Number of elements A, B, and C per thread per fundamental tensor core tile
159 int64_t numElementA; // set based on data type (32b except F64)
160 int64_t numElementB; // set based on data type (32b except F64)
161 int64_t numElementC{2}; // two accumulator elements per fundamental tile
162
163 // nvgpu.mma.sync vector operands (per thread)
164 auto aVector = matrixA.getType();
165 auto bVector = matrixB.getType();
166 auto cVector = matrixC.getType();
167
168 // vector shapes
169 ArrayRef<int64_t> aShape = aVector.getShape();
170 ArrayRef<int64_t> bShape = bVector.getShape();
171 ArrayRef<int64_t> cShape = cVector.getShape();
172
173 // vector element type
174 Type aType = aVector.getElementType();
175
176 // Certain data types are not allowed in sparse mode.
177 if (sparse && aType.isF64())
178 return op->emitError() << "f64 is not supported for sparse mode";
179
180 if (aType.isF64()) {
181 // exception to 8-by-8-128b fundamental tensor core tile size
182 shapeK = 4;
183 numElementA = 1;
184 numElementB = 1;
185 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
186 aType.isInteger(8) || aType.isInteger(4)) {
187 // 8-by-8-128b fundamental tensor core tile size
188 int operandBitwidth = aType.getIntOrFloatBitWidth();
189 shapeK = 128 / operandBitwidth; // 128b wide shapeK
190
191 numElementA = 32 / operandBitwidth; // 32b wide operand A
192 numElementB = 32 / operandBitwidth; // 32b wide operand B
193 } else {
194 return op->emitError()
195 << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
196 "supported by "
197 << op->getName();
198 }
199
200 //
201 // Basic verification
202 //
203
204 if (aShape.size() != 2) {
205 return op->emitError() << "matrixA must be 2 dimensional vector";
206 }
207
208 if (bShape.size() != 2) {
209 return op->emitError() << "matrixB must be 2 dimensional vector";
210 }
211
212 if (cShape.size() != 2) {
213 return op->emitError() << "matrixC must be 2 dimensional vector";
214 }
215
216 auto [m, n, k] = mmaShape;
217
218 // verify warp-wide size for vector a
219 int64_t sparseFactor = sparse ? 2 : 1;
220 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
221 return op->emitOpError()
222 << "expected " << m * k << " warp-wide matrix A elements";
223
224 // verify warp-wide size for vector b
225 if (bShape[0] * bShape[1] * kWarpSize != k * n)
226 return op->emitOpError()
227 << "expected " << k * n << " warp-wide matrix B elements";
228
229 // verify warp-wide size for vector c
230 if (cShape[0] * cShape[1] * kWarpSize != m * n)
231 return op->emitOpError()
232 << "expected " << m * n << " warp-wide matrix C elements";
233
234 // verify tf32 tensor cores are enabled for only F32 datatype
235 if (tf32Enabled && !(aType.isF32()))
236 return op->emitOpError()
237 << "expected tf32 tensor cores only for F32 operands";
238
239 //
240 // Extended verification
241 //
242
243 // tiles of fundamental tensor core operations
244 int64_t mTile = m / shapeM;
245 int64_t nTile = n / shapeN;
246 int64_t kTile = k / shapeK;
247
248 // verify shape of aVector
249 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
250 (aShape[1] != numElementA))
251 return op->emitOpError() << "expected matrix A to be shaped ("
252 << mTile * kTile << " x " << numElementA << ")";
253
254 // verify shape of bVector
255 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
256 return op->emitOpError() << "expected matrix B to be shaped ("
257 << kTile * nTile << " x " << numElementB << ")";
258
259 // verify shape of cVector
260 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
261 return op->emitOpError() << "expected matrix C to be shaped ("
262 << mTile * nTile << " x " << numElementC << ")";
263
264 return success();
265}
266
267LogicalResult MmaSyncOp::verify() {
268 if (getMmaShape().size() != 3)
269 return emitOpError() << "mmaShape must have exactly 3 elements";
270
271 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
272 getMatrixC(), getMmaShapeAsArray(),
273 getOperation()->hasAttr(getTf32EnabledAttrName()));
274}
275
276//===----------------------------------------------------------------------===//
277// NVGPU_MmaSparseSyncOp
278//===----------------------------------------------------------------------===//
279void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
280 ::mlir::OperationState &odsState, Value matrixA,
281 Value matrixB, Value matrixC, Value sparseMetadata,
282 ArrayRef<int64_t> mmaShape) {
283 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
284 sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
285}
286
287LogicalResult MmaSparseSyncOp::verify() {
288 unsigned sparsitySelector = getSparsitySelector();
289 if (sparsitySelector > 1)
290 return emitOpError() << "sparsity selector should be 0 or 1";
291
292 if (getMmaShape().size() != 3)
293 return emitOpError() << "mmaShape must have exactly 3 elements";
294
295 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
296 getMatrixC(), getMmaShapeAsArray(),
297 getOperation()->hasAttr(getTf32EnabledAttrName()),
298 true);
299}
300
301//===----------------------------------------------------------------------===//
302// NVGPU_LdMatrixOp
303//===----------------------------------------------------------------------===//
304LogicalResult LdMatrixOp::verify() {
305 // ldmatrix reads data from source in shared memory
306 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
307
308 // ldmatrix writes data to result/destination in vector registers
309 auto resVector = llvm::cast<VectorType>(getRes().getType());
310
311 // vector register shape, element type, and bitwidth
312 ArrayRef<int64_t> resShape = resVector.getShape();
313 Type resType = resVector.getElementType();
314 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
315
316 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
317 int64_t numElementsPer32b = 32 / elementBitWidth;
318
319 // number of 8-by-8 tiles
320 int64_t numTiles = getNumTiles();
321
322 // transpose elements in vector registers at 16b granularity when true
323 bool isTranspose = getTranspose();
324
325 //
326 // verification
327 //
328
329 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
330 return emitError()
331 << "expected nvgpu.ldmatrix srcMemref must have a memory space "
332 "attribute of IntegerAttr("
333 << NVGPUDialect::kSharedMemoryAddressSpace
334 << ") or gpu::AddressSpaceAttr(Workgroup)";
335 if (elementBitWidth > 32)
336 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
337 if (isTranspose && !(elementBitWidth == 16))
338 return emitError()
339 << "nvgpu.ldmatrix transpose works only at 16b granularity";
340 if (resShape.size() != 2) {
341 return emitError() << "results must be 2 dimensional vector";
342 }
343 if (!(resShape[1] == numElementsPer32b))
344 return emitError() << "expected vector register shape[1] = "
345 << numElementsPer32b;
346 if (!(resShape[0] == numTiles))
347 return emitError()
348 << "expected vector register shape[0] and numTiles to match";
349
350 return success();
351}
352
353//===----------------------------------------------------------------------===//
354// NVGPU_TmaAsyncLoadOp
355//===----------------------------------------------------------------------===//
356
357static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
358 switch (kind) {
359 case TensorMapSwizzleKind::SWIZZLE_32B:
360 return 32;
361 case TensorMapSwizzleKind::SWIZZLE_64B:
362 return 64;
363 case TensorMapSwizzleKind::SWIZZLE_128B:
364 return 128;
365 default:
366 return 0;
367 }
368}
369
370std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
371 Operation *op, TensorMapDescriptorType descType,
372 std::optional<MemRefType> memrefType = std::nullopt) {
373 MemRefType descMemref = descType.getTensor();
374 // Limitation
375 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
376 return op->emitError() << "Interleave options are not supported yet.";
377
378 // Address space check for shared memory check
379 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
380 return op->emitError() << "the tensor map descriptor has incorrect address "
381 "space, it must be shared memory address space.";
382 }
383 // Support only static shape for the time being
384 if (!descMemref.hasStaticShape())
385 return op->emitError() << "the tensor map descriptor must be static shaped";
386
387 for (auto dim : descMemref.getShape()) {
388 if (dim <= 0 || dim > kMaxTMADimension) {
389 return op->emitError() << "the tensor map descriptor must have "
390 "dimensions between 1 and "
391 << kMaxTMADimension << " but it is " << dim;
392 }
393 }
394 if (descMemref.getRank() > 1 &&
395 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
396 unsigned lastDimensionByte =
397 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
398 unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
399 if (lastDimensionByte != expectByte)
400 return op->emitError() << "the tensormap descriptor must have last "
401 "dimension of "
402 << expectByte << " bytes but it is "
403 << lastDimensionByte << " bytes";
404 }
405
406 // No verification if memref type is not provided
407 if (!memrefType.has_value())
408 return std::nullopt;
409
410 MemRefType dstMemref = memrefType.value();
411
412 // Check element type
413 if (descMemref.getElementType() != dstMemref.getElementType()) {
414 return op->emitError() << "the element type of tensor map descriptor and "
415 "memref must be same";
416 }
417
418 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
419 return op->emitError() << "the destination memref has incorrect address "
420 "space, it must be shared memory address space.";
421 }
422 if (!dstMemref.hasStaticShape())
423 return op->emitError() << "the destination memref must be static shaped";
424
425 if (dstMemref.getRank() != descMemref.getRank()) {
426 return op->emitError() << "the shape of tensor map descriptor and "
427 "memref must have same rank";
428 }
429 if (!descMemref.getShape().equals(dstMemref.getShape())) {
430 return op->emitError() << "memref and tensor map shapes mismatch "
431 << descMemref << " != " << dstMemref;
432 }
433
434 int lastDimBytes =
435 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
436 if (lastDimBytes % kTMALastdimByte != 0) {
437 return op->emitError() << "the bytes in the last dimension of the tensor "
438 "map must be a multiple of 16";
439 }
440 return std::nullopt;
441}
442
443LogicalResult TmaAsyncLoadOp::verify() {
444 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
445 *this, getTensorMapDescriptor().getType(), getDst().getType());
446 if (error.has_value())
447 return error.value();
448
449 if (getCoordinates().size() > kMaxTMATensorDimension) {
450 return emitError() << "Maximum " << kMaxTMATensorDimension
451 << " coordinates are supported.";
452 }
453 if (getCoordinates().size() !=
454 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
455 return emitError() << "number of coordinates do not match with the rank of "
456 "tensor descriptor map.";
457 }
458
459 return success();
460}
461
462//===----------------------------------------------------------------------===//
463// NVGPU_TmaAsyncStoreOp
464//===----------------------------------------------------------------------===//
465
466LogicalResult TmaAsyncStoreOp::verify() {
467 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
468 *this, getTensorMapDescriptor().getType(), getSrc().getType());
469 if (error.has_value())
470 return error.value();
471
472 if (getCoordinates().size() > kMaxTMATensorDimension) {
473 return emitError() << "Maximum " << kMaxTMATensorDimension
474 << " coordinates are supported.";
475 }
476 if (getCoordinates().size() !=
477 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
478 return emitError() << "number of coordinates do not match with the rank of "
479 "tensor descriptor map.";
480 }
481
482 return success();
483}
484
485LogicalResult TmaCreateDescriptorOp::verify() {
486 if (getBoxDimensions().size() > kMaxTMATensorDimension) {
487 return emitError() << "Maximum " << kMaxTMATensorDimension
488 << " coordinates are supported.";
489 }
490
491 std::optional<InFlightDiagnostic> error =
492 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
493 if (error.has_value())
494 return error.value();
495
496 return success();
497}
498
499//===----------------------------------------------------------------------===//
500// NVGPU_WarpgroupGenerateDescriptorOp
501//===----------------------------------------------------------------------===//
502
503LogicalResult WarpgroupGenerateDescriptorOp::verify() {
504 std::optional<InFlightDiagnostic> error =
505 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
506 if (error.has_value())
507 return error.value();
508
509 if (getTensorMap().getType().getSwizzle() !=
510 TensorMapSwizzleKind::SWIZZLE_128B) {
511 return emitError() << "supports only "
512 << stringifyTensorMapSwizzleKind(
513 TensorMapSwizzleKind::SWIZZLE_128B)
514 << " is supported for the time being";
515 }
516
517 if (getTensorMap().getType().getInterleave() !=
518 TensorMapInterleaveKind::INTERLEAVE_NONE) {
519 return emitError() << "supports only "
520 << stringifyTensorMapInterleaveKind(
521 TensorMapInterleaveKind::INTERLEAVE_NONE)
522 << " is supported for the time being";
523 }
524
525 return success();
526}
527
528//===----------------------------------------------------------------------===//
529// WarpgroupMmaOp
530//===----------------------------------------------------------------------===//
531
532LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
533 // F32 += F16 + F16
534 // F16 += F16 + F16
535 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
536 return success();
537 // F32 += TF32 + TF32
538 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
539 return success();
540 // s32 += i8 + i8
541 if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
542 return success();
543 // s32 += i1 + i1
544 if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
545 return success();
546 // F32 += BF16 + BF16
547 // F16 += BF16 + BF16
548 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
549 return success();
550 // F16 += f8 + f8
551 // F32 += f8 + f8
552 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
553 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
554 (typeD.isF32() || typeD.isF16()))
555 return success();
556
557 return failure();
558}
559
560LogicalResult isAllowedSizeM(int sizeM) {
561 if (sizeM % kWgmmaSizeM)
562 return failure();
563 return success();
564}
565
566LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
567 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
568 72, 80, 88, 96, 104, 112, 120, 128,
569 136, 144, 152, 160, 168, 176, 184, 192,
570 200, 208, 216, 224, 232, 240, 248, 256};
571 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
572 80, 96, 112, 128, 144, 160,
573 176, 192, 208, 224, 240, 256};
574 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
575 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
576 if (llvm::is_contained(allowedN, sizeN))
577 return success();
578
579 if (typeA.isInteger(8) || typeA.isInteger(1))
580 if (llvm::is_contained(allowedNshort, sizeN))
581 return success();
582 return failure();
583}
584
585LogicalResult WarpgroupMmaOp::verify() {
586 if (getTransposeA() && !getTransposeB())
587 return emitOpError()
588 << "supports non-transpose A (Row Major) "
589 "and transpose B (Column Major) for the time being ";
590 MemRefType matrixA = getDescriptorA().getType().getTensor();
591 MemRefType matrixB = getDescriptorB().getType().getTensor();
592 VectorType matrixC = getMatrixC().getType().getFragmented();
593 VectorType matrixD = getMatrixD().getType().getFragmented();
594
595 if (matrixC != matrixD)
596 return emitOpError() << "type of matrix C and matrix D must be the same";
597
598 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
599 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
600 return emitOpError()
601 << "has matrices A, B, C and D, they must be 2 dimensional";
602 }
603
604 if (matrixA.getShape()[1] != matrixB.getShape()[0])
605 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
606 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
607 << " )";
608 if (matrixA.getShape()[0] != matrixC.getShape()[0])
609 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
610 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
611 << " )";
612 if (matrixB.getShape()[1] != matrixC.getShape()[1])
613 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
614 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
615 << " )";
616
617 if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
618 matrixA.getElementType(),
619 matrixB.getElementType())))
620 return emitOpError() << matrixC.getElementType()
621 << " += " << matrixA.getElementType() << " * "
622 << matrixB.getElementType()
623 << ", it is not supported.";
624 // Check N
625 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
626 return emitOpError() << "has input type " << matrixB << " n is set to "
627 << matrixB.getDimSize(1) << ", it is not supported";
628 }
629
630 // Currently, f16/bf16 supported
631 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
632 !matrixA.getElementType().isBF16()) {
633 return emitOpError() << "hit a limitation: " << matrixC.getElementType()
634 << " += " << matrixA.getElementType() << " * "
635 << matrixB.getElementType()
636 << ", it is not supported yet";
637 }
638
639 return success();
640}
641
642LogicalResult WarpgroupMmaStoreOp::verify() {
643 MemRefType dstMemrefType = getDstMemref().getType();
644 VectorType vtype = getMatrixD().getType().getFragmented();
645
646 // Limitation
647 if (!vtype.getElementType().isF32()) {
648 return emitOpError()
649 << "hit a limitation: only f32 results for the time being";
650 }
651 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
652 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
653 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
654 << "] values. However, destination memref["
655 << dstMemrefType.getDimSize(0) << "]["
656 << dstMemrefType.getDimSize(1)
657 << "] does not have same size as results";
658 }
659 return success();
660}
661
662//===----------------------------------------------------------------------===//
663// WarpgroupMmaInitAccumulatorOp
664//===----------------------------------------------------------------------===//
665
666LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
667 WarpgroupAccumulatorType accType = getMatrixC().getType();
668 int64_t sizeM = accType.getFragmented().getDimSize(0);
669 int64_t sizeN = accType.getFragmented().getDimSize(1);
670 Type elemType = accType.getFragmented().getElementType();
671
672 if (failed(isAllowedSizeM(sizeM)) ||
673 failed(isAllowedSizeN(sizeN, elemType))) {
674 return emitOpError() << "has type " << accType.getFragmented()
675 << ". It does not fit into warp-group "
676 "level (wgmma) matrix multiplication instruction "
677 "(or not supported yet)";
678 }
679 return success();
680}
681
682//===----------------------------------------------------------------------===//
683// RcpOp
684//===----------------------------------------------------------------------===//
685
686LogicalResult RcpOp::verify() {
687 RcpRoundingModeAttr rounding = getRoundingAttr();
688 bool ftz = getFtz();
689 // Currently, only `rcp_approx` and `ftz` is supported.
690 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
691 return emitOpError() << "has a limitation. " << rounding
692 << " or non-ftz is not supported yet.";
693 }
694 return success();
695}
696
697//===----------------------------------------------------------------------===//
698// TableGen'd dialect, type, and op definitions
699//===----------------------------------------------------------------------===//
700
701#define GET_ATTRDEF_CLASSES
702#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
703
704#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
705
706#define GET_OP_CLASSES
707#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
708
709#define GET_TYPEDEF_CLASSES
710#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
Definition Attributes.h:25
UnitAttr getUnitAttr()
Definition Builders.cpp:102
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
This represents an operation in an abstracted form, suitable for use with the builder APIs.