MLIR  19.0.0git
NVGPUDialect.cpp
Go to the documentation of this file.
1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Verifier.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::nvgpu;
32 
33 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
34 
35 void nvgpu::NVGPUDialect::initialize() {
36  addTypes<
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
39  >();
40  addAttributes<
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
43  >();
44  addOperations<
45 #define GET_OP_LIST
46 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
47  >();
48 }
49 
50 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
51  if (!memorySpace)
52  return false;
53  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
54  return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
55  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
56  return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
57  return false;
58 }
59 
60 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
61  Attribute memorySpace = type.getMemorySpace();
62  return isSharedMemoryAddressSpace(memorySpace);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // NVGPU_DeviceAsyncCopyOp
67 //===----------------------------------------------------------------------===//
68 
70  auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
71  auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
72 
73  if (!isLastMemrefDimUnitStride(srcMemref))
74  return emitError("source memref most minor dim must have unit stride");
75  if (!isLastMemrefDimUnitStride(dstMemref))
76  return emitError("destination memref most minor dim must have unit stride");
77  if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
78  return emitError()
79  << "destination memref must have a memory space attribute of "
80  "IntegerAttr("
81  << NVGPUDialect::kSharedMemoryAddressSpace
82  << ") or gpu::AddressSpaceAttr(Workgroup)";
83  if (dstMemref.getElementType() != srcMemref.getElementType())
84  return emitError("source and destination must have the same element type");
85  if (size_t(srcMemref.getRank()) != getSrcIndices().size())
86  return emitOpError() << "expected " << srcMemref.getRank()
87  << " source indices, got " << getSrcIndices().size();
88  if (size_t(dstMemref.getRank()) != getDstIndices().size())
89  return emitOpError() << "expected " << dstMemref.getRank()
90  << " destination indices, got "
91  << getDstIndices().size();
92  int64_t dstElements = getDstElements().getZExtValue();
93  int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
94  if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
95  unsigned dstWidth = dstMemref.getElementTypeBitWidth();
97  diag << "Requested copy elements is " << dstElements << " with width "
98  << dstMemref.getElementTypeBitWidth()
99  << ". But copy elements could be one of ";
100  if ((32 / dstWidth) > 0)
101  diag << (32 / dstWidth) << ", ";
102  if ((64 / dstWidth) > 0)
103  diag << (64 / dstWidth) << ", ";
104  if ((128 / dstWidth) > 0)
105  diag << (128 / dstWidth) << ".";
106  return diag;
107  }
108  if (getBypassL1().has_value()) {
109  int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
110  if (getBypassL1().value() && sizeInBytes != 16) {
111  return emitOpError() << "bypassL1 does not satify alignment for "
112  << dstMemref << " with destination element "
113  << dstElements
114  << ". Unset bypassL1, or set "
115  "destination element to "
116  << req;
117  }
118  }
119  return success();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // NVGPU_MmaSyncOp
124 //===----------------------------------------------------------------------===//
125 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
126  ::mlir::OperationState &odsState, Value matrixA,
127  Value matrixB, Value matrixC, ArrayAttr mmaShape) {
128  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
129  mmaShape, UnitAttr());
130 }
131 
132 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
133  ::mlir::OperationState &odsState, Value matrixA,
134  Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
135  bool tf32Enabled) {
136  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
137  odsBuilder.getI64ArrayAttr(mmaShape),
138  tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
139 }
140 
141 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
143  TypedValue<VectorType> matrixA,
144  TypedValue<VectorType> matrixB,
145  TypedValue<VectorType> matrixC,
146  const std::array<int64_t, 3> &mmaShape,
147  bool tf32Enabled, bool sparse = false) {
148 
149  // The verification for mma.sync covering various shapes and data types is
150  // based on the fundamental tensor core shape.
151 
152  // "Fundamental" tensor core shapes:
153  // - For F32 (TF32), F16, S8, and S4 data
154  // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
155  // - F64 is an exception and is of shape 8-by-8-by-256b.
156  int64_t shapeM = 8;
157  int64_t shapeN = 8;
158  int64_t shapeK; // set based on data type (128b for all data types except F64)
159 
160  // Number of elements A, B, and C per thread per fundamental tensor core tile
161  int64_t numElementA; // set based on data type (32b except F64)
162  int64_t numElementB; // set based on data type (32b except F64)
163  int64_t numElementC{2}; // two accumulator elements per fundamental tile
164 
165  // nvgpu.mma.sync vector operands (per thread)
166  auto aVector = matrixA.getType();
167  auto bVector = matrixB.getType();
168  auto cVector = matrixC.getType();
169 
170  // vector shapes
171  ArrayRef<int64_t> aShape = aVector.getShape();
172  ArrayRef<int64_t> bShape = bVector.getShape();
173  ArrayRef<int64_t> cShape = cVector.getShape();
174 
175  // vector element type
176  Type aType = aVector.getElementType();
177 
178  // Certain data types are not allowed in sparse mode.
179  if (sparse && aType.isF64())
180  return op->emitError() << "f64 is not supported for sparse mode";
181 
182  if (aType.isF64()) {
183  // exception to 8-by-8-128b fundamental tensor core tile size
184  shapeK = 4;
185  numElementA = 1;
186  numElementB = 1;
187  } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
188  aType.isInteger(8) || aType.isInteger(4)) {
189  // 8-by-8-128b fundamental tensor core tile size
190  int operandBitwidth = aType.getIntOrFloatBitWidth();
191  shapeK = 128 / operandBitwidth; // 128b wide shapeK
192 
193  numElementA = 32 / operandBitwidth; // 32b wide operand A
194  numElementB = 32 / operandBitwidth; // 32b wide operand B
195  } else {
196  return op->emitError()
197  << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
198  "supported by "
199  << op->getName();
200  }
201 
202  //
203  // Basic verification
204  //
205 
206  auto [m, n, k] = mmaShape;
207 
208  // verify warp-wide size for vector a
209  int64_t sparseFactor = sparse ? 2 : 1;
210  if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
211  return op->emitOpError()
212  << "expected " << m * k << " warp-wide matrix A elements";
213 
214  // verify warp-wide size for vector b
215  if (bShape[0] * bShape[1] * kWarpSize != k * n)
216  return op->emitOpError()
217  << "expected " << k * n << " warp-wide matrix B elements";
218 
219  // verify warp-wide size for vector c
220  if (cShape[0] * cShape[1] * kWarpSize != m * n)
221  return op->emitOpError()
222  << "expected " << m * n << " warp-wide matrix C elements";
223 
224  // verify tf32 tensor cores are enabled for only F32 datatype
225  if (tf32Enabled && !(aType.isF32()))
226  return op->emitOpError()
227  << "expected tf32 tensor cores only for F32 operands";
228 
229  //
230  // Extended verification
231  //
232 
233  // tiles of fundamental tensor core operations
234  int64_t mTile = m / shapeM;
235  int64_t nTile = n / shapeN;
236  int64_t kTile = k / shapeK;
237 
238  // verify shape of aVector
239  if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
240  (aShape[1] != numElementA))
241  return op->emitOpError() << "expected matrix A to be shaped ("
242  << mTile * kTile << " x " << numElementA << ")";
243 
244  // verify shape of bVector
245  if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
246  return op->emitOpError() << "expected matrix B to be shaped ("
247  << kTile * nTile << " x " << numElementB << ")";
248 
249  // verify shape of cVector
250  if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
251  return op->emitOpError() << "expected matrix C to be shaped ("
252  << mTile * nTile << " x " << numElementC << ")";
253 
254  return success();
255 }
256 
258  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
259  getMatrixC(), getMmaShapeAsArray(),
260  getOperation()->hasAttr(getTf32EnabledAttrName()));
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // NVGPU_MmaSparseSyncOp
265 //===----------------------------------------------------------------------===//
266 void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
267  ::mlir::OperationState &odsState, Value matrixA,
268  Value matrixB, Value matrixC, Value sparseMetadata,
269  ArrayRef<int64_t> mmaShape) {
270  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
271  sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
272 }
273 
275  unsigned sparsitySelector = getSparsitySelector();
276  if (sparsitySelector > 1)
277  return emitOpError() << "sparsity selector should be 0 or 1";
278  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
279  getMatrixC(), getMmaShapeAsArray(),
280  getOperation()->hasAttr(getTf32EnabledAttrName()),
281  true);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // NVGPU_LdMatrixOp
286 //===----------------------------------------------------------------------===//
288 
289  // ldmatrix reads data from source in shared memory
290  auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
291 
292  // ldmatrix writes data to result/destination in vector registers
293  auto resVector = llvm::cast<VectorType>(getRes().getType());
294 
295  // vector register shape, element type, and bitwidth
296  ArrayRef<int64_t> resShape = resVector.getShape();
297  Type resType = resVector.getElementType();
298  int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
299 
300  // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
301  int64_t numElementsPer32b = 32 / elementBitWidth;
302 
303  // number of 8-by-8 tiles
304  int64_t numTiles = getNumTiles();
305 
306  // transpose elements in vector registers at 16b granularity when true
307  bool isTranspose = getTranspose();
308 
309  //
310  // verification
311  //
312 
313  if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
314  return emitError()
315  << "expected nvgpu.ldmatrix srcMemref must have a memory space "
316  "attribute of IntegerAttr("
317  << NVGPUDialect::kSharedMemoryAddressSpace
318  << ") or gpu::AddressSpaceAttr(Workgroup)";
319  if (elementBitWidth > 32)
320  return emitError() << "nvgpu.ldmatrix works for 32b or lower";
321  if (isTranspose && !(elementBitWidth == 16))
322  return emitError()
323  << "nvgpu.ldmatrix transpose works only at 16b granularity";
324  if (resShape.size() != 2) {
325  return emitError() << "results must be 2 dimensional vector";
326  }
327  if (!(resShape[1] == numElementsPer32b))
328  return emitError() << "expected vector register shape[1] = "
329  << numElementsPer32b;
330  if (!(resShape[0] == numTiles))
331  return emitError()
332  << "expected vector register shape[0] and numTiles to match";
333 
334  return success();
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // NVGPU_TmaAsyncLoadOp
339 //===----------------------------------------------------------------------===//
340 
341 std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
342  Operation *op, nvgpu::TensorMapDescriptorType descType,
343  std::optional<MemRefType> memrefType = std::nullopt) {
344  MemRefType descMemref = descType.getTensor();
345  // Limitation
346  if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
347  return op->emitError() << "Interleave options are not supported yet.";
348 
349  // Address space check for shared memory check
350  if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
351  return op->emitError() << "the tensor map descriptor has incorrect address "
352  "space, it must be shared memory address space.";
353  }
354  // Support only static shape for the time being
355  if (!descMemref.hasStaticShape())
356  return op->emitError() << "the tensor map descriptor must be static shaped";
357 
358  for (auto dim : descMemref.getShape()) {
359  if (dim <= 0 || dim > kMaxTMADimension) {
360  return op->emitError() << "the tensor map descriptor must have "
361  "dimensions between 1 and "
362  << kMaxTMADimension << " but it is " << dim;
363  }
364  }
365  if (descMemref.getRank() > 1 &&
366  descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
367  unsigned lastDimensionByte =
368  descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
369  if (lastDimensionByte != kMaxTMALastdimByte)
370  return op->emitError() << "the tensormap descriptor must have last "
371  "dimension of "
372  << kMaxTMALastdimByte << " bytes but it is "
373  << lastDimensionByte << " bytes";
374  }
375 
376  // No verification if memref type is not provided
377  if (!memrefType.has_value())
378  return std::nullopt;
379 
380  MemRefType dstMemref = memrefType.value();
381 
382  // Check element type
383  if (descMemref.getElementType() != dstMemref.getElementType()) {
384  return op->emitError() << "the element type of tensor map descriptor and "
385  "memref must be same";
386  }
387 
388  if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
389  return op->emitError() << "the destination memref has incorrect address "
390  "space, it must be shared memory address space.";
391  }
392  if (!dstMemref.hasStaticShape())
393  return op->emitError() << "the destination memref must be static shaped";
394 
395  if (dstMemref.getRank() != descMemref.getRank()) {
396  return op->emitError() << "the shape of tensor map descriptor and "
397  "memref must have same rank";
398  }
399  if (!descMemref.getShape().equals(dstMemref.getShape())) {
400  return op->emitError() << "memref and tensor map shapes mismatch "
401  << descMemref << " != " << dstMemref;
402  }
403 
404  return std::nullopt;
405 }
406 
408  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
409  *this, getTensorMapDescriptor().getType(), getDst().getType());
410  if (error.has_value())
411  return error.value();
412 
413  if (getCoordinates().size() > kMaxTMATensorDimension) {
414  return emitError() << "Maximum " << kMaxTMATensorDimension
415  << " coordinates are supported.";
416  }
417  if (getCoordinates().size() !=
418  size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
419  return emitError() << "number of coordinates do not match with the rank of "
420  "tensor descriptor map.";
421  }
422 
423  return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // NVGPU_TmaAsyncStoreOp
428 //===----------------------------------------------------------------------===//
429 
431  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
432  *this, getTensorMapDescriptor().getType(), getSrc().getType());
433  if (error.has_value())
434  return error.value();
435 
436  if (getCoordinates().size() > kMaxTMATensorDimension) {
437  return emitError() << "Maximum " << kMaxTMATensorDimension
438  << " coordinates are supported.";
439  }
440  if (getCoordinates().size() !=
441  size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
442  return emitError() << "number of coordinates do not match with the rank of "
443  "tensor descriptor map.";
444  }
445 
446  return success();
447 }
448 
450  if (getBoxDimensions().size() > kMaxTMATensorDimension) {
451  return emitError() << "Maximum " << kMaxTMATensorDimension
452  << " coordinates are supported.";
453  }
454 
455  std::optional<InFlightDiagnostic> error =
456  verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
457  if (error.has_value())
458  return error.value();
459 
460  return success();
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // NVGPU_WarpgroupGenerateDescriptorOp
465 //===----------------------------------------------------------------------===//
466 
468  std::optional<InFlightDiagnostic> error =
469  verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
470  if (error.has_value())
471  return error.value();
472 
473  if (getTensorMap().getType().getSwizzle() !=
474  TensorMapSwizzleKind::SWIZZLE_128B) {
475  return emitError() << "supports only "
476  << stringifyTensorMapSwizzleKind(
477  TensorMapSwizzleKind::SWIZZLE_128B)
478  << " is supported for the time being";
479  }
480 
481  if (getTensorMap().getType().getInterleave() !=
482  TensorMapInterleaveKind::INTERLEAVE_NONE) {
483  return emitError() << "supports only "
484  << stringifyTensorMapInterleaveKind(
485  TensorMapInterleaveKind::INTERLEAVE_NONE)
486  << " is supported for the time being";
487  }
488 
489  return success();
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // WarpgroupMmaOp
494 //===----------------------------------------------------------------------===//
495 
497  // F32 += F16 + F16
498  // F16 += F16 + F16
499  if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
500  return success();
501  // F32 += TF32 + TF32
502  if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
503  return success();
504  // s32 += i8 + i8
505  if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
506  return success();
507  // s32 += i1 + i1
508  if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
509  return success();
510  // F32 += BF16 + BF16
511  // F16 += BF16 + BF16
512  if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
513  return success();
514  // F16 += f8 + f8
515  // F32 += f8 + f8
516  if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
517  (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
518  (typeD.isF32() || typeD.isF16()))
519  return success();
520 
521  return failure();
522 }
523 
525  if (sizeM % kWgmmaSizeM)
526  return failure();
527  return success();
528 }
529 
530 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
531  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
532  72, 80, 88, 96, 104, 112, 120, 128,
533  136, 144, 152, 160, 168, 176, 184, 192,
534  200, 208, 216, 224, 232, 240, 248, 256};
535  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
536  80, 96, 112, 128, 144, 160,
537  176, 192, 208, 224, 240, 256};
538  if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
539  typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
540  if (llvm::is_contained(allowedN, sizeN))
541  return success();
542 
543  if (typeA.isInteger(8) || typeA.isInteger(1))
544  if (llvm::is_contained(allowedNshort, sizeN))
545  return success();
546  return failure();
547 }
548 
550  if (getTransposeA() && !getTransposeB())
551  return emitOpError()
552  << "supports non-transpose A (Row Major) "
553  "and transpose B (Column Major) for the time being ";
554  MemRefType matrixA = getDescriptorA().getType().getTensor();
555  MemRefType matrixB = getDescriptorB().getType().getTensor();
556  VectorType matrixC = getMatrixC().getType().getFragmented();
557  VectorType matrixD = getMatrixD().getType().getFragmented();
558 
559  if (matrixC != matrixD)
560  return emitOpError() << "type of matrix C and matrix D must be the same";
561 
562  if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
563  matrixC.getRank() != 2 || matrixD.getRank() != 2) {
564  return emitOpError()
565  << "has matrices A, B, C and D, they must be 2 dimensional";
566  }
567 
568  if (matrixA.getShape()[1] != matrixB.getShape()[0])
569  return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
570  << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
571  << " )";
572  if (matrixA.getShape()[0] != matrixC.getShape()[0])
573  return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
574  << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
575  << " )";
576  if (matrixB.getShape()[1] != matrixC.getShape()[1])
577  return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
578  << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
579  << " )";
580 
581  if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
582  matrixA.getElementType(),
583  matrixB.getElementType())))
584  return emitOpError() << matrixC.getElementType()
585  << " += " << matrixA.getElementType() << " * "
586  << matrixB.getElementType()
587  << ", it is not supported.";
588  // Check N
589  if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
590  return emitOpError() << "has input type " << matrixB << " n is set to "
591  << matrixB.getDimSize(1) << ", it is not supported";
592  }
593 
594  // Currently, f16/bf16 supported
595  if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
596  !matrixA.getElementType().isBF16()) {
597  return emitOpError() << "hit a limitation: " << matrixC.getElementType()
598  << " += " << matrixA.getElementType() << " * "
599  << matrixB.getElementType()
600  << ", it is not supported yet";
601  }
602 
603  return success();
604 }
605 
607  MemRefType dstMemrefType = getDstMemref().getType();
608  VectorType vtype = getMatrixD().getType().getFragmented();
609 
610  // Limitation
611  if (!vtype.getElementType().isF32()) {
612  return emitOpError()
613  << "hit a limitation: only f32 results for the time being";
614  }
615  if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
616  vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
617  return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
618  << "] values. However, destination memref["
619  << dstMemrefType.getDimSize(0) << "]["
620  << dstMemrefType.getDimSize(1)
621  << "] does not have same size as results";
622  }
623  return success();
624 }
625 
626 //===----------------------------------------------------------------------===//
627 // WarpgroupMmaInitAccumulatorOp
628 //===----------------------------------------------------------------------===//
629 
631 
632  nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
633  int64_t sizeM = accType.getFragmented().getDimSize(0);
634  int64_t sizeN = accType.getFragmented().getDimSize(1);
635  Type elemType = accType.getFragmented().getElementType();
636 
637  if (failed(isAllowedSizeM(sizeM)) ||
638  failed(isAllowedSizeN(sizeN, elemType))) {
639  return emitOpError() << "has type " << accType.getFragmented()
640  << ". It does not fit into warp-group "
641  "level (wgmma) matrix multiplication instruction "
642  "(or not supported yet)";
643  }
644  return success();
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // TableGen'd dialect, type, and op definitions
649 //===----------------------------------------------------------------------===//
650 
651 #define GET_ATTRDEF_CLASSES
652 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
653 
654 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
655 
656 #define GET_OP_CLASSES
657 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
658 
659 #define GET_TYPEDEF_CLASSES
660 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, nvgpu::TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:27
constexpr int kWarpSize
Definition: NVGPUDialect.h:24
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
Definition: NVGPUDialect.h:31
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Definition: NVGPUDialect.h:35
constexpr unsigned kMaxTMALastdimByte
Last dimension of 2D+ TMA must be 128 bytes.
Definition: NVGPUDialect.h:37
Attributes are known-constant values of operations.
Definition: Attributes.h:25
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isTF32() const
Definition: Types.cpp:50
bool isFloat8E4M3FN() const
Definition: Types.cpp:38
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isBF16() const
Definition: Types.cpp:48
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.