26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
33 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
35 void nvgpu::NVGPUDialect::initialize() {
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
46 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
50 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
Attribute memorySpace) {
53 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
54 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
55 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
56 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
60 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
61 Attribute memorySpace = type.getMemorySpace();
62 return isSharedMemoryAddressSpace(memorySpace);
70 auto srcMemref = llvm::cast<MemRefType>(getSrc().
getType());
71 auto dstMemref = llvm::cast<MemRefType>(getDst().
getType());
74 return emitError(
"source memref most minor dim must have unit stride");
76 return emitError(
"destination memref most minor dim must have unit stride");
77 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
79 <<
"destination memref must have a memory space attribute of "
81 << NVGPUDialect::kSharedMemoryAddressSpace
82 <<
") or gpu::AddressSpaceAttr(Workgroup)";
83 if (dstMemref.getElementType() != srcMemref.getElementType())
84 return emitError(
"source and destination must have the same element type");
85 if (
size_t(srcMemref.getRank()) != getSrcIndices().size())
86 return emitOpError() <<
"expected " << srcMemref.getRank()
87 <<
" source indices, got " << getSrcIndices().size();
88 if (
size_t(dstMemref.getRank()) != getDstIndices().size())
89 return emitOpError() <<
"expected " << dstMemref.getRank()
90 <<
" destination indices, got "
91 << getDstIndices().size();
92 int64_t dstElements = getDstElements().getZExtValue();
93 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
94 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
95 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
97 diag <<
"Requested copy elements is " << dstElements <<
" with width "
98 << dstMemref.getElementTypeBitWidth()
99 <<
". But copy elements could be one of ";
100 if ((32 / dstWidth) > 0)
101 diag << (32 / dstWidth) <<
", ";
102 if ((64 / dstWidth) > 0)
103 diag << (64 / dstWidth) <<
", ";
104 if ((128 / dstWidth) > 0)
105 diag << (128 / dstWidth) <<
".";
108 if (getBypassL1().has_value()) {
109 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
110 if (getBypassL1().value() && sizeInBytes != 16) {
111 return emitOpError() <<
"bypassL1 does not satify alignment for "
112 << dstMemref <<
" with destination element "
114 <<
". Unset bypassL1, or set "
115 "destination element to "
127 Value matrixB,
Value matrixC, ArrayAttr mmaShape) {
128 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
129 mmaShape, UnitAttr());
136 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
138 tf32Enabled ? odsBuilder.
getUnitAttr() : UnitAttr());
146 const std::array<int64_t, 3> &mmaShape,
147 bool tf32Enabled,
bool sparse =
false) {
163 int64_t numElementC{2};
166 auto aVector = matrixA.getType();
167 auto bVector = matrixB.getType();
168 auto cVector = matrixC.getType();
176 Type aType = aVector.getElementType();
179 if (sparse && aType.
isF64())
180 return op->
emitError() <<
"f64 is not supported for sparse mode";
191 shapeK = 128 / operandBitwidth;
193 numElementA = 32 / operandBitwidth;
194 numElementB = 32 / operandBitwidth;
197 <<
"expected input data type (i4,i8,f16,bf16,tf32,f64) "
206 auto [m, n, k] = mmaShape;
209 int64_t sparseFactor = sparse ? 2 : 1;
210 if (aShape[0] * aShape[1] *
kWarpSize != m * k / sparseFactor)
212 <<
"expected " << m * k <<
" warp-wide matrix A elements";
215 if (bShape[0] * bShape[1] *
kWarpSize != k * n)
217 <<
"expected " << k * n <<
" warp-wide matrix B elements";
220 if (cShape[0] * cShape[1] *
kWarpSize != m * n)
222 <<
"expected " << m * n <<
" warp-wide matrix C elements";
225 if (tf32Enabled && !(aType.
isF32()))
227 <<
"expected tf32 tensor cores only for F32 operands";
234 int64_t mTile = m / shapeM;
235 int64_t nTile = n / shapeN;
236 int64_t kTile = k / shapeK;
239 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
240 (aShape[1] != numElementA))
241 return op->
emitOpError() <<
"expected matrix A to be shaped ("
242 << mTile * kTile <<
" x " << numElementA <<
")";
245 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
246 return op->
emitOpError() <<
"expected matrix B to be shaped ("
247 << kTile * nTile <<
" x " << numElementB <<
")";
250 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
251 return op->
emitOpError() <<
"expected matrix C to be shaped ("
252 << mTile * nTile <<
" x " << numElementC <<
")";
258 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
259 getMatrixC(), getMmaShapeAsArray(),
260 getOperation()->hasAttr(getTf32EnabledAttrName()));
270 build(odsBuilder, odsState, matrixC.
getType(), matrixA, matrixB, matrixC,
275 unsigned sparsitySelector = getSparsitySelector();
276 if (sparsitySelector > 1)
277 return emitOpError() <<
"sparsity selector should be 0 or 1";
278 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
279 getMatrixC(), getMmaShapeAsArray(),
280 getOperation()->hasAttr(getTf32EnabledAttrName()),
290 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().
getType());
293 auto resVector = llvm::cast<VectorType>(getRes().
getType());
297 Type resType = resVector.getElementType();
301 int64_t numElementsPer32b = 32 / elementBitWidth;
304 int64_t numTiles = getNumTiles();
307 bool isTranspose = getTranspose();
313 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
315 <<
"expected nvgpu.ldmatrix srcMemref must have a memory space "
316 "attribute of IntegerAttr("
317 << NVGPUDialect::kSharedMemoryAddressSpace
318 <<
") or gpu::AddressSpaceAttr(Workgroup)";
319 if (elementBitWidth > 32)
320 return emitError() <<
"nvgpu.ldmatrix works for 32b or lower";
321 if (isTranspose && !(elementBitWidth == 16))
323 <<
"nvgpu.ldmatrix transpose works only at 16b granularity";
324 if (resShape.size() != 2) {
325 return emitError() <<
"results must be 2 dimensional vector";
327 if (!(resShape[1] == numElementsPer32b))
328 return emitError() <<
"expected vector register shape[1] = "
329 << numElementsPer32b;
330 if (!(resShape[0] == numTiles))
332 <<
"expected vector register shape[0] and numTiles to match";
342 Operation *op, nvgpu::TensorMapDescriptorType descType,
343 std::optional<MemRefType> memrefType = std::nullopt) {
344 MemRefType descMemref = descType.getTensor();
346 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
347 return op->
emitError() <<
"Interleave options are not supported yet.";
350 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
351 return op->
emitError() <<
"the tensor map descriptor has incorrect address "
352 "space, it must be shared memory address space.";
355 if (!descMemref.hasStaticShape())
356 return op->
emitError() <<
"the tensor map descriptor must be static shaped";
358 for (
auto dim : descMemref.getShape()) {
360 return op->
emitError() <<
"the tensor map descriptor must have "
361 "dimensions between 1 and "
365 if (descMemref.getRank() > 1 &&
366 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
367 unsigned lastDimensionByte =
368 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
370 return op->
emitError() <<
"the tensormap descriptor must have last "
373 << lastDimensionByte <<
" bytes";
377 if (!memrefType.has_value())
380 MemRefType dstMemref = memrefType.value();
383 if (descMemref.getElementType() != dstMemref.getElementType()) {
384 return op->
emitError() <<
"the element type of tensor map descriptor and "
385 "memref must be same";
388 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
389 return op->
emitError() <<
"the destination memref has incorrect address "
390 "space, it must be shared memory address space.";
392 if (!dstMemref.hasStaticShape())
393 return op->
emitError() <<
"the destination memref must be static shaped";
395 if (dstMemref.getRank() != descMemref.getRank()) {
396 return op->
emitError() <<
"the shape of tensor map descriptor and "
397 "memref must have same rank";
399 if (!descMemref.getShape().equals(dstMemref.getShape())) {
400 return op->
emitError() <<
"memref and tensor map shapes mismatch "
401 << descMemref <<
" != " << dstMemref;
410 if (error.has_value())
411 return error.value();
415 <<
" coordinates are supported.";
418 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
419 return emitError() <<
"number of coordinates do not match with the rank of "
420 "tensor descriptor map.";
433 if (error.has_value())
434 return error.value();
438 <<
" coordinates are supported.";
441 size_t(getTensorMapDescriptor().
getType().getTensor().getRank())) {
442 return emitError() <<
"number of coordinates do not match with the rank of "
443 "tensor descriptor map.";
452 <<
" coordinates are supported.";
455 std::optional<InFlightDiagnostic> error =
457 if (error.has_value())
458 return error.value();
468 std::optional<InFlightDiagnostic> error =
470 if (error.has_value())
471 return error.value();
473 if (getTensorMap().
getType().getSwizzle() !=
474 TensorMapSwizzleKind::SWIZZLE_128B) {
476 << stringifyTensorMapSwizzleKind(
477 TensorMapSwizzleKind::SWIZZLE_128B)
478 <<
" is supported for the time being";
481 if (getTensorMap().
getType().getInterleave() !=
482 TensorMapInterleaveKind::INTERLEAVE_NONE) {
484 << stringifyTensorMapInterleaveKind(
485 TensorMapInterleaveKind::INTERLEAVE_NONE)
486 <<
" is supported for the time being";
532 72, 80, 88, 96, 104, 112, 120, 128,
533 136, 144, 152, 160, 168, 176, 184, 192,
534 200, 208, 216, 224, 232, 240, 248, 256};
536 80, 96, 112, 128, 144, 160,
537 176, 192, 208, 224, 240, 256};
540 if (llvm::is_contained(allowedN, sizeN))
544 if (llvm::is_contained(allowedNshort, sizeN))
550 if (getTransposeA() && !getTransposeB())
552 <<
"supports non-transpose A (Row Major) "
553 "and transpose B (Column Major) for the time being ";
554 MemRefType matrixA = getDescriptorA().
getType().getTensor();
555 MemRefType matrixB = getDescriptorB().
getType().getTensor();
556 VectorType matrixC = getMatrixC().
getType().getFragmented();
557 VectorType matrixD = getMatrixD().getType().getFragmented();
559 if (matrixC != matrixD)
560 return emitOpError() <<
"type of matrix C and matrix D must be the same";
562 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
563 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
565 <<
"has matrices A, B, C and D, they must be 2 dimensional";
568 if (matrixA.getShape()[1] != matrixB.getShape()[0])
569 return emitOpError() <<
"2nd dim matrix-A (" << matrixA.getShape()[1]
570 <<
")!= 1st dim matrix-B (" << matrixB.getShape()[0]
572 if (matrixA.getShape()[0] != matrixC.getShape()[0])
573 return emitOpError() <<
"1st dim matrix-A ( " << matrixA.getShape()[0]
574 <<
" )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
576 if (matrixB.getShape()[1] != matrixC.getShape()[1])
577 return emitOpError() <<
"2nd dim matrix-B ( " << matrixB.getShape()[1]
578 <<
" ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
582 matrixA.getElementType(),
583 matrixB.getElementType())))
584 return emitOpError() << matrixC.getElementType()
585 <<
" += " << matrixA.getElementType() <<
" * "
586 << matrixB.getElementType()
587 <<
", it is not supported.";
589 if (failed(
isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
590 return emitOpError() <<
"has input type " << matrixB <<
" n is set to "
591 << matrixB.getDimSize(1) <<
", it is not supported";
595 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
596 !matrixA.getElementType().isBF16()) {
597 return emitOpError() <<
"hit a limitation: " << matrixC.getElementType()
598 <<
" += " << matrixA.getElementType() <<
" * "
599 << matrixB.getElementType()
600 <<
", it is not supported yet";
607 MemRefType dstMemrefType = getDstMemref().getType();
608 VectorType vtype = getMatrixD().getType().getFragmented();
611 if (!vtype.getElementType().isF32()) {
613 <<
"hit a limitation: only f32 results for the time being";
615 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
616 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
617 return emitOpError() <<
"results [" << vtype <<
"][" << vtype.getDimSize(1)
618 <<
"] values. However, destination memref["
619 << dstMemrefType.getDimSize(0) <<
"]["
620 << dstMemrefType.getDimSize(1)
621 <<
"] does not have same size as results";
632 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
633 int64_t sizeM = accType.getFragmented().getDimSize(0);
634 int64_t sizeN = accType.getFragmented().getDimSize(1);
635 Type elemType = accType.getFragmented().getElementType();
639 return emitOpError() <<
"has type " << accType.getFragmented()
640 <<
". It does not fit into warp-group "
641 "level (wgmma) matrix multiplication instruction "
642 "(or not supported yet)";
652 RcpRoundingModeAttr rounding = getRoundingAttr();
655 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
656 return emitOpError() <<
"has a limitation. " << rounding
657 <<
" or non-ftz is not supported yet.";
666 #define GET_ATTRDEF_CLASSES
667 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
669 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
671 #define GET_OP_CLASSES
672 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
674 #define GET_TYPEDEF_CLASSES
675 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, nvgpu::TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
constexpr unsigned kMaxTMALastdimByte
Last dimension of 2D+ TMA must be 128 bytes.
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FN() const
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isFloat8E5M2() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.