23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
29#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
31void NVGPUDialect::initialize() {
33#define GET_TYPEDEF_LIST
34#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
37#define GET_ATTRDEF_LIST
38#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
42#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
44 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LdMatrixOp>();
45 declarePromisedInterfaces<memref::IndexedMemCopyOpInterface,
49bool NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
52 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
53 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
54 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
55 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
59bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
60 Attribute memorySpace = type.getMemorySpace();
61 return isSharedMemoryAddressSpace(memorySpace);
68LogicalResult DeviceAsyncCopyOp::verify() {
69 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
70 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
72 if (!srcMemref.isLastDimUnitStride())
73 return emitError(
"source memref most minor dim must have unit stride");
74 if (!dstMemref.isLastDimUnitStride())
75 return emitError(
"destination memref most minor dim must have unit stride");
76 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
78 <<
"destination memref must have a memory space attribute of "
80 << NVGPUDialect::kSharedMemoryAddressSpace
81 <<
") or gpu::AddressSpaceAttr(Workgroup)";
82 if (dstMemref.getElementType() != srcMemref.getElementType())
83 return emitError(
"source and destination must have the same element type");
84 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
85 return emitOpError() <<
"expected " << srcMemref.getRank()
86 <<
" source indices, got " << getSrcIndices().size();
87 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
88 return emitOpError() <<
"expected " << dstMemref.getRank()
89 <<
" destination indices, got "
90 << getDstIndices().size();
91 int64_t dstElements = getDstElements().getZExtValue();
92 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
93 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
94 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
96 diag <<
"Requested copy elements is " << dstElements <<
" with width "
97 << dstMemref.getElementTypeBitWidth()
98 <<
". But copy elements could be one of ";
99 if ((32 / dstWidth) > 0)
100 diag << (32 / dstWidth) <<
", ";
101 if ((64 / dstWidth) > 0)
102 diag << (64 / dstWidth) <<
", ";
103 if ((128 / dstWidth) > 0)
104 diag << (128 / dstWidth) <<
".";
107 if (getBypassL1().has_value()) {
108 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
109 if (getBypassL1().value() && sizeInBytes != 16) {
110 return emitOpError() <<
"bypassL1 does not satify alignment for "
111 << dstMemref <<
" with destination element "
113 <<
". Unset bypassL1, or set "
114 "destination element to "
127 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
128 mmaShape, UnitAttr());
135 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
137 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
145 const std::array<int64_t, 3> &mmaShape,
146 bool tf32Enabled,
bool sparse =
false) {
164 auto aVector = matrixA.getType();
165 auto bVector = matrixB.getType();
166 auto cVector = matrixC.getType();
174 Type aType = aVector.getElementType();
177 if (sparse && aType.
isF64())
178 return op->
emitError() <<
"f64 is not supported for sparse mode";
189 shapeK = 128 / operandBitwidth;
191 numElementA = 32 / operandBitwidth;
192 numElementB = 32 / operandBitwidth;
195 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
204 if (aShape.size() != 2) {
205 return op->
emitError() <<
"matrixA must be 2 dimensional vector";
208 if (bShape.size() != 2) {
209 return op->
emitError() <<
"matrixB must be 2 dimensional vector";
212 if (cShape.size() != 2) {
213 return op->
emitError() <<
"matrixC must be 2 dimensional vector";
216 auto [m, n, k] = mmaShape;
219 int64_t sparseFactor = sparse ? 2 : 1;
220 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
222 <<
"expected " << m * k <<
" warp-wide matrix A elements";
225 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
227 <<
"expected " << k * n <<
" warp-wide matrix B elements";
230 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
232 <<
"expected " << m * n <<
" warp-wide matrix C elements";
235 if (tf32Enabled && !(aType.
isF32()))
237 <<
"expected tf32 tensor cores only for F32 operands";
249 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
250 (aShape[1] != numElementA))
251 return op->
emitOpError() <<
"expected matrix A to be shaped ("
252 << mTile * kTile <<
" x " << numElementA <<
")";
255 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
256 return op->
emitOpError() <<
"expected matrix B to be shaped ("
257 << kTile * nTile <<
" x " << numElementB <<
")";
260 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
261 return op->
emitOpError() <<
"expected matrix C to be shaped ("
262 << mTile * nTile <<
" x " << numElementC <<
")";
267LogicalResult MmaSyncOp::verify() {
268 if (getMmaShape().size() != 3)
269 return emitOpError() <<
"mmaShape must have exactly 3 elements";
271 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
272 getMatrixC(), getMmaShapeAsArray(),
273 getOperation()->hasAttr(getTf32EnabledAttrName()));
283 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
287LogicalResult MmaSparseSyncOp::verify() {
288 unsigned sparsitySelector = getSparsitySelector();
289 if (sparsitySelector > 1)
290 return emitOpError() <<
"sparsity selector should be 0 or 1";
292 if (getMmaShape().size() != 3)
293 return emitOpError() <<
"mmaShape must have exactly 3 elements";
295 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
296 getMatrixC(), getMmaShapeAsArray(),
297 getOperation()->hasAttr(getTf32EnabledAttrName()),
304LogicalResult LdMatrixOp::verify() {
306 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
309 auto resVector = llvm::cast<VectorType>(getRes().
getType());
313 Type resType = resVector.getElementType();
317 int64_t numElementsPer32b = 32 / elementBitWidth;
320 int64_t numTiles = getNumTiles();
323 bool isTranspose = getTranspose();
329 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
331 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
332 "attribute of IntegerAttr("
333 << NVGPUDialect::kSharedMemoryAddressSpace
334 <<
") or gpu::AddressSpaceAttr(Workgroup)";
335 if (elementBitWidth > 32)
336 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
337 if (isTranspose && !(elementBitWidth == 16))
339 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
340 if (resShape.size() != 2) {
341 return emitError() <<
"results must be 2 dimensional vector";
343 if (!(resShape[1] == numElementsPer32b))
344 return emitError() <<
"expected vector register shape[1] = "
345 << numElementsPer32b;
346 if (!(resShape[0] == numTiles))
348 <<
"expected vector register shape[0] and numTiles to match";
359 case TensorMapSwizzleKind::SWIZZLE_32B:
361 case TensorMapSwizzleKind::SWIZZLE_64B:
363 case TensorMapSwizzleKind::SWIZZLE_128B:
371 Operation *op, TensorMapDescriptorType descType,
372 std::optional<MemRefType> memrefType = std::nullopt) {
373 MemRefType descMemref = descType.getTensor();
375 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
376 return op->
emitError() <<
"Interleave options are not supported yet.";
379 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
380 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
381 "space, it must be shared memory address space.";
384 if (!descMemref.hasStaticShape())
385 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
387 for (
auto dim : descMemref.getShape()) {
389 return op->
emitError() <<
"the tensor map descriptor must have "
390 "dimensions between 1 and "
394 if (descMemref.getRank() > 1 &&
395 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
396 unsigned lastDimensionByte =
397 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
399 if (lastDimensionByte != expectByte)
400 return op->
emitError() <<
"the tensormap descriptor must have last "
402 << expectByte <<
" bytes but it is "
403 << lastDimensionByte <<
" bytes";
407 if (!memrefType.has_value())
410 MemRefType dstMemref = memrefType.value();
413 if (descMemref.getElementType() != dstMemref.getElementType()) {
414 return op->
emitError() <<
"the element type of tensor map descriptor and "
415 "memref must be same";
418 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
419 return op->
emitError() <<
"the destination memref has incorrect address "
420 "space, it must be shared memory address space.";
422 if (!dstMemref.hasStaticShape())
423 return op->
emitError() <<
"the destination memref must be static shaped";
425 if (dstMemref.getRank() != descMemref.getRank()) {
426 return op->
emitError() <<
"the shape of tensor map descriptor and "
427 "memref must have same rank";
429 if (!descMemref.getShape().equals(dstMemref.getShape())) {
430 return op->
emitError() <<
"memref and tensor map shapes mismatch "
431 << descMemref <<
" != " << dstMemref;
435 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
437 return op->
emitError() <<
"the bytes in the last dimension of the tensor "
438 "map must be a multiple of 16";
443LogicalResult TmaAsyncLoadOp::verify() {
446 if (error.has_value())
447 return error.value();
451 <<
" coordinates are supported.";
454 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
455 return emitError() <<
"number of coordinates do not match with the rank of "
456 "tensor descriptor map.";
466LogicalResult TmaAsyncStoreOp::verify() {
469 if (error.has_value())
470 return error.value();
474 <<
" coordinates are supported.";
477 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
478 return emitError() <<
"number of coordinates do not match with the rank of "
479 "tensor descriptor map.";
485LogicalResult TmaCreateDescriptorOp::verify() {
488 <<
" coordinates are supported.";
491 std::optional<InFlightDiagnostic> error =
493 if (error.has_value())
494 return error.value();
503LogicalResult WarpgroupGenerateDescriptorOp::verify() {
504 std::optional<InFlightDiagnostic> error =
506 if (error.has_value())
507 return error.value();
509 if (getTensorMap().
getType().getSwizzle() !=
510 TensorMapSwizzleKind::SWIZZLE_128B) {
512 << stringifyTensorMapSwizzleKind(
513 TensorMapSwizzleKind::SWIZZLE_128B)
514 <<
" is supported for the time being";
517 if (getTensorMap().
getType().getInterleave() !=
518 TensorMapInterleaveKind::INTERLEAVE_NONE) {
520 << stringifyTensorMapInterleaveKind(
521 TensorMapInterleaveKind::INTERLEAVE_NONE)
522 <<
" is supported for the time being";
552 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
553 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
568 72, 80, 88, 96, 104, 112, 120, 128,
569 136, 144, 152, 160, 168, 176, 184, 192,
570 200, 208, 216, 224, 232, 240, 248, 256};
572 80, 96, 112, 128, 144, 160,
573 176, 192, 208, 224, 240, 256};
575 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
576 if (llvm::is_contained(allowedN, sizeN))
580 if (llvm::is_contained(allowedNshort, sizeN))
585LogicalResult WarpgroupMmaOp::verify() {
586 if (getTransposeA() && !getTransposeB())
588 <<
"supports non-transpose A (Row Major) "
589 "and transpose B (Column Major) for the time being ";
590 MemRefType matrixA = getDescriptorA().
getType().getTensor();
591 MemRefType matrixB = getDescriptorB().
getType().getTensor();
592 VectorType matrixC = getMatrixC().
getType().getFragmented();
593 VectorType matrixD = getMatrixD().getType().getFragmented();
595 if (matrixC != matrixD)
596 return emitOpError() <<
"type of matrix C and matrix D must be the same";
598 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
599 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
601 <<
"has matrices A, B, C and D, they must be 2 dimensional";
604 if (matrixA.getShape()[1] != matrixB.getShape()[0])
605 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
606 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
608 if (matrixA.getShape()[0] != matrixC.getShape()[0])
609 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
610 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
612 if (matrixB.getShape()[1] != matrixC.getShape()[1])
613 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
614 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
618 matrixA.getElementType(),
619 matrixB.getElementType())))
621 <<
" += " << matrixA.getElementType() <<
" * "
622 << matrixB.getElementType()
623 <<
", it is not supported.";
626 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
627 << matrixB.getDimSize(1) <<
", it is not supported";
631 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
632 !matrixA.getElementType().isBF16()) {
633 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
634 <<
" += " << matrixA.getElementType() <<
" * "
635 << matrixB.getElementType()
636 <<
", it is not supported yet";
642LogicalResult WarpgroupMmaStoreOp::verify() {
643 MemRefType dstMemrefType = getDstMemref().getType();
644 VectorType vtype = getMatrixD().getType().getFragmented();
647 if (!vtype.getElementType().isF32()) {
649 <<
"hit a limitation: only f32 results for the time being";
651 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
652 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
653 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
654 <<
"] values. However, destination memref["
655 << dstMemrefType.getDimSize(0) <<
"]["
656 << dstMemrefType.getDimSize(1)
657 <<
"] does not have same size as results";
666LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
667 WarpgroupAccumulatorType accType = getMatrixC().getType();
668 int64_t sizeM = accType.getFragmented().getDimSize(0);
669 int64_t sizeN = accType.getFragmented().getDimSize(1);
670 Type elemType = accType.getFragmented().getElementType();
674 return emitOpError() <<
"has type " << accType.getFragmented()
675 <<
". It does not fit into warp-group "
676 "level (wgmma) matrix multiplication instruction "
677 "(or not supported yet)";
686LogicalResult RcpOp::verify() {
687 RcpRoundingModeAttr rounding = getRoundingAttr();
690 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
691 return emitOpError() <<
"has a limitation. " << rounding
692 <<
" or non-ftz is not supported yet.";
701#define GET_ATTRDEF_CLASSES
702#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
704#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
706#define GET_OP_CLASSES
707#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
709#define GET_TYPEDEF_CLASSES
710#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
This represents an operation in an abstracted form, suitable for use with the builder APIs.