22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
28 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
30 void nvgpu::NVGPUDialect::initialize() {
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
36 #define GET_ATTRDEF_LIST
37 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
41 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
45 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
48 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
55 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
65 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
66 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError(
"source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError(
"destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
74 <<
"destination memref must have a memory space attribute of "
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 <<
") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError(
"source and destination must have the same element type");
80 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() <<
"expected " << srcMemref.getRank()
82 <<
" source indices, got " << getSrcIndices().size();
83 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() <<
"expected " << dstMemref.getRank()
85 <<
" destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92 diag <<
"Requested copy elements is " << dstElements <<
" with width "
93 << dstMemref.getElementTypeBitWidth()
94 <<
". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) <<
", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) <<
", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) <<
".";
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() <<
"bypassL1 does not satify alignment for "
107 << dstMemref <<
" with destination element "
109 <<
". Unset bypassL1, or set "
110 "destination element to "
122 Value matrixB,
Value matrixC, ArrayAttr mmaShape) {
123 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
124 mmaShape, UnitAttr());
131 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
133 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled,
bool sparse =
false) {
158 int64_t numElementC{2};
161 auto aVector = matrixA.getType();
162 auto bVector = matrixB.getType();
163 auto cVector = matrixC.getType();
171 Type aType = aVector.getElementType();
174 if (sparse && aType.
isF64())
175 return op->
emitError() <<
"f64 is not supported for sparse mode";
186 shapeK = 128 / operandBitwidth;
188 numElementA = 32 / operandBitwidth;
189 numElementB = 32 / operandBitwidth;
192 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
201 if (aShape.size() != 2) {
202 return op->
emitError() <<
"matrixA must be 2 dimensional vector";
205 if (bShape.size() != 2) {
206 return op->
emitError() <<
"matrixB must be 2 dimensional vector";
209 if (cShape.size() != 2) {
210 return op->
emitError() <<
"matrixC must be 2 dimensional vector";
213 auto [m, n, k] = mmaShape;
216 int64_t sparseFactor = sparse ? 2 : 1;
217 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
219 <<
"expected " << m * k <<
" warp-wide matrix A elements";
222 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
224 <<
"expected " << k * n <<
" warp-wide matrix B elements";
227 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
229 <<
"expected " << m * n <<
" warp-wide matrix C elements";
232 if (tf32Enabled && !(aType.
isF32()))
234 <<
"expected tf32 tensor cores only for F32 operands";
241 int64_t mTile = m / shapeM;
242 int64_t nTile = n / shapeN;
243 int64_t kTile = k / shapeK;
246 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
247 (aShape[1] != numElementA))
248 return op->
emitOpError() <<
"expected matrix A to be shaped ("
249 << mTile * kTile <<
" x " << numElementA <<
")";
252 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
253 return op->
emitOpError() <<
"expected matrix B to be shaped ("
254 << kTile * nTile <<
" x " << numElementB <<
")";
257 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
258 return op->
emitOpError() <<
"expected matrix C to be shaped ("
259 << mTile * nTile <<
" x " << numElementC <<
")";
265 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
266 getMatrixC(), getMmaShapeAsArray(),
267 getOperation()->hasAttr(getTf32EnabledAttrName()));
277 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
282 unsigned sparsitySelector = getSparsitySelector();
283 if (sparsitySelector > 1)
284 return emitOpError() <<
"sparsity selector should be 0 or 1";
285 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
286 getMatrixC(), getMmaShapeAsArray(),
287 getOperation()->hasAttr(getTf32EnabledAttrName()),
297 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
300 auto resVector = llvm::cast<VectorType>(getRes().
getType());
304 Type resType = resVector.getElementType();
308 int64_t numElementsPer32b = 32 / elementBitWidth;
311 int64_t numTiles = getNumTiles();
314 bool isTranspose = getTranspose();
320 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
322 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
323 "attribute of IntegerAttr("
324 << NVGPUDialect::kSharedMemoryAddressSpace
325 <<
") or gpu::AddressSpaceAttr(Workgroup)";
326 if (elementBitWidth > 32)
327 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
328 if (isTranspose && !(elementBitWidth == 16))
330 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
331 if (resShape.size() != 2) {
332 return emitError() <<
"results must be 2 dimensional vector";
334 if (!(resShape[1] == numElementsPer32b))
335 return emitError() <<
"expected vector register shape[1] = "
336 << numElementsPer32b;
337 if (!(resShape[0] == numTiles))
339 <<
"expected vector register shape[0] and numTiles to match";
350 case TensorMapSwizzleKind::SWIZZLE_32B:
352 case TensorMapSwizzleKind::SWIZZLE_64B:
354 case TensorMapSwizzleKind::SWIZZLE_128B:
362 Operation *op, nvgpu::TensorMapDescriptorType descType,
363 std::optional<MemRefType> memrefType = std::nullopt) {
364 MemRefType descMemref = descType.getTensor();
366 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
367 return op->
emitError() <<
"Interleave options are not supported yet.";
370 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
371 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
372 "space, it must be shared memory address space.";
375 if (!descMemref.hasStaticShape())
376 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
378 for (
auto dim : descMemref.getShape()) {
380 return op->
emitError() <<
"the tensor map descriptor must have "
381 "dimensions between 1 and "
385 if (descMemref.getRank() > 1 &&
386 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
387 unsigned lastDimensionByte =
388 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
390 if (lastDimensionByte != expectByte)
391 return op->
emitError() <<
"the tensormap descriptor must have last "
393 << expectByte <<
" bytes but it is "
394 << lastDimensionByte <<
" bytes";
398 if (!memrefType.has_value())
401 MemRefType dstMemref = memrefType.value();
404 if (descMemref.getElementType() != dstMemref.getElementType()) {
405 return op->
emitError() <<
"the element type of tensor map descriptor and "
406 "memref must be same";
409 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
410 return op->
emitError() <<
"the destination memref has incorrect address "
411 "space, it must be shared memory address space.";
413 if (!dstMemref.hasStaticShape())
414 return op->
emitError() <<
"the destination memref must be static shaped";
416 if (dstMemref.getRank() != descMemref.getRank()) {
417 return op->
emitError() <<
"the shape of tensor map descriptor and "
418 "memref must have same rank";
420 if (!descMemref.getShape().equals(dstMemref.getShape())) {
421 return op->
emitError() <<
"memref and tensor map shapes mismatch "
422 << descMemref <<
" != " << dstMemref;
426 descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
427 if (lastDimBytes % 16 != 0) {
428 return op->
emitError() <<
"the bytes in the last dimension of the tensor "
429 "map must be a multiple of 16";
437 if (error.has_value())
438 return error.value();
442 <<
" coordinates are supported.";
445 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
446 return emitError() <<
"number of coordinates do not match with the rank of "
447 "tensor descriptor map.";
460 if (error.has_value())
461 return error.value();
465 <<
" coordinates are supported.";
468 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
469 return emitError() <<
"number of coordinates do not match with the rank of "
470 "tensor descriptor map.";
479 <<
" coordinates are supported.";
482 std::optional<InFlightDiagnostic> error =
484 if (error.has_value())
485 return error.value();
495 std::optional<InFlightDiagnostic> error =
497 if (error.has_value())
498 return error.value();
500 if (getTensorMap().
getType().getSwizzle() !=
501 TensorMapSwizzleKind::SWIZZLE_128B) {
503 << stringifyTensorMapSwizzleKind(
504 TensorMapSwizzleKind::SWIZZLE_128B)
505 <<
" is supported for the time being";
508 if (getTensorMap().
getType().getInterleave() !=
509 TensorMapInterleaveKind::INTERLEAVE_NONE) {
511 << stringifyTensorMapInterleaveKind(
512 TensorMapInterleaveKind::INTERLEAVE_NONE)
513 <<
" is supported for the time being";
543 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
544 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
559 72, 80, 88, 96, 104, 112, 120, 128,
560 136, 144, 152, 160, 168, 176, 184, 192,
561 200, 208, 216, 224, 232, 240, 248, 256};
563 80, 96, 112, 128, 144, 160,
564 176, 192, 208, 224, 240, 256};
566 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
567 if (llvm::is_contained(allowedN, sizeN))
571 if (llvm::is_contained(allowedNshort, sizeN))
577 if (getTransposeA() && !getTransposeB())
579 <<
"supports non-transpose A (Row Major) "
580 "and transpose B (Column Major) for the time being ";
581 MemRefType matrixA = getDescriptorA().
getType().getTensor();
582 MemRefType matrixB = getDescriptorB().
getType().getTensor();
583 VectorType matrixC = getMatrixC().
getType().getFragmented();
584 VectorType matrixD = getMatrixD().getType().getFragmented();
586 if (matrixC != matrixD)
587 return emitOpError() <<
"type of matrix C and matrix D must be the same";
589 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
590 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
592 <<
"has matrices A, B, C and D, they must be 2 dimensional";
595 if (matrixA.getShape()[1] != matrixB.getShape()[0])
596 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
597 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
599 if (matrixA.getShape()[0] != matrixC.getShape()[0])
600 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
601 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
603 if (matrixB.getShape()[1] != matrixC.getShape()[1])
604 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
605 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
609 matrixA.getElementType(),
610 matrixB.getElementType())))
611 return emitOpError() << matrixC.getElementType()
612 <<
" += " << matrixA.getElementType() <<
" * "
613 << matrixB.getElementType()
614 <<
", it is not supported.";
617 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
618 << matrixB.getDimSize(1) <<
", it is not supported";
622 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
623 !matrixA.getElementType().isBF16()) {
624 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
625 <<
" += " << matrixA.getElementType() <<
" * "
626 << matrixB.getElementType()
627 <<
", it is not supported yet";
634 MemRefType dstMemrefType = getDstMemref().getType();
635 VectorType vtype = getMatrixD().getType().getFragmented();
638 if (!vtype.getElementType().isF32()) {
640 <<
"hit a limitation: only f32 results for the time being";
642 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
643 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
644 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
645 <<
"] values. However, destination memref["
646 << dstMemrefType.getDimSize(0) <<
"]["
647 << dstMemrefType.getDimSize(1)
648 <<
"] does not have same size as results";
659 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
660 int64_t sizeM = accType.getFragmented().getDimSize(0);
661 int64_t sizeN = accType.getFragmented().getDimSize(1);
662 Type elemType = accType.getFragmented().getElementType();
666 return emitOpError() <<
"has type " << accType.getFragmented()
667 <<
". It does not fit into warp-group "
668 "level (wgmma) matrix multiplication instruction "
669 "(or not supported yet)";
679 RcpRoundingModeAttr rounding = getRoundingAttr();
682 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
683 return emitOpError() <<
"has a limitation. " << rounding
684 <<
" or non-ftz is not supported yet.";
693 #define GET_ATTRDEF_CLASSES
694 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
696 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
698 #define GET_OP_CLASSES
699 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
701 #define GET_TYPEDEF_CLASSES
702 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, nvgpu::TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.