26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
33 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
35 void nvgpu::NVGPUDialect::initialize() {
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
46 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
50 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
53 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
54 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
55 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
56 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
60 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
61 Attribute memorySpace = type.getMemorySpace();
62 return isSharedMemoryAddressSpace(memorySpace);
70 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
71 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
73 if (!srcMemref.isLastDimUnitStride())
74 return emitError(
"source memref most minor dim must have unit stride");
75 if (!dstMemref.isLastDimUnitStride())
76 return emitError(
"destination memref most minor dim must have unit stride");
77 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
79 <<
"destination memref must have a memory space attribute of "
81 << NVGPUDialect::kSharedMemoryAddressSpace
82 <<
") or gpu::AddressSpaceAttr(Workgroup)";
83 if (dstMemref.getElementType() != srcMemref.getElementType())
84 return emitError(
"source and destination must have the same element type");
85 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
86 return emitOpError() <<
"expected " << srcMemref.getRank()
87 <<
" source indices, got " << getSrcIndices().size();
88 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
89 return emitOpError() <<
"expected " << dstMemref.getRank()
90 <<
" destination indices, got "
91 << getDstIndices().size();
92 int64_t dstElements = getDstElements().getZExtValue();
93 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
94 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
95 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
97 diag <<
"Requested copy elements is " << dstElements <<
" with width "
98 << dstMemref.getElementTypeBitWidth()
99 <<
". But copy elements could be one of ";
100 if ((32 / dstWidth) > 0)
101 diag << (32 / dstWidth) <<
", ";
102 if ((64 / dstWidth) > 0)
103 diag << (64 / dstWidth) <<
", ";
104 if ((128 / dstWidth) > 0)
105 diag << (128 / dstWidth) <<
".";
108 if (getBypassL1().has_value()) {
109 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
110 if (getBypassL1().value() && sizeInBytes != 16) {
111 return emitOpError() <<
"bypassL1 does not satify alignment for "
112 << dstMemref <<
" with destination element "
114 <<
". Unset bypassL1, or set "
115 "destination element to "
127 Value matrixB,
Value matrixC, ArrayAttr mmaShape) {
128 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
129 mmaShape, UnitAttr());
136 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
138 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
146 const std::array<int64_t, 3> &mmaShape,
147 bool tf32Enabled,
bool sparse =
false) {
163 int64_t numElementC{2};
166 auto aVector = matrixA.getType();
167 auto bVector = matrixB.getType();
168 auto cVector = matrixC.getType();
176 Type aType = aVector.getElementType();
179 if (sparse && aType.
isF64())
180 return op->
emitError() <<
"f64 is not supported for sparse mode";
191 shapeK = 128 / operandBitwidth;
193 numElementA = 32 / operandBitwidth;
194 numElementB = 32 / operandBitwidth;
197 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
206 if (aShape.size() != 2) {
207 return op->
emitError() <<
"matrixA must be 2 dimensional vector";
210 if (bShape.size() != 2) {
211 return op->
emitError() <<
"matrixB must be 2 dimensional vector";
214 if (cShape.size() != 2) {
215 return op->
emitError() <<
"matrixC must be 2 dimensional vector";
218 auto [m, n, k] = mmaShape;
221 int64_t sparseFactor = sparse ? 2 : 1;
222 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
224 <<
"expected " << m * k <<
" warp-wide matrix A elements";
227 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
229 <<
"expected " << k * n <<
" warp-wide matrix B elements";
232 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
234 <<
"expected " << m * n <<
" warp-wide matrix C elements";
237 if (tf32Enabled && !(aType.
isF32()))
239 <<
"expected tf32 tensor cores only for F32 operands";
246 int64_t mTile = m / shapeM;
247 int64_t nTile = n / shapeN;
248 int64_t kTile = k / shapeK;
251 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
252 (aShape[1] != numElementA))
253 return op->
emitOpError() <<
"expected matrix A to be shaped ("
254 << mTile * kTile <<
" x " << numElementA <<
")";
257 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
258 return op->
emitOpError() <<
"expected matrix B to be shaped ("
259 << kTile * nTile <<
" x " << numElementB <<
")";
262 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
263 return op->
emitOpError() <<
"expected matrix C to be shaped ("
264 << mTile * nTile <<
" x " << numElementC <<
")";
270 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
271 getMatrixC(), getMmaShapeAsArray(),
272 getOperation()->hasAttr(getTf32EnabledAttrName()));
282 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
287 unsigned sparsitySelector = getSparsitySelector();
288 if (sparsitySelector > 1)
289 return emitOpError() <<
"sparsity selector should be 0 or 1";
290 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
291 getMatrixC(), getMmaShapeAsArray(),
292 getOperation()->hasAttr(getTf32EnabledAttrName()),
302 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
305 auto resVector = llvm::cast<VectorType>(getRes().
getType());
309 Type resType = resVector.getElementType();
313 int64_t numElementsPer32b = 32 / elementBitWidth;
316 int64_t numTiles = getNumTiles();
319 bool isTranspose = getTranspose();
325 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
327 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
328 "attribute of IntegerAttr("
329 << NVGPUDialect::kSharedMemoryAddressSpace
330 <<
") or gpu::AddressSpaceAttr(Workgroup)";
331 if (elementBitWidth > 32)
332 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
333 if (isTranspose && !(elementBitWidth == 16))
335 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
336 if (resShape.size() != 2) {
337 return emitError() <<
"results must be 2 dimensional vector";
339 if (!(resShape[1] == numElementsPer32b))
340 return emitError() <<
"expected vector register shape[1] = "
341 << numElementsPer32b;
342 if (!(resShape[0] == numTiles))
344 <<
"expected vector register shape[0] and numTiles to match";
354 Operation *op, nvgpu::TensorMapDescriptorType descType,
355 std::optional<MemRefType> memrefType = std::nullopt) {
356 MemRefType descMemref = descType.getTensor();
358 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
359 return op->
emitError() <<
"Interleave options are not supported yet.";
362 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
363 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
364 "space, it must be shared memory address space.";
367 if (!descMemref.hasStaticShape())
368 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
370 for (
auto dim : descMemref.getShape()) {
372 return op->
emitError() <<
"the tensor map descriptor must have "
373 "dimensions between 1 and "
377 if (descMemref.getRank() > 1 &&
378 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
379 unsigned lastDimensionByte =
380 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
382 return op->
emitError() <<
"the tensormap descriptor must have last "
385 << lastDimensionByte <<
" bytes";
389 if (!memrefType.has_value())
392 MemRefType dstMemref = memrefType.value();
395 if (descMemref.getElementType() != dstMemref.getElementType()) {
396 return op->
emitError() <<
"the element type of tensor map descriptor and "
397 "memref must be same";
400 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
401 return op->
emitError() <<
"the destination memref has incorrect address "
402 "space, it must be shared memory address space.";
404 if (!dstMemref.hasStaticShape())
405 return op->
emitError() <<
"the destination memref must be static shaped";
407 if (dstMemref.getRank() != descMemref.getRank()) {
408 return op->
emitError() <<
"the shape of tensor map descriptor and "
409 "memref must have same rank";
411 if (!descMemref.getShape().equals(dstMemref.getShape())) {
412 return op->
emitError() <<
"memref and tensor map shapes mismatch "
413 << descMemref <<
" != " << dstMemref;
422 if (error.has_value())
423 return error.value();
427 <<
" coordinates are supported.";
430 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
431 return emitError() <<
"number of coordinates do not match with the rank of "
432 "tensor descriptor map.";
445 if (error.has_value())
446 return error.value();
450 <<
" coordinates are supported.";
453 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
454 return emitError() <<
"number of coordinates do not match with the rank of "
455 "tensor descriptor map.";
464 <<
" coordinates are supported.";
467 std::optional<InFlightDiagnostic> error =
469 if (error.has_value())
470 return error.value();
480 std::optional<InFlightDiagnostic> error =
482 if (error.has_value())
483 return error.value();
485 if (getTensorMap().
getType().getSwizzle() !=
486 TensorMapSwizzleKind::SWIZZLE_128B) {
488 << stringifyTensorMapSwizzleKind(
489 TensorMapSwizzleKind::SWIZZLE_128B)
490 <<
" is supported for the time being";
493 if (getTensorMap().
getType().getInterleave() !=
494 TensorMapInterleaveKind::INTERLEAVE_NONE) {
496 << stringifyTensorMapInterleaveKind(
497 TensorMapInterleaveKind::INTERLEAVE_NONE)
498 <<
" is supported for the time being";
528 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
529 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
544 72, 80, 88, 96, 104, 112, 120, 128,
545 136, 144, 152, 160, 168, 176, 184, 192,
546 200, 208, 216, 224, 232, 240, 248, 256};
548 80, 96, 112, 128, 144, 160,
549 176, 192, 208, 224, 240, 256};
551 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
552 if (llvm::is_contained(allowedN, sizeN))
556 if (llvm::is_contained(allowedNshort, sizeN))
562 if (getTransposeA() && !getTransposeB())
564 <<
"supports non-transpose A (Row Major) "
565 "and transpose B (Column Major) for the time being ";
566 MemRefType matrixA = getDescriptorA().
getType().getTensor();
567 MemRefType matrixB = getDescriptorB().
getType().getTensor();
568 VectorType matrixC = getMatrixC().
getType().getFragmented();
569 VectorType matrixD = getMatrixD().getType().getFragmented();
571 if (matrixC != matrixD)
572 return emitOpError() <<
"type of matrix C and matrix D must be the same";
574 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
575 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
577 <<
"has matrices A, B, C and D, they must be 2 dimensional";
580 if (matrixA.getShape()[1] != matrixB.getShape()[0])
581 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
582 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
584 if (matrixA.getShape()[0] != matrixC.getShape()[0])
585 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
586 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
588 if (matrixB.getShape()[1] != matrixC.getShape()[1])
589 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
590 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
594 matrixA.getElementType(),
595 matrixB.getElementType())))
596 return emitOpError() << matrixC.getElementType()
597 <<
" += " << matrixA.getElementType() <<
" * "
598 << matrixB.getElementType()
599 <<
", it is not supported.";
601 if (failed(
isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
602 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
603 << matrixB.getDimSize(1) <<
", it is not supported";
607 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
608 !matrixA.getElementType().isBF16()) {
609 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
610 <<
" += " << matrixA.getElementType() <<
" * "
611 << matrixB.getElementType()
612 <<
", it is not supported yet";
619 MemRefType dstMemrefType = getDstMemref().getType();
620 VectorType vtype = getMatrixD().getType().getFragmented();
623 if (!vtype.getElementType().isF32()) {
625 <<
"hit a limitation: only f32 results for the time being";
627 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
628 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
629 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
630 <<
"] values. However, destination memref["
631 << dstMemrefType.getDimSize(0) <<
"]["
632 << dstMemrefType.getDimSize(1)
633 <<
"] does not have same size as results";
644 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
645 int64_t sizeM = accType.getFragmented().getDimSize(0);
646 int64_t sizeN = accType.getFragmented().getDimSize(1);
647 Type elemType = accType.getFragmented().getElementType();
651 return emitOpError() <<
"has type " << accType.getFragmented()
652 <<
". It does not fit into warp-group "
653 "level (wgmma) matrix multiplication instruction "
654 "(or not supported yet)";
664 RcpRoundingModeAttr rounding = getRoundingAttr();
667 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
668 return emitOpError() <<
"has a limitation. " << rounding
669 <<
" or non-ftz is not supported yet.";
678 #define GET_ATTRDEF_CLASSES
679 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
681 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
683 #define GET_OP_CLASSES
684 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
686 #define GET_TYPEDEF_CLASSES
687 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, nvgpu::TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
constexpr unsigned kMaxTMALastdimByte
Last dimension of 2D+ TMA must be 128 bytes.
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.