22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
28#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
30void NVGPUDialect::initialize() {
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
36#define GET_ATTRDEF_LIST
37#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
41#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
45bool NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
48 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
55bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
64LogicalResult DeviceAsyncCopyOp::verify() {
65 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
66 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError(
"source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError(
"destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
74 <<
"destination memref must have a memory space attribute of "
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 <<
") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError(
"source and destination must have the same element type");
80 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() <<
"expected " << srcMemref.getRank()
82 <<
" source indices, got " << getSrcIndices().size();
83 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() <<
"expected " << dstMemref.getRank()
85 <<
" destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92 diag <<
"Requested copy elements is " << dstElements <<
" with width "
93 << dstMemref.getElementTypeBitWidth()
94 <<
". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) <<
", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) <<
", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) <<
".";
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() <<
"bypassL1 does not satify alignment for "
107 << dstMemref <<
" with destination element "
109 <<
". Unset bypassL1, or set "
110 "destination element to "
123 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
124 mmaShape, UnitAttr());
131 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
133 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled,
bool sparse =
false) {
160 auto aVector = matrixA.getType();
161 auto bVector = matrixB.getType();
162 auto cVector = matrixC.getType();
170 Type aType = aVector.getElementType();
173 if (sparse && aType.
isF64())
174 return op->
emitError() <<
"f64 is not supported for sparse mode";
185 shapeK = 128 / operandBitwidth;
187 numElementA = 32 / operandBitwidth;
188 numElementB = 32 / operandBitwidth;
191 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
200 if (aShape.size() != 2) {
201 return op->
emitError() <<
"matrixA must be 2 dimensional vector";
204 if (bShape.size() != 2) {
205 return op->
emitError() <<
"matrixB must be 2 dimensional vector";
208 if (cShape.size() != 2) {
209 return op->
emitError() <<
"matrixC must be 2 dimensional vector";
212 auto [m, n, k] = mmaShape;
215 int64_t sparseFactor = sparse ? 2 : 1;
216 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
218 <<
"expected " << m * k <<
" warp-wide matrix A elements";
221 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
223 <<
"expected " << k * n <<
" warp-wide matrix B elements";
226 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
228 <<
"expected " << m * n <<
" warp-wide matrix C elements";
231 if (tf32Enabled && !(aType.
isF32()))
233 <<
"expected tf32 tensor cores only for F32 operands";
245 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
246 (aShape[1] != numElementA))
247 return op->
emitOpError() <<
"expected matrix A to be shaped ("
248 << mTile * kTile <<
" x " << numElementA <<
")";
251 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
252 return op->
emitOpError() <<
"expected matrix B to be shaped ("
253 << kTile * nTile <<
" x " << numElementB <<
")";
256 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
257 return op->
emitOpError() <<
"expected matrix C to be shaped ("
258 << mTile * nTile <<
" x " << numElementC <<
")";
263LogicalResult MmaSyncOp::verify() {
264 if (getMmaShape().size() != 3)
265 return emitOpError() <<
"mmaShape must have exactly 3 elements";
267 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
268 getMatrixC(), getMmaShapeAsArray(),
269 getOperation()->hasAttr(getTf32EnabledAttrName()));
279 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
283LogicalResult MmaSparseSyncOp::verify() {
284 unsigned sparsitySelector = getSparsitySelector();
285 if (sparsitySelector > 1)
286 return emitOpError() <<
"sparsity selector should be 0 or 1";
288 if (getMmaShape().size() != 3)
289 return emitOpError() <<
"mmaShape must have exactly 3 elements";
291 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
292 getMatrixC(), getMmaShapeAsArray(),
293 getOperation()->hasAttr(getTf32EnabledAttrName()),
300LogicalResult LdMatrixOp::verify() {
302 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
305 auto resVector = llvm::cast<VectorType>(getRes().
getType());
309 Type resType = resVector.getElementType();
313 int64_t numElementsPer32b = 32 / elementBitWidth;
316 int64_t numTiles = getNumTiles();
319 bool isTranspose = getTranspose();
325 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
327 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
328 "attribute of IntegerAttr("
329 << NVGPUDialect::kSharedMemoryAddressSpace
330 <<
") or gpu::AddressSpaceAttr(Workgroup)";
331 if (elementBitWidth > 32)
332 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
333 if (isTranspose && !(elementBitWidth == 16))
335 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
336 if (resShape.size() != 2) {
337 return emitError() <<
"results must be 2 dimensional vector";
339 if (!(resShape[1] == numElementsPer32b))
340 return emitError() <<
"expected vector register shape[1] = "
341 << numElementsPer32b;
342 if (!(resShape[0] == numTiles))
344 <<
"expected vector register shape[0] and numTiles to match";
355 case TensorMapSwizzleKind::SWIZZLE_32B:
357 case TensorMapSwizzleKind::SWIZZLE_64B:
359 case TensorMapSwizzleKind::SWIZZLE_128B:
367 Operation *op, TensorMapDescriptorType descType,
368 std::optional<MemRefType> memrefType = std::nullopt) {
369 MemRefType descMemref = descType.getTensor();
371 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
372 return op->
emitError() <<
"Interleave options are not supported yet.";
375 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
376 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
377 "space, it must be shared memory address space.";
380 if (!descMemref.hasStaticShape())
381 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
383 for (
auto dim : descMemref.getShape()) {
385 return op->
emitError() <<
"the tensor map descriptor must have "
386 "dimensions between 1 and "
390 if (descMemref.getRank() > 1 &&
391 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
392 unsigned lastDimensionByte =
393 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
395 if (lastDimensionByte != expectByte)
396 return op->
emitError() <<
"the tensormap descriptor must have last "
398 << expectByte <<
" bytes but it is "
399 << lastDimensionByte <<
" bytes";
403 if (!memrefType.has_value())
406 MemRefType dstMemref = memrefType.value();
409 if (descMemref.getElementType() != dstMemref.getElementType()) {
410 return op->
emitError() <<
"the element type of tensor map descriptor and "
411 "memref must be same";
414 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
415 return op->
emitError() <<
"the destination memref has incorrect address "
416 "space, it must be shared memory address space.";
418 if (!dstMemref.hasStaticShape())
419 return op->
emitError() <<
"the destination memref must be static shaped";
421 if (dstMemref.getRank() != descMemref.getRank()) {
422 return op->
emitError() <<
"the shape of tensor map descriptor and "
423 "memref must have same rank";
425 if (!descMemref.getShape().equals(dstMemref.getShape())) {
426 return op->
emitError() <<
"memref and tensor map shapes mismatch "
427 << descMemref <<
" != " << dstMemref;
431 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
433 return op->
emitError() <<
"the bytes in the last dimension of the tensor "
434 "map must be a multiple of 16";
439LogicalResult TmaAsyncLoadOp::verify() {
442 if (error.has_value())
443 return error.value();
447 <<
" coordinates are supported.";
450 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
451 return emitError() <<
"number of coordinates do not match with the rank of "
452 "tensor descriptor map.";
462LogicalResult TmaAsyncStoreOp::verify() {
465 if (error.has_value())
466 return error.value();
470 <<
" coordinates are supported.";
473 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
474 return emitError() <<
"number of coordinates do not match with the rank of "
475 "tensor descriptor map.";
481LogicalResult TmaCreateDescriptorOp::verify() {
484 <<
" coordinates are supported.";
487 std::optional<InFlightDiagnostic> error =
489 if (error.has_value())
490 return error.value();
499LogicalResult WarpgroupGenerateDescriptorOp::verify() {
500 std::optional<InFlightDiagnostic> error =
502 if (error.has_value())
503 return error.value();
505 if (getTensorMap().
getType().getSwizzle() !=
506 TensorMapSwizzleKind::SWIZZLE_128B) {
508 << stringifyTensorMapSwizzleKind(
509 TensorMapSwizzleKind::SWIZZLE_128B)
510 <<
" is supported for the time being";
513 if (getTensorMap().
getType().getInterleave() !=
514 TensorMapInterleaveKind::INTERLEAVE_NONE) {
516 << stringifyTensorMapInterleaveKind(
517 TensorMapInterleaveKind::INTERLEAVE_NONE)
518 <<
" is supported for the time being";
548 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
549 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
564 72, 80, 88, 96, 104, 112, 120, 128,
565 136, 144, 152, 160, 168, 176, 184, 192,
566 200, 208, 216, 224, 232, 240, 248, 256};
568 80, 96, 112, 128, 144, 160,
569 176, 192, 208, 224, 240, 256};
571 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
572 if (llvm::is_contained(allowedN, sizeN))
576 if (llvm::is_contained(allowedNshort, sizeN))
581LogicalResult WarpgroupMmaOp::verify() {
582 if (getTransposeA() && !getTransposeB())
584 <<
"supports non-transpose A (Row Major) "
585 "and transpose B (Column Major) for the time being ";
586 MemRefType matrixA = getDescriptorA().
getType().getTensor();
587 MemRefType matrixB = getDescriptorB().
getType().getTensor();
588 VectorType matrixC = getMatrixC().
getType().getFragmented();
589 VectorType matrixD = getMatrixD().getType().getFragmented();
591 if (matrixC != matrixD)
592 return emitOpError() <<
"type of matrix C and matrix D must be the same";
594 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
595 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
597 <<
"has matrices A, B, C and D, they must be 2 dimensional";
600 if (matrixA.getShape()[1] != matrixB.getShape()[0])
601 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
602 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
604 if (matrixA.getShape()[0] != matrixC.getShape()[0])
605 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
606 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
608 if (matrixB.getShape()[1] != matrixC.getShape()[1])
609 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
610 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
614 matrixA.getElementType(),
615 matrixB.getElementType())))
617 <<
" += " << matrixA.getElementType() <<
" * "
618 << matrixB.getElementType()
619 <<
", it is not supported.";
622 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
623 << matrixB.getDimSize(1) <<
", it is not supported";
627 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
628 !matrixA.getElementType().isBF16()) {
629 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
630 <<
" += " << matrixA.getElementType() <<
" * "
631 << matrixB.getElementType()
632 <<
", it is not supported yet";
638LogicalResult WarpgroupMmaStoreOp::verify() {
639 MemRefType dstMemrefType = getDstMemref().getType();
640 VectorType vtype = getMatrixD().getType().getFragmented();
643 if (!vtype.getElementType().isF32()) {
645 <<
"hit a limitation: only f32 results for the time being";
647 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
648 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
649 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
650 <<
"] values. However, destination memref["
651 << dstMemrefType.getDimSize(0) <<
"]["
652 << dstMemrefType.getDimSize(1)
653 <<
"] does not have same size as results";
662LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
663 WarpgroupAccumulatorType accType = getMatrixC().getType();
664 int64_t sizeM = accType.getFragmented().getDimSize(0);
665 int64_t sizeN = accType.getFragmented().getDimSize(1);
666 Type elemType = accType.getFragmented().getElementType();
670 return emitOpError() <<
"has type " << accType.getFragmented()
671 <<
". It does not fit into warp-group "
672 "level (wgmma) matrix multiplication instruction "
673 "(or not supported yet)";
682LogicalResult RcpOp::verify() {
683 RcpRoundingModeAttr rounding = getRoundingAttr();
686 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
687 return emitOpError() <<
"has a limitation. " << rounding
688 <<
" or non-ftz is not supported yet.";
697#define GET_ATTRDEF_CLASSES
698#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
700#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
702#define GET_OP_CLASSES
703#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
705#define GET_TYPEDEF_CLASSES
706#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
This represents an operation in an abstracted form, suitable for use with the builder APIs.