MLIR  22.0.0git
NVGPUDialect.cpp
Go to the documentation of this file.
1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Verifier.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::nvgpu;
27 
28 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
29 
30 void NVGPUDialect::initialize() {
31  addTypes<
32 #define GET_TYPEDEF_LIST
33 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
34  >();
35  addAttributes<
36 #define GET_ATTRDEF_LIST
37 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
38  >();
39  addOperations<
40 #define GET_OP_LIST
41 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
42  >();
43 }
44 
45 bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
46  if (!memorySpace)
47  return false;
48  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
49  return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
51  return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
52  return false;
53 }
54 
55 bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56  Attribute memorySpace = type.getMemorySpace();
57  return isSharedMemoryAddressSpace(memorySpace);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // NVGPU_DeviceAsyncCopyOp
62 //===----------------------------------------------------------------------===//
63 
64 LogicalResult DeviceAsyncCopyOp::verify() {
65  auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
66  auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
67 
68  if (!srcMemref.isLastDimUnitStride())
69  return emitError("source memref most minor dim must have unit stride");
70  if (!dstMemref.isLastDimUnitStride())
71  return emitError("destination memref most minor dim must have unit stride");
72  if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
73  return emitError()
74  << "destination memref must have a memory space attribute of "
75  "IntegerAttr("
76  << NVGPUDialect::kSharedMemoryAddressSpace
77  << ") or gpu::AddressSpaceAttr(Workgroup)";
78  if (dstMemref.getElementType() != srcMemref.getElementType())
79  return emitError("source and destination must have the same element type");
80  if (size_t(srcMemref.getRank()) != getSrcIndices().size())
81  return emitOpError() << "expected " << srcMemref.getRank()
82  << " source indices, got " << getSrcIndices().size();
83  if (size_t(dstMemref.getRank()) != getDstIndices().size())
84  return emitOpError() << "expected " << dstMemref.getRank()
85  << " destination indices, got "
86  << getDstIndices().size();
87  int64_t dstElements = getDstElements().getZExtValue();
88  int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89  if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90  unsigned dstWidth = dstMemref.getElementTypeBitWidth();
92  diag << "Requested copy elements is " << dstElements << " with width "
93  << dstMemref.getElementTypeBitWidth()
94  << ". But copy elements could be one of ";
95  if ((32 / dstWidth) > 0)
96  diag << (32 / dstWidth) << ", ";
97  if ((64 / dstWidth) > 0)
98  diag << (64 / dstWidth) << ", ";
99  if ((128 / dstWidth) > 0)
100  diag << (128 / dstWidth) << ".";
101  return diag;
102  }
103  if (getBypassL1().has_value()) {
104  int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105  if (getBypassL1().value() && sizeInBytes != 16) {
106  return emitOpError() << "bypassL1 does not satify alignment for "
107  << dstMemref << " with destination element "
108  << dstElements
109  << ". Unset bypassL1, or set "
110  "destination element to "
111  << req;
112  }
113  }
114  return success();
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // NVGPU_MmaSyncOp
119 //===----------------------------------------------------------------------===//
120 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
121  ::mlir::OperationState &odsState, Value matrixA,
122  Value matrixB, Value matrixC, ArrayAttr mmaShape) {
123  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
124  mmaShape, UnitAttr());
125 }
126 
127 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
128  ::mlir::OperationState &odsState, Value matrixA,
129  Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
130  bool tf32Enabled) {
131  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
132  odsBuilder.getI64ArrayAttr(mmaShape),
133  tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
134 }
135 
136 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
137 static LogicalResult verifyMmaSyncOp(Operation *op,
138  TypedValue<VectorType> matrixA,
139  TypedValue<VectorType> matrixB,
140  TypedValue<VectorType> matrixC,
141  const std::array<int64_t, 3> &mmaShape,
142  bool tf32Enabled, bool sparse = false) {
143  // The verification for mma.sync covering various shapes and data types is
144  // based on the fundamental tensor core shape.
145 
146  // "Fundamental" tensor core shapes:
147  // - For F32 (TF32), F16, S8, and S4 data
148  // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
149  // - F64 is an exception and is of shape 8-by-8-by-256b.
150  int64_t shapeM = 8;
151  int64_t shapeN = 8;
152  int64_t shapeK; // set based on data type (128b for all data types except F64)
153 
154  // Number of elements A, B, and C per thread per fundamental tensor core tile
155  int64_t numElementA; // set based on data type (32b except F64)
156  int64_t numElementB; // set based on data type (32b except F64)
157  int64_t numElementC{2}; // two accumulator elements per fundamental tile
158 
159  // nvgpu.mma.sync vector operands (per thread)
160  auto aVector = matrixA.getType();
161  auto bVector = matrixB.getType();
162  auto cVector = matrixC.getType();
163 
164  // vector shapes
165  ArrayRef<int64_t> aShape = aVector.getShape();
166  ArrayRef<int64_t> bShape = bVector.getShape();
167  ArrayRef<int64_t> cShape = cVector.getShape();
168 
169  // vector element type
170  Type aType = aVector.getElementType();
171 
172  // Certain data types are not allowed in sparse mode.
173  if (sparse && aType.isF64())
174  return op->emitError() << "f64 is not supported for sparse mode";
175 
176  if (aType.isF64()) {
177  // exception to 8-by-8-128b fundamental tensor core tile size
178  shapeK = 4;
179  numElementA = 1;
180  numElementB = 1;
181  } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
182  aType.isInteger(8) || aType.isInteger(4)) {
183  // 8-by-8-128b fundamental tensor core tile size
184  int operandBitwidth = aType.getIntOrFloatBitWidth();
185  shapeK = 128 / operandBitwidth; // 128b wide shapeK
186 
187  numElementA = 32 / operandBitwidth; // 32b wide operand A
188  numElementB = 32 / operandBitwidth; // 32b wide operand B
189  } else {
190  return op->emitError()
191  << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
192  "supported by "
193  << op->getName();
194  }
195 
196  //
197  // Basic verification
198  //
199 
200  if (aShape.size() != 2) {
201  return op->emitError() << "matrixA must be 2 dimensional vector";
202  }
203 
204  if (bShape.size() != 2) {
205  return op->emitError() << "matrixB must be 2 dimensional vector";
206  }
207 
208  if (cShape.size() != 2) {
209  return op->emitError() << "matrixC must be 2 dimensional vector";
210  }
211 
212  auto [m, n, k] = mmaShape;
213 
214  // verify warp-wide size for vector a
215  int64_t sparseFactor = sparse ? 2 : 1;
216  if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
217  return op->emitOpError()
218  << "expected " << m * k << " warp-wide matrix A elements";
219 
220  // verify warp-wide size for vector b
221  if (bShape[0] * bShape[1] * kWarpSize != k * n)
222  return op->emitOpError()
223  << "expected " << k * n << " warp-wide matrix B elements";
224 
225  // verify warp-wide size for vector c
226  if (cShape[0] * cShape[1] * kWarpSize != m * n)
227  return op->emitOpError()
228  << "expected " << m * n << " warp-wide matrix C elements";
229 
230  // verify tf32 tensor cores are enabled for only F32 datatype
231  if (tf32Enabled && !(aType.isF32()))
232  return op->emitOpError()
233  << "expected tf32 tensor cores only for F32 operands";
234 
235  //
236  // Extended verification
237  //
238 
239  // tiles of fundamental tensor core operations
240  int64_t mTile = m / shapeM;
241  int64_t nTile = n / shapeN;
242  int64_t kTile = k / shapeK;
243 
244  // verify shape of aVector
245  if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
246  (aShape[1] != numElementA))
247  return op->emitOpError() << "expected matrix A to be shaped ("
248  << mTile * kTile << " x " << numElementA << ")";
249 
250  // verify shape of bVector
251  if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
252  return op->emitOpError() << "expected matrix B to be shaped ("
253  << kTile * nTile << " x " << numElementB << ")";
254 
255  // verify shape of cVector
256  if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
257  return op->emitOpError() << "expected matrix C to be shaped ("
258  << mTile * nTile << " x " << numElementC << ")";
259 
260  return success();
261 }
262 
263 LogicalResult MmaSyncOp::verify() {
264  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
265  getMatrixC(), getMmaShapeAsArray(),
266  getOperation()->hasAttr(getTf32EnabledAttrName()));
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // NVGPU_MmaSparseSyncOp
271 //===----------------------------------------------------------------------===//
272 void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
273  ::mlir::OperationState &odsState, Value matrixA,
274  Value matrixB, Value matrixC, Value sparseMetadata,
275  ArrayRef<int64_t> mmaShape) {
276  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
277  sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
278 }
279 
280 LogicalResult MmaSparseSyncOp::verify() {
281  unsigned sparsitySelector = getSparsitySelector();
282  if (sparsitySelector > 1)
283  return emitOpError() << "sparsity selector should be 0 or 1";
284  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
285  getMatrixC(), getMmaShapeAsArray(),
286  getOperation()->hasAttr(getTf32EnabledAttrName()),
287  true);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // NVGPU_LdMatrixOp
292 //===----------------------------------------------------------------------===//
293 LogicalResult LdMatrixOp::verify() {
294  // ldmatrix reads data from source in shared memory
295  auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
296 
297  // ldmatrix writes data to result/destination in vector registers
298  auto resVector = llvm::cast<VectorType>(getRes().getType());
299 
300  // vector register shape, element type, and bitwidth
301  ArrayRef<int64_t> resShape = resVector.getShape();
302  Type resType = resVector.getElementType();
303  int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
304 
305  // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
306  int64_t numElementsPer32b = 32 / elementBitWidth;
307 
308  // number of 8-by-8 tiles
309  int64_t numTiles = getNumTiles();
310 
311  // transpose elements in vector registers at 16b granularity when true
312  bool isTranspose = getTranspose();
313 
314  //
315  // verification
316  //
317 
318  if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
319  return emitError()
320  << "expected nvgpu.ldmatrix srcMemref must have a memory space "
321  "attribute of IntegerAttr("
322  << NVGPUDialect::kSharedMemoryAddressSpace
323  << ") or gpu::AddressSpaceAttr(Workgroup)";
324  if (elementBitWidth > 32)
325  return emitError() << "nvgpu.ldmatrix works for 32b or lower";
326  if (isTranspose && !(elementBitWidth == 16))
327  return emitError()
328  << "nvgpu.ldmatrix transpose works only at 16b granularity";
329  if (resShape.size() != 2) {
330  return emitError() << "results must be 2 dimensional vector";
331  }
332  if (!(resShape[1] == numElementsPer32b))
333  return emitError() << "expected vector register shape[1] = "
334  << numElementsPer32b;
335  if (!(resShape[0] == numTiles))
336  return emitError()
337  << "expected vector register shape[0] and numTiles to match";
338 
339  return success();
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // NVGPU_TmaAsyncLoadOp
344 //===----------------------------------------------------------------------===//
345 
346 static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
347  switch (kind) {
348  case TensorMapSwizzleKind::SWIZZLE_32B:
349  return 32;
350  case TensorMapSwizzleKind::SWIZZLE_64B:
351  return 64;
352  case TensorMapSwizzleKind::SWIZZLE_128B:
353  return 128;
354  default:
355  return 0;
356  }
357 }
358 
359 std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
360  Operation *op, TensorMapDescriptorType descType,
361  std::optional<MemRefType> memrefType = std::nullopt) {
362  MemRefType descMemref = descType.getTensor();
363  // Limitation
364  if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
365  return op->emitError() << "Interleave options are not supported yet.";
366 
367  // Address space check for shared memory check
368  if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
369  return op->emitError() << "the tensor map descriptor has incorrect address "
370  "space, it must be shared memory address space.";
371  }
372  // Support only static shape for the time being
373  if (!descMemref.hasStaticShape())
374  return op->emitError() << "the tensor map descriptor must be static shaped";
375 
376  for (auto dim : descMemref.getShape()) {
377  if (dim <= 0 || dim > kMaxTMADimension) {
378  return op->emitError() << "the tensor map descriptor must have "
379  "dimensions between 1 and "
380  << kMaxTMADimension << " but it is " << dim;
381  }
382  }
383  if (descMemref.getRank() > 1 &&
384  descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
385  unsigned lastDimensionByte =
386  descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
387  unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
388  if (lastDimensionByte != expectByte)
389  return op->emitError() << "the tensormap descriptor must have last "
390  "dimension of "
391  << expectByte << " bytes but it is "
392  << lastDimensionByte << " bytes";
393  }
394 
395  // No verification if memref type is not provided
396  if (!memrefType.has_value())
397  return std::nullopt;
398 
399  MemRefType dstMemref = memrefType.value();
400 
401  // Check element type
402  if (descMemref.getElementType() != dstMemref.getElementType()) {
403  return op->emitError() << "the element type of tensor map descriptor and "
404  "memref must be same";
405  }
406 
407  if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
408  return op->emitError() << "the destination memref has incorrect address "
409  "space, it must be shared memory address space.";
410  }
411  if (!dstMemref.hasStaticShape())
412  return op->emitError() << "the destination memref must be static shaped";
413 
414  if (dstMemref.getRank() != descMemref.getRank()) {
415  return op->emitError() << "the shape of tensor map descriptor and "
416  "memref must have same rank";
417  }
418  if (!descMemref.getShape().equals(dstMemref.getShape())) {
419  return op->emitError() << "memref and tensor map shapes mismatch "
420  << descMemref << " != " << dstMemref;
421  }
422 
423  int lastDimBytes =
424  descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
425  if (lastDimBytes % kTMALastdimByte != 0) {
426  return op->emitError() << "the bytes in the last dimension of the tensor "
427  "map must be a multiple of 16";
428  }
429  return std::nullopt;
430 }
431 
432 LogicalResult TmaAsyncLoadOp::verify() {
433  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
434  *this, getTensorMapDescriptor().getType(), getDst().getType());
435  if (error.has_value())
436  return error.value();
437 
438  if (getCoordinates().size() > kMaxTMATensorDimension) {
439  return emitError() << "Maximum " << kMaxTMATensorDimension
440  << " coordinates are supported.";
441  }
442  if (getCoordinates().size() !=
443  size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
444  return emitError() << "number of coordinates do not match with the rank of "
445  "tensor descriptor map.";
446  }
447 
448  return success();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // NVGPU_TmaAsyncStoreOp
453 //===----------------------------------------------------------------------===//
454 
455 LogicalResult TmaAsyncStoreOp::verify() {
456  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
457  *this, getTensorMapDescriptor().getType(), getSrc().getType());
458  if (error.has_value())
459  return error.value();
460 
461  if (getCoordinates().size() > kMaxTMATensorDimension) {
462  return emitError() << "Maximum " << kMaxTMATensorDimension
463  << " coordinates are supported.";
464  }
465  if (getCoordinates().size() !=
466  size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
467  return emitError() << "number of coordinates do not match with the rank of "
468  "tensor descriptor map.";
469  }
470 
471  return success();
472 }
473 
474 LogicalResult TmaCreateDescriptorOp::verify() {
475  if (getBoxDimensions().size() > kMaxTMATensorDimension) {
476  return emitError() << "Maximum " << kMaxTMATensorDimension
477  << " coordinates are supported.";
478  }
479 
480  std::optional<InFlightDiagnostic> error =
481  verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
482  if (error.has_value())
483  return error.value();
484 
485  return success();
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // NVGPU_WarpgroupGenerateDescriptorOp
490 //===----------------------------------------------------------------------===//
491 
493  std::optional<InFlightDiagnostic> error =
494  verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
495  if (error.has_value())
496  return error.value();
497 
498  if (getTensorMap().getType().getSwizzle() !=
499  TensorMapSwizzleKind::SWIZZLE_128B) {
500  return emitError() << "supports only "
501  << stringifyTensorMapSwizzleKind(
502  TensorMapSwizzleKind::SWIZZLE_128B)
503  << " is supported for the time being";
504  }
505 
506  if (getTensorMap().getType().getInterleave() !=
507  TensorMapInterleaveKind::INTERLEAVE_NONE) {
508  return emitError() << "supports only "
509  << stringifyTensorMapInterleaveKind(
510  TensorMapInterleaveKind::INTERLEAVE_NONE)
511  << " is supported for the time being";
512  }
513 
514  return success();
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // WarpgroupMmaOp
519 //===----------------------------------------------------------------------===//
520 
521 LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
522  // F32 += F16 + F16
523  // F16 += F16 + F16
524  if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
525  return success();
526  // F32 += TF32 + TF32
527  if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
528  return success();
529  // s32 += i8 + i8
530  if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
531  return success();
532  // s32 += i1 + i1
533  if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
534  return success();
535  // F32 += BF16 + BF16
536  // F16 += BF16 + BF16
537  if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
538  return success();
539  // F16 += f8 + f8
540  // F32 += f8 + f8
541  if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
542  isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
543  (typeD.isF32() || typeD.isF16()))
544  return success();
545 
546  return failure();
547 }
548 
549 LogicalResult isAllowedSizeM(int sizeM) {
550  if (sizeM % kWgmmaSizeM)
551  return failure();
552  return success();
553 }
554 
555 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
556  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
557  72, 80, 88, 96, 104, 112, 120, 128,
558  136, 144, 152, 160, 168, 176, 184, 192,
559  200, 208, 216, 224, 232, 240, 248, 256};
560  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
561  80, 96, 112, 128, 144, 160,
562  176, 192, 208, 224, 240, 256};
563  if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
564  isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
565  if (llvm::is_contained(allowedN, sizeN))
566  return success();
567 
568  if (typeA.isInteger(8) || typeA.isInteger(1))
569  if (llvm::is_contained(allowedNshort, sizeN))
570  return success();
571  return failure();
572 }
573 
574 LogicalResult WarpgroupMmaOp::verify() {
575  if (getTransposeA() && !getTransposeB())
576  return emitOpError()
577  << "supports non-transpose A (Row Major) "
578  "and transpose B (Column Major) for the time being ";
579  MemRefType matrixA = getDescriptorA().getType().getTensor();
580  MemRefType matrixB = getDescriptorB().getType().getTensor();
581  VectorType matrixC = getMatrixC().getType().getFragmented();
582  VectorType matrixD = getMatrixD().getType().getFragmented();
583 
584  if (matrixC != matrixD)
585  return emitOpError() << "type of matrix C and matrix D must be the same";
586 
587  if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
588  matrixC.getRank() != 2 || matrixD.getRank() != 2) {
589  return emitOpError()
590  << "has matrices A, B, C and D, they must be 2 dimensional";
591  }
592 
593  if (matrixA.getShape()[1] != matrixB.getShape()[0])
594  return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
595  << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
596  << " )";
597  if (matrixA.getShape()[0] != matrixC.getShape()[0])
598  return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
599  << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
600  << " )";
601  if (matrixB.getShape()[1] != matrixC.getShape()[1])
602  return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
603  << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
604  << " )";
605 
606  if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
607  matrixA.getElementType(),
608  matrixB.getElementType())))
609  return emitOpError() << matrixC.getElementType()
610  << " += " << matrixA.getElementType() << " * "
611  << matrixB.getElementType()
612  << ", it is not supported.";
613  // Check N
614  if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
615  return emitOpError() << "has input type " << matrixB << " n is set to "
616  << matrixB.getDimSize(1) << ", it is not supported";
617  }
618 
619  // Currently, f16/bf16 supported
620  if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
621  !matrixA.getElementType().isBF16()) {
622  return emitOpError() << "hit a limitation: " << matrixC.getElementType()
623  << " += " << matrixA.getElementType() << " * "
624  << matrixB.getElementType()
625  << ", it is not supported yet";
626  }
627 
628  return success();
629 }
630 
631 LogicalResult WarpgroupMmaStoreOp::verify() {
632  MemRefType dstMemrefType = getDstMemref().getType();
633  VectorType vtype = getMatrixD().getType().getFragmented();
634 
635  // Limitation
636  if (!vtype.getElementType().isF32()) {
637  return emitOpError()
638  << "hit a limitation: only f32 results for the time being";
639  }
640  if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
641  vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
642  return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
643  << "] values. However, destination memref["
644  << dstMemrefType.getDimSize(0) << "]["
645  << dstMemrefType.getDimSize(1)
646  << "] does not have same size as results";
647  }
648  return success();
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // WarpgroupMmaInitAccumulatorOp
653 //===----------------------------------------------------------------------===//
654 
656  WarpgroupAccumulatorType accType = getMatrixC().getType();
657  int64_t sizeM = accType.getFragmented().getDimSize(0);
658  int64_t sizeN = accType.getFragmented().getDimSize(1);
659  Type elemType = accType.getFragmented().getElementType();
660 
661  if (failed(isAllowedSizeM(sizeM)) ||
662  failed(isAllowedSizeN(sizeN, elemType))) {
663  return emitOpError() << "has type " << accType.getFragmented()
664  << ". It does not fit into warp-group "
665  "level (wgmma) matrix multiplication instruction "
666  "(or not supported yet)";
667  }
668  return success();
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // RcpOp
673 //===----------------------------------------------------------------------===//
674 
675 LogicalResult RcpOp::verify() {
676  RcpRoundingModeAttr rounding = getRoundingAttr();
677  bool ftz = getFtz();
678  // Currently, only `rcp_approx` and `ftz` is supported.
679  if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
680  return emitOpError() << "has a limitation. " << rounding
681  << " or non-ftz is not supported yet.";
682  }
683  return success();
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // TableGen'd dialect, type, and op definitions
688 //===----------------------------------------------------------------------===//
689 
690 #define GET_ATTRDEF_CLASSES
691 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
692 
693 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
694 
695 #define GET_OP_CLASSES
696 #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
697 
698 #define GET_TYPEDEF_CLASSES
699 #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
union mlir::linalg::@1252::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
LogicalResult isAllowedSizeM(int sizeM)
static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue< VectorType > matrixA, TypedValue< VectorType > matrixB, TypedValue< VectorType > matrixC, const std::array< int64_t, 3 > &mmaShape, bool tf32Enabled, bool sparse=false)
Performs verification for MmaSyncOp and MmaSparseSyncOp.
LogicalResult isAllowedSizeN(int sizeN, Type typeA)
LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB)
std::optional< InFlightDiagnostic > verifyTmaDescriptorWithMemref(Operation *op, TensorMapDescriptorType descType, std::optional< MemRefType > memrefType=std::nullopt)
static unsigned getSwizzleBytes(TensorMapSwizzleKind kind)
constexpr unsigned kTMALastdimByte
The bytes in the last dimension of the tensor map must be a multiple of 16.
Definition: NVGPUDialect.h:50
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
constexpr unsigned kMaxTMATensorDimension
Maximum TMA tile dimension (tensorRank) must be non-zero and less than or equal to the maximum suppor...
Definition: NVGPUDialect.h:44
constexpr unsigned kMaxTMADimension
Maximum TMA tile size (boxDim), which specifies number of elements to be traversed along each of the ...
Definition: NVGPUDialect.h:48
Attributes are known-constant values of operations.
Definition: Attributes.h:25
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:281
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.