22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
28 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
30 void NVGPUDialect::initialize() {
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
36 #define GET_ATTRDEF_LIST
37 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
41 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
45 bool NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
48 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
55 bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
65 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
66 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError(
"source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError(
"destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
74 <<
"destination memref must have a memory space attribute of "
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 <<
") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError(
"source and destination must have the same element type");
80 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() <<
"expected " << srcMemref.getRank()
82 <<
" source indices, got " << getSrcIndices().size();
83 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() <<
"expected " << dstMemref.getRank()
85 <<
" destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92 diag <<
"Requested copy elements is " << dstElements <<
" with width "
93 << dstMemref.getElementTypeBitWidth()
94 <<
". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) <<
", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) <<
", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) <<
".";
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() <<
"bypassL1 does not satify alignment for "
107 << dstMemref <<
" with destination element "
109 <<
". Unset bypassL1, or set "
110 "destination element to "
122 Value matrixB,
Value matrixC, ArrayAttr mmaShape) {
123 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
124 mmaShape, UnitAttr());
131 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
133 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled,
bool sparse =
false) {
157 int64_t numElementC{2};
160 auto aVector = matrixA.getType();
161 auto bVector = matrixB.getType();
162 auto cVector = matrixC.getType();
170 Type aType = aVector.getElementType();
173 if (sparse && aType.
isF64())
174 return op->
emitError() <<
"f64 is not supported for sparse mode";
185 shapeK = 128 / operandBitwidth;
187 numElementA = 32 / operandBitwidth;
188 numElementB = 32 / operandBitwidth;
191 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
200 if (aShape.size() != 2) {
201 return op->
emitError() <<
"matrixA must be 2 dimensional vector";
204 if (bShape.size() != 2) {
205 return op->
emitError() <<
"matrixB must be 2 dimensional vector";
208 if (cShape.size() != 2) {
209 return op->
emitError() <<
"matrixC must be 2 dimensional vector";
212 auto [m, n, k] = mmaShape;
215 int64_t sparseFactor = sparse ? 2 : 1;
216 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
218 <<
"expected " << m * k <<
" warp-wide matrix A elements";
221 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
223 <<
"expected " << k * n <<
" warp-wide matrix B elements";
226 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
228 <<
"expected " << m * n <<
" warp-wide matrix C elements";
231 if (tf32Enabled && !(aType.
isF32()))
233 <<
"expected tf32 tensor cores only for F32 operands";
240 int64_t mTile = m / shapeM;
241 int64_t nTile = n / shapeN;
242 int64_t kTile = k / shapeK;
245 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
246 (aShape[1] != numElementA))
247 return op->
emitOpError() <<
"expected matrix A to be shaped ("
248 << mTile * kTile <<
" x " << numElementA <<
")";
251 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
252 return op->
emitOpError() <<
"expected matrix B to be shaped ("
253 << kTile * nTile <<
" x " << numElementB <<
")";
256 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
257 return op->
emitOpError() <<
"expected matrix C to be shaped ("
258 << mTile * nTile <<
" x " << numElementC <<
")";
264 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
265 getMatrixC(), getMmaShapeAsArray(),
266 getOperation()->hasAttr(getTf32EnabledAttrName()));
276 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
281 unsigned sparsitySelector = getSparsitySelector();
282 if (sparsitySelector > 1)
283 return emitOpError() <<
"sparsity selector should be 0 or 1";
284 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
285 getMatrixC(), getMmaShapeAsArray(),
286 getOperation()->hasAttr(getTf32EnabledAttrName()),
295 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
298 auto resVector = llvm::cast<VectorType>(getRes().
getType());
302 Type resType = resVector.getElementType();
306 int64_t numElementsPer32b = 32 / elementBitWidth;
309 int64_t numTiles = getNumTiles();
312 bool isTranspose = getTranspose();
318 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
320 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
321 "attribute of IntegerAttr("
322 << NVGPUDialect::kSharedMemoryAddressSpace
323 <<
") or gpu::AddressSpaceAttr(Workgroup)";
324 if (elementBitWidth > 32)
325 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
326 if (isTranspose && !(elementBitWidth == 16))
328 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
329 if (resShape.size() != 2) {
330 return emitError() <<
"results must be 2 dimensional vector";
332 if (!(resShape[1] == numElementsPer32b))
333 return emitError() <<
"expected vector register shape[1] = "
334 << numElementsPer32b;
335 if (!(resShape[0] == numTiles))
337 <<
"expected vector register shape[0] and numTiles to match";
348 case TensorMapSwizzleKind::SWIZZLE_32B:
350 case TensorMapSwizzleKind::SWIZZLE_64B:
352 case TensorMapSwizzleKind::SWIZZLE_128B:
360 Operation *op, TensorMapDescriptorType descType,
361 std::optional<MemRefType> memrefType = std::nullopt) {
362 MemRefType descMemref = descType.getTensor();
364 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
365 return op->
emitError() <<
"Interleave options are not supported yet.";
368 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
369 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
370 "space, it must be shared memory address space.";
373 if (!descMemref.hasStaticShape())
374 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
376 for (
auto dim : descMemref.getShape()) {
378 return op->
emitError() <<
"the tensor map descriptor must have "
379 "dimensions between 1 and "
383 if (descMemref.getRank() > 1 &&
384 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
385 unsigned lastDimensionByte =
386 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
388 if (lastDimensionByte != expectByte)
389 return op->
emitError() <<
"the tensormap descriptor must have last "
391 << expectByte <<
" bytes but it is "
392 << lastDimensionByte <<
" bytes";
396 if (!memrefType.has_value())
399 MemRefType dstMemref = memrefType.value();
402 if (descMemref.getElementType() != dstMemref.getElementType()) {
403 return op->
emitError() <<
"the element type of tensor map descriptor and "
404 "memref must be same";
407 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
408 return op->
emitError() <<
"the destination memref has incorrect address "
409 "space, it must be shared memory address space.";
411 if (!dstMemref.hasStaticShape())
412 return op->
emitError() <<
"the destination memref must be static shaped";
414 if (dstMemref.getRank() != descMemref.getRank()) {
415 return op->
emitError() <<
"the shape of tensor map descriptor and "
416 "memref must have same rank";
418 if (!descMemref.getShape().equals(dstMemref.getShape())) {
419 return op->
emitError() <<
"memref and tensor map shapes mismatch "
420 << descMemref <<
" != " << dstMemref;
424 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
426 return op->
emitError() <<
"the bytes in the last dimension of the tensor "
427 "map must be a multiple of 16";
435 if (error.has_value())
436 return error.value();
440 <<
" coordinates are supported.";
443 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
444 return emitError() <<
"number of coordinates do not match with the rank of "
445 "tensor descriptor map.";
458 if (error.has_value())
459 return error.value();
463 <<
" coordinates are supported.";
466 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
467 return emitError() <<
"number of coordinates do not match with the rank of "
468 "tensor descriptor map.";
477 <<
" coordinates are supported.";
480 std::optional<InFlightDiagnostic> error =
482 if (error.has_value())
483 return error.value();
493 std::optional<InFlightDiagnostic> error =
495 if (error.has_value())
496 return error.value();
498 if (getTensorMap().
getType().getSwizzle() !=
499 TensorMapSwizzleKind::SWIZZLE_128B) {
501 << stringifyTensorMapSwizzleKind(
502 TensorMapSwizzleKind::SWIZZLE_128B)
503 <<
" is supported for the time being";
506 if (getTensorMap().
getType().getInterleave() !=
507 TensorMapInterleaveKind::INTERLEAVE_NONE) {
509 << stringifyTensorMapInterleaveKind(
510 TensorMapInterleaveKind::INTERLEAVE_NONE)
511 <<
" is supported for the time being";
541 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
542 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
557 72, 80, 88, 96, 104, 112, 120, 128,
558 136, 144, 152, 160, 168, 176, 184, 192,
559 200, 208, 216, 224, 232, 240, 248, 256};
561 80, 96, 112, 128, 144, 160,
562 176, 192, 208, 224, 240, 256};
564 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
565 if (llvm::is_contained(allowedN, sizeN))
569 if (llvm::is_contained(allowedNshort, sizeN))
575 if (getTransposeA() && !getTransposeB())
577 <<
"supports non-transpose A (Row Major) "
578 "and transpose B (Column Major) for the time being ";
579 MemRefType matrixA = getDescriptorA().
getType().getTensor();
580 MemRefType matrixB = getDescriptorB().
getType().getTensor();
581 VectorType matrixC = getMatrixC().
getType().getFragmented();
582 VectorType matrixD = getMatrixD().getType().getFragmented();
584 if (matrixC != matrixD)
585 return emitOpError() <<
"type of matrix C and matrix D must be the same";
587 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
588 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
590 <<
"has matrices A, B, C and D, they must be 2 dimensional";
593 if (matrixA.getShape()[1] != matrixB.getShape()[0])
594 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
595 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
597 if (matrixA.getShape()[0] != matrixC.getShape()[0])
598 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
599 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
601 if (matrixB.getShape()[1] != matrixC.getShape()[1])
602 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
603 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
607 matrixA.getElementType(),
608 matrixB.getElementType())))
609 return emitOpError() << matrixC.getElementType()
610 <<
" += " << matrixA.getElementType() <<
" * "
611 << matrixB.getElementType()
612 <<
", it is not supported.";
615 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
616 << matrixB.getDimSize(1) <<
", it is not supported";
620 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
621 !matrixA.getElementType().isBF16()) {
622 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
623 <<
" += " << matrixA.getElementType() <<
" * "
624 << matrixB.getElementType()
625 <<
", it is not supported yet";
632 MemRefType dstMemrefType = getDstMemref().getType();
633 VectorType vtype = getMatrixD().getType().getFragmented();
636 if (!vtype.getElementType().isF32()) {
638 <<
"hit a limitation: only f32 results for the time being";
640 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
641 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
642 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
643 <<
"] values. However, destination memref["
644 << dstMemrefType.getDimSize(0) <<
"]["
645 << dstMemrefType.getDimSize(1)
646 <<
"] does not have same size as results";
656 WarpgroupAccumulatorType accType = getMatrixC().getType();
657 int64_t sizeM = accType.getFragmented().getDimSize(0);
658 int64_t sizeN = accType.getFragmented().getDimSize(1);
659 Type elemType = accType.getFragmented().getElementType();
663 return emitOpError() <<
"has type " << accType.getFragmented()
664 <<
". It does not fit into warp-group "
665 "level (wgmma) matrix multiplication instruction "
666 "(or not supported yet)";
676 RcpRoundingModeAttr rounding = getRoundingAttr();
679 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
680 return emitOpError() <<
"has a limitation. " << rounding
681 <<
" or non-ftz is not supported yet.";
690 #define GET_ATTRDEF_CLASSES
691 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
693 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
695 #define GET_OP_CLASSES
696 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
698 #define GET_TYPEDEF_CLASSES
699 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
union mlir::linalg::@1252::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.