MLIR  22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cassert>
39 #include <optional>
40 #include <string>
41 
42 using namespace mlir;
43 using namespace NVVM;
44 
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
47 
48 static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49 
50 //===----------------------------------------------------------------------===//
51 // Verifier methods
52 //===----------------------------------------------------------------------===//
53 
54 // This verifier is shared among the following Ops:
55 // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
56 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
57 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
58  bool isIm2Col,
59  size_t numIm2ColOffsets,
60  Location loc) {
61  if (tensorDims < 1 || tensorDims > 5)
62  return emitError(loc, "expects coordinates between 1 to 5 dimension");
63 
64  // For Im2Col mode, there are two constraints:
65  if (isIm2Col) {
66  // 1. Tensor must always be at least 3-d.
67  if (tensorDims < 3)
68  return emitError(
69  loc,
70  "to use im2col mode, the tensor has to be at least 3-dimensional");
71  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
72  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
73  return emitError(
74  loc, "im2col offsets must be 2 less than number of coordinates");
75  }
76  return success();
77 }
78 
80  TMAStoreMode mode = getMode();
81  // We lower through inline-ptx when getPredicate() is true.
82  // a) Only TILE mode is supported
83  // b) Cache-hint is not supported
84  if (getPredicate()) {
85  if (mode != TMAStoreMode::TILE)
86  return emitError("Inline-ptx lowering supported only for Tile mode.");
87  if (getL2CacheHint())
88  return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
89  }
90 
91  size_t dims = getCoordinates().size();
92  switch (mode) {
93  case TMAStoreMode::TILE:
94  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
95  case TMAStoreMode::IM2COL:
96  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
97  case TMAStoreMode::TILE_SCATTER4:
98  if (dims != 5)
99  return emitError("Scatter4 mode expects 5 coordinates");
100  }
101  return success();
102 }
103 
104 LogicalResult CpAsyncOp::verify() {
105  if (getModifier() != LoadCacheModifierKind::CG &&
106  getModifier() != LoadCacheModifierKind::CA)
107  return emitError("Only CG and CA cache modifiers are supported.");
108  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
109  return emitError("expected byte size to be either 4, 8 or 16.");
110  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
111  return emitError("CG cache modifier is only support for 16 bytes copy.");
112  return success();
113 }
114 
115 // This verify params can be shared across TMA Load and Prefetch Ops.
116 static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
117  TMALoadMode mode, Location loc) {
118  if (tensorDims < 1 || tensorDims > 5)
119  return emitError(loc, "expects coordinates between 1 to 5 dimension");
120 
121  auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
122  size_t expectedIm2colOff) -> LogicalResult {
123  if (isIm2col && (tensorDims < 3))
124  return emitError(loc)
125  << "to use " << stringifyEnum(mode)
126  << " mode, the tensor has to be at least 3-dimensional";
127 
128  if (numIm2colOff != expectedIm2colOff)
129  return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
130  << " (provided " << numIm2colOff << ")";
131 
132  return success();
133  };
134 
135  switch (mode) {
136  case TMALoadMode::TILE:
137  return checkTMALoadParams(mode, false, 0);
138  case TMALoadMode::IM2COL:
139  return checkTMALoadParams(mode, true, tensorDims - 2);
140  case TMALoadMode::IM2COL_W:
141  case TMALoadMode::IM2COL_W_128:
142  return checkTMALoadParams(mode, true, 2);
143  case TMALoadMode::TILE_GATHER4:
144  return (tensorDims == 5)
145  ? checkTMALoadParams(mode, false, 0)
146  : emitError(loc, "Gather4 mode expects 5 coordinates");
147  }
148  return success();
149 }
150 
151 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
152  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
153  getMode(), getLoc());
154 }
155 
157  TMALoadMode mode = getMode();
158  bool isCTAOnly = getIsCTAOnly();
159  if (getPredicate()) { // Inline-asm based lowering
160  if (isCTAOnly)
161  return emitError("Predicate is supported only for shared::cluster mode.");
162  if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163  return emitError(
164  "Predicate is supported only for Tile and Im2col modes.");
165  } else { // Intrinsics-based lowering
166  NVVMMemorySpace expectedAS =
167  isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168  unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
169  .getAddressSpace();
170  if (AS != expectedAS)
171  return emitError()
172  << (isCTAOnly
173  ? "Shared::cta destination requires address-space 3."
174  : "Shared::cluster destination requires address-space 7.");
175  // Checks specific to shared::cta mode
176  if (isCTAOnly) {
177  if (getMulticastMask())
178  return emitError("Multicast is not supported with shared::cta mode.");
179  if (getGroup())
180  return emitError("CTAGroup is not supported with shared::cta mode.");
181  }
182  }
183 
184  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
185  getMode(), getLoc());
186 }
187 
188 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
189  TMAStoreMode mode = getMode();
190  size_t dims = getCoordinates().size();
191  switch (mode) {
192  case TMAStoreMode::TILE:
193  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
194  case TMAStoreMode::IM2COL:
195  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
196  case TMAStoreMode::TILE_SCATTER4:
197  return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
198  }
199  return success();
200 }
201 
202 LogicalResult ConvertFloatToTF32Op::verify() {
203  using RndMode = NVVM::FPRoundingMode;
204  switch (getRnd()) {
205  case RndMode::RNA:
206  if (getRelu())
207  return emitError("Relu not supported with rna rounding mode.");
208  break;
209  case RndMode::RN:
210  case RndMode::RZ:
211  break;
212  default:
213  return emitError(
214  "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
215  }
216  return success();
217 }
218 
219 LogicalResult ConvertF32x2ToF6x2Op::verify() {
220  mlir::MLIRContext *ctx = getContext();
221 
222  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
223  return emitOpError("Only ")
224  << mlir::Float6E2M3FNType::get(ctx) << " and "
226  << " types are supported for conversions from f32x2 to f6x2.";
227  }
228  return success();
229 }
230 
231 LogicalResult ConvertF32x2ToF8x2Op::verify() {
232  using RndMode = NVVM::FPRoundingMode;
233  using SatMode = NVVM::SaturationMode;
234 
235  bool isRoundingModeRN = getRnd() == RndMode::RN;
236  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
237  bool isRoundingModeRP = getRnd() == RndMode::RP;
238  bool isSatFinite = getSat() == SatMode::SATFINITE;
239 
240  bool hasRelu = getRelu();
241 
242  mlir::MLIRContext *ctx = getContext();
243 
245  .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
246  [&](mlir::Type) -> LogicalResult {
247  if (!isRoundingModeRN) {
248  return emitOpError("Only RN rounding mode is supported for "
249  "conversions from f32x2 to ")
250  << mlir::Float8E4M3FNType::get(ctx) << " and "
251  << mlir::Float8E5M2Type::get(ctx) << " types";
252  }
253  if (!isSatFinite) {
254  return emitOpError("Only SATFINITE saturation mode is supported "
255  "for conversions "
256  "from f32x2 to ")
257  << mlir::Float8E4M3FNType::get(ctx) << " and "
258  << mlir::Float8E5M2Type::get(ctx) << " types";
259  }
260  return success();
261  })
262  .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
263  if (!(isRoundingModeRZ || isRoundingModeRP)) {
264  return emitOpError("Only RZ and RP rounding modes are supported for "
265  "conversions from f32x2 to ")
266  << mlir::Float8E8M0FNUType::get(ctx) << " type";
267  }
268  if (hasRelu) {
269  return emitOpError("relu not supported for conversions to ")
270  << mlir::Float8E8M0FNUType::get(ctx) << " type";
271  }
272  return success();
273  })
274  .Default([&](mlir::Type) {
275  return emitOpError("Only ")
276  << mlir::Float8E4M3FNType::get(ctx) << ", "
277  << mlir::Float8E5M2Type::get(ctx) << ", and "
279  << " types are "
280  "supported for conversions from f32x2 to f8x2";
281  });
282 }
283 
284 LogicalResult ConvertF16x2ToF8x2Op::verify() {
285  mlir::MLIRContext *ctx = getContext();
286 
287  if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
288  return emitOpError("Only ")
289  << mlir::Float8E4M3FNType::get(ctx) << " and "
291  << " types are supported for conversions from f16x2 to f8x2.";
292  }
293  return success();
294 }
295 
296 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
297  using RndMode = NVVM::FPRoundingMode;
298 
299  if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
300  return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
301  << " type is supported for conversions from "
302  "bf16x2 to f8x2.";
303 
304  auto rnd = getRnd();
305  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
306  return emitOpError("Only RZ and RP rounding modes are supported for "
307  "conversions from bf16x2 to f8x2.");
308 
309  return success();
310 }
311 
312 LogicalResult BulkStoreOp::verify() {
313  if (getInitVal() != 0)
314  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
315  return success();
316 }
317 
318 LogicalResult PMEventOp::verify() {
319  auto eventId = getEventId();
320  auto maskedEventId = getMaskedEventId();
321  if (!maskedEventId && !eventId) {
322  return emitOpError() << "either `id` or `mask` must be set";
323  }
324 
325  if (maskedEventId && eventId) {
326  return emitOpError() << "`id` and `mask` cannot be set at the same time";
327  }
328 
329  if (eventId) {
330  if (eventId < 0 || eventId > 15) {
331  return emitOpError() << "`id` must be between 0 and 15";
332  }
333  }
334 
335  return llvm::success();
336 }
337 
338 // Given the element type of an operand and whether or not it is an accumulator,
339 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
340 // operand's element type.
341 std::optional<mlir::NVVM::MMATypes>
342 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
343  auto half2Type =
344  VectorType::get(2, Float16Type::get(operandElType.getContext()));
345  if (operandElType.isF64())
346  return NVVM::MMATypes::f64;
347  if (operandElType.isF16() || operandElType == half2Type)
348  return NVVM::MMATypes::f16;
349  if (operandElType.isF32() && isAccumulator)
350  return NVVM::MMATypes::f32;
351  if (operandElType.isF32() && !isAccumulator)
352  return NVVM::MMATypes::tf32;
353  if (llvm::isa<IntegerType>(operandElType)) {
354  if (isAccumulator)
355  return NVVM::MMATypes::s32;
356  return std::nullopt;
357  }
358 
359  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
360  if (structType.getBody().empty())
361  return std::nullopt;
362  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
363  }
364 
365  return std::nullopt;
366 }
367 
368 static bool isInt4PtxType(MMATypes type) {
369  return (type == MMATypes::u4 || type == MMATypes::s4);
370 }
371 
372 static bool isInt8PtxType(MMATypes type) {
373  return (type == MMATypes::u8 || type == MMATypes::s8);
374 }
375 
376 static bool isIntegerPtxType(MMATypes type) {
377  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
378  type == MMATypes::s32;
379 }
380 
381 MMATypes MmaOp::accumPtxType() {
382  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
383  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
384  assert(val.has_value() && "accumulator PTX type should always be inferrable");
385  return val.value();
386 }
387 
388 MMATypes MmaOp::resultPtxType() {
389  std::optional<mlir::NVVM::MMATypes> val =
390  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
391  assert(val.has_value() && "result PTX type should always be inferrable");
392  return val.value();
393 }
394 
395 void MmaOp::print(OpAsmPrinter &p) {
396  SmallVector<Type, 4> regTypes;
397  struct OperandFragment {
398  StringRef operandName;
399  StringRef ptxTypeAttr;
401  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
402  : operandName(name), ptxTypeAttr(ptxTypeName) {}
403  };
404 
405  std::array<OperandFragment, 3> frags{
406  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
407  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
408  OperandFragment("C", "")};
409  SmallVector<StringRef, 4> ignoreAttrNames{
410  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
411 
412  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
413  auto &frag = frags[fragIdx];
414  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
415  for (auto operandIdx = varOperandSpec.first;
416  operandIdx < varOperandSpec.first + varOperandSpec.second;
417  operandIdx++) {
418  frag.regs.push_back(this->getOperand(operandIdx));
419  if (operandIdx == 0) {
420  regTypes.push_back(this->getOperand(operandIdx).getType());
421  }
422  }
423  std::optional<MMATypes> inferredType =
424  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
425  if (inferredType)
426  ignoreAttrNames.push_back(frag.ptxTypeAttr);
427  }
428 
429  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
430  p << " " << frag.operandName;
431  p << "[";
432  p.printOperands(frag.regs);
433  p << "] ";
434  };
435 
436  for (const auto &frag : frags) {
437  printMmaOperand(frag);
438  }
439 
440  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
441 
442  // Print the types of the operands and result.
443  p << " : "
444  << "(";
445  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
446  frags[1].regs[0].getType(),
447  frags[2].regs[0].getType()},
448  p);
449  p << ")";
450  p.printArrowTypeList(TypeRange{this->getRes().getType()});
451 }
452 
453 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
454  ValueRange operandA, ValueRange operandB, ValueRange operandC,
455  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
456  std::optional<MMAIntOverflow> intOverflow,
457  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
458  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
459 
460  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
461  MLIRContext *ctx = builder.getContext();
462  result.addAttribute(
463  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
464 
465  result.addOperands(operandA);
466  result.addOperands(operandB);
467  result.addOperands(operandC);
468 
469  if (multiplicandPtxTypes) {
470  result.addAttribute("multiplicandAPtxType",
471  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
472  result.addAttribute("multiplicandBPtxType",
473  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
474  } else {
475  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
476  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
477  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
478  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
479  }
480 
481  if (multiplicandLayouts) {
482  result.addAttribute("layoutA",
483  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
484  result.addAttribute("layoutB",
485  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
486  } else {
487  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
488  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
489  }
490 
491  if (intOverflow.has_value())
492  result.addAttribute("intOverflowBehavior",
493  MMAIntOverflowAttr::get(ctx, *intOverflow));
494  if (b1Op.has_value())
495  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
496 
497  result.addTypes(resultType);
498  result.addAttribute(
499  MmaOp::getOperandSegmentSizeAttr(),
500  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
501  static_cast<int32_t>(operandB.size()),
502  static_cast<int32_t>(operandC.size())}));
503 }
504 
505 // <operation> :=
506 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
507 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
508 // `->` type($res)
509 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
510  struct OperandFragment {
511  std::optional<MMATypes> elemtype;
513  SmallVector<Type> regTypes;
514  };
515 
516  Builder &builder = parser.getBuilder();
517  std::array<OperandFragment, 4> frags;
518 
519  NamedAttrList namedAttributes;
520 
521  // A helper to parse the operand segments.
522  auto parseMmaOperand = [&](StringRef operandName,
523  OperandFragment &frag) -> LogicalResult {
524  if (parser.parseKeyword(operandName).failed())
525  return failure();
526  if (parser
527  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
528  .failed())
529  return failure();
530  return success();
531  };
532 
533  // Parse the operand segments.
534  if (parseMmaOperand("A", frags[0]).failed())
535  return failure();
536  if (parseMmaOperand("B", frags[1]).failed())
537  return failure();
538  if (parseMmaOperand("C", frags[2]).failed())
539  return failure();
540 
541  if (parser.parseOptionalAttrDict(namedAttributes).failed())
542  return failure();
543 
544  // Parse the type specification and resolve operands.
545  SmallVector<Type, 3> operandTypes;
546  if (failed(parser.parseColon()))
547  return failure();
548  if (failed(parser.parseLParen()))
549  return failure();
550  if (failed(parser.parseTypeList(operandTypes)))
551  return failure();
552  if (failed(parser.parseRParen()))
553  if (operandTypes.size() != 3)
554  return parser.emitError(
555  parser.getNameLoc(),
556  "expected one type for each operand segment but got " +
557  Twine(operandTypes.size()) + " types");
558  for (const auto &iter : llvm::enumerate(operandTypes)) {
559  auto &frag = frags[iter.index()];
560  frag.regTypes.resize(frag.regs.size(), iter.value());
561  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
562  parser.getNameLoc(), result.operands)))
563  return failure();
564  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
565  /*isAccumulator*/ iter.index() < 2);
566  }
567 
568  Type resultType;
569  if (parser.parseArrow() || parser.parseType(resultType))
570  return failure();
571  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
572 
573  std::array<StringRef, 2> names{"multiplicandAPtxType",
574  "multiplicandBPtxType"};
575  for (unsigned idx = 0; idx < names.size(); idx++) {
576  const auto &frag = frags[idx];
577  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
578  if (!frag.elemtype.has_value() && !attr.has_value()) {
579  return parser.emitError(
580  parser.getNameLoc(),
581  "attribute " + names[idx] +
582  " is not provided explicitly and cannot be inferred");
583  }
584  if (!attr.has_value())
585  result.addAttribute(
586  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
587  }
588 
589  result.addTypes(resultType);
590  if (!namedAttributes.empty())
591  result.addAttributes(namedAttributes);
592  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
593  builder.getDenseI32ArrayAttr({
594  static_cast<int32_t>(frags[0].regs.size()),
595  static_cast<int32_t>(frags[1].regs.size()),
596  static_cast<int32_t>(frags[2].regs.size()),
597  }));
598  return success();
599 }
600 
601 LogicalResult MmaOp::verify() {
602  MLIRContext *context = getContext();
603  auto f16Ty = Float16Type::get(context);
604  auto i32Ty = IntegerType::get(context, 32);
605  auto f16x2Ty = VectorType::get(2, f16Ty);
606  auto f32Ty = Float32Type::get(context);
607  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
608  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
609 
610  auto s32x4StructTy =
611  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
612  auto f32x8StructTy =
613  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
614  auto f16x2x2StructTy =
615  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
616  auto f32x4StructTy =
617  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
618  auto s32x2StructTy =
619  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
620 
621  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
622  getShapeAttr().getK()};
623 
624  // These variables define the set of allowed data types for matrices A, B, C,
625  // and result.
626  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
627  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
628  AllowedShapes allowedShapes;
629  AllowedTypes expectedA;
630  AllowedTypes expectedB;
631  AllowedTypes expectedC;
632  SmallVector<Type> expectedResult;
633 
634  // When M = 16, we just need to calculate the number of 8xk tiles, where
635  // k is a factor that depends on the data type.
636  if (mmaShape[0] == 16) {
637  int64_t kFactor;
638  Type multiplicandFragType;
639  switch (*getMultiplicandAPtxType()) {
640  case MMATypes::tf32:
641  kFactor = 4;
642  multiplicandFragType = i32Ty;
643  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
644  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
645  break;
646  case MMATypes::bf16:
647  kFactor = 8;
648  multiplicandFragType = i32Ty;
649  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
650  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
651  break;
652  case MMATypes::f16:
653  kFactor = 8;
654  multiplicandFragType = f16x2Ty;
655  expectedResult.push_back(f16x2x2StructTy);
656  expectedResult.push_back(f32x4StructTy);
657  break;
658  case MMATypes::s4:
659  case MMATypes::u4:
660  kFactor = 32;
661  break;
662  case MMATypes::b1:
663  kFactor = 128;
664  break;
665  case MMATypes::s8:
666  case MMATypes::u8:
667  kFactor = 16;
668  break;
669  default:
670  return emitError("invalid shape or multiplicand type: " +
671  stringifyEnum(getMultiplicandAPtxType().value()));
672  }
673 
674  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
675  expectedResult.push_back(s32x4StructTy);
676  expectedC.emplace_back(4, i32Ty);
677  multiplicandFragType = i32Ty;
678  } else {
679  expectedC.emplace_back(2, f16x2Ty);
680  expectedC.emplace_back(4, f32Ty);
681  }
682 
683  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
684  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
685  expectedA.emplace_back(unitA, multiplicandFragType);
686  expectedB.emplace_back(unitB, multiplicandFragType);
687  allowedShapes.push_back({16, 8, kFactor});
688  allowedShapes.push_back({16, 8, kFactor * 2});
689 
690  if (resultPtxType() != accumPtxType())
691  return emitOpError("ctype does not match dtype");
692  }
693 
694  // In the M=8 case, there is only 1 possible case per data type.
695  if (mmaShape[0] == 8) {
696  if (*getMultiplicandAPtxType() == MMATypes::f16) {
697  expectedA.emplace_back(2, f16x2Ty);
698  expectedB.emplace_back(2, f16x2Ty);
699  expectedResult.push_back(f16x2x4StructTy);
700  expectedResult.push_back(f32x8StructTy);
701  expectedC.emplace_back(4, f16x2Ty);
702  expectedC.emplace_back(8, f32Ty);
703  allowedShapes.push_back({8, 8, 4});
704  }
705  if (*getMultiplicandAPtxType() == MMATypes::f64) {
706  Type f64Ty = Float64Type::get(context);
707  expectedA.emplace_back(1, f64Ty);
708  expectedB.emplace_back(1, f64Ty);
709  expectedC.emplace_back(2, f64Ty);
710  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
711  context, SmallVector<Type>(2, f64Ty)));
712  allowedShapes.push_back({8, 8, 4});
713  }
714  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
715  expectedA.push_back({i32Ty});
716  expectedB.push_back({i32Ty});
717  expectedC.push_back({i32Ty, i32Ty});
718  expectedResult.push_back(s32x2StructTy);
719  if (isInt4PtxType(getMultiplicandAPtxType().value()))
720  allowedShapes.push_back({8, 8, 32});
721  if (isInt8PtxType(getMultiplicandAPtxType().value()))
722  allowedShapes.push_back({8, 8, 16});
723  if (getMultiplicandAPtxType().value() == MMATypes::b1)
724  allowedShapes.push_back({8, 8, 128});
725  }
726  }
727 
728  std::string errorMessage;
729  llvm::raw_string_ostream errorStream(errorMessage);
730 
731  // Check that we matched an existing shape/dtype combination.
732  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
733  !llvm::is_contained(allowedShapes, mmaShape)) {
734  errorStream << "unimplemented variant for MMA shape <";
735  llvm::interleaveComma(mmaShape, errorStream);
736  errorStream << ">";
737  return emitOpError(errorMessage);
738  }
739 
740  // Verify the operand types for segments of A, B, and C operands.
741  std::array<StringRef, 3> operandNames{"A", "B", "C"};
742  for (const auto &iter : llvm::enumerate(
743  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
744  auto spec = this->getODSOperandIndexAndLength(iter.index());
745  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
746  operand_type_begin() + spec.first +
747  spec.second);
748  bool match = llvm::is_contained(iter.value(), operandTySeg);
749 
750  if (!match) {
751  errorStream << "Could not match types for the "
752  << operandNames[iter.index()]
753  << " operands; expected one of ";
754  for (const auto &x : iter.value()) {
755  errorStream << x.size() << "x" << x[0] << " ";
756  }
757  errorStream << "but got ";
758  llvm::interleaveComma(operandTySeg, errorStream);
759  return emitOpError(errorMessage);
760  }
761  }
762 
763  // Check the result type
764  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
765  return expectedResultType == getResult().getType();
766  })) {
767  errorStream
768  << "Could not match allowed types for the result; expected one of ";
769  llvm::interleaveComma(expectedResult, errorStream);
770  errorStream << " but got " << getResult().getType();
771  return emitOpError(errorMessage);
772  }
773 
774  // Ensure that binary MMA variants have a b1 MMA operation defined.
775  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
776  return emitOpError("op requires " + getB1OpAttrName().strref() +
777  " attribute");
778  }
779 
780  // Ensure int4/int8 MMA variants specify the accum overflow behavior
781  // attribute.
782  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
783  isInt8PtxType(*getMultiplicandAPtxType())) {
784  if (!getIntOverflowBehavior())
785  return emitOpError("op requires " +
786  getIntOverflowBehaviorAttrName().strref() +
787  " attribute");
788  }
789 
790  return success();
791 }
792 
793 LogicalResult ShflOp::verify() {
794  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
795  return success();
796  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
797  auto elementType = (type && type.getBody().size() == 2)
798  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
799  : nullptr;
800  if (!elementType || elementType.getWidth() != 1)
801  return emitError("expected return type to be a two-element struct with "
802  "i1 as the second element");
803  return success();
804 }
805 
806 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
807  NVVM::MMAFrag frag, int nRow,
808  int nCol,
809  MLIRContext *context) {
810  unsigned numberElements = 0;
811  Type elementType;
812  OpBuilder builder(context);
813  Type f16x2 = VectorType::get(2, builder.getF16Type());
814  if (type == NVVM::MMATypes::f16) {
815  elementType = f16x2;
816  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
817  numberElements = 8;
818  else
819  numberElements = 4;
820  } else if (type == NVVM::MMATypes::f32) {
821  elementType = builder.getF32Type();
822  numberElements = 8;
823  } else if (type == NVVM::MMATypes::tf32) {
824  elementType = builder.getI32Type();
825  numberElements = 4;
826  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
827  elementType = builder.getI32Type();
828  int parallelSize = 0;
829  if (frag == NVVM::MMAFrag::a)
830  parallelSize = nRow;
831  if (frag == NVVM::MMAFrag::b)
832  parallelSize = nCol;
833 
834  // m == 16 && n == 16 && k == 16
835  if (parallelSize == 16)
836  numberElements = 2;
837  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
838  else if (parallelSize == 8)
839  numberElements = 1;
840  else if (parallelSize == 32)
841  numberElements = 4;
842  } else if (type == NVVM::MMATypes::s32) {
843  elementType = builder.getI32Type();
844  numberElements = 8;
845  }
846  assert(numberElements != 0 && elementType != nullptr);
847  return std::make_pair(elementType, numberElements);
848 }
849 
850 static std::pair<mlir::Type, unsigned>
851 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
852  int k, MLIRContext *context) {
853  int nRow, nCol;
854  if (frag == NVVM::MMAFrag::a) {
855  nRow = m;
856  nCol = k;
857  } else if (frag == NVVM::MMAFrag::b) {
858  nRow = k;
859  nCol = n;
860  } else {
861  nRow = m;
862  nCol = n;
863  }
864  assert(nRow && nCol);
865  return inferMMAType(type, frag, nRow, nCol, context);
866 }
867 
868 LogicalResult NVVM::WMMALoadOp::verify() {
869  unsigned addressSpace =
870  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
871  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
872  addressSpace != NVVMMemorySpace::Shared)
873  return emitOpError("expected source pointer in memory "
874  "space 0, 1, 3");
875 
876  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
877  getEltype(), getFrag()) == 0)
878  return emitOpError() << "invalid attribute combination";
879  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
880  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
881  Type dstType = LLVM::LLVMStructType::getLiteral(
882  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
883  if (getType() != dstType)
884  return emitOpError("expected destination type is a structure of ")
885  << typeInfo.second << " elements of type " << typeInfo.first;
886  return success();
887 }
888 
889 LogicalResult NVVM::WMMAStoreOp::verify() {
890  unsigned addressSpace =
891  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
892  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
893  addressSpace != NVVMMemorySpace::Shared)
894  return emitOpError("expected operands to be a source pointer in memory "
895  "space 0, 1, 3");
896 
897  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
898  getEltype()) == 0)
899  return emitOpError() << "invalid attribute combination";
900  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
901  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
902  if (getArgs().size() != typeInfo.second)
903  return emitOpError() << "expected " << typeInfo.second << " data operands";
904  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
905  return operands.getType() != typeInfo.first;
906  }))
907  return emitOpError() << "expected data operands of type " << typeInfo.first;
908  return success();
909 }
910 
911 LogicalResult NVVM::WMMAMmaOp::verify() {
912  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
913  getLayoutB(), getEltypeA(),
914  getEltypeB()) == 0)
915  return emitOpError() << "invalid attribute combination";
916  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
917  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
918  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
919  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
920  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
921  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
922  SmallVector<Type, 32> arguments;
923  arguments.append(typeInfoA.second, typeInfoA.first);
924  arguments.append(typeInfoB.second, typeInfoB.first);
925  arguments.append(typeInfoC.second, typeInfoC.first);
926  unsigned numArgs = arguments.size();
927  if (getArgs().size() != numArgs)
928  return emitOpError() << "expected " << numArgs << " arguments";
929  for (unsigned i = 0; i < numArgs; i++) {
930  if (getArgs()[i].getType() != arguments[i])
931  return emitOpError() << "expected argument " << i << " to be of type "
932  << arguments[i];
933  }
934  Type dstType = LLVM::LLVMStructType::getLiteral(
935  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
936  if (getType() != dstType)
937  return emitOpError("expected destination type is a structure of ")
938  << typeInfoC.second << " elements of type " << typeInfoC.first;
939  return success();
940 }
941 
942 LogicalResult NVVM::LdMatrixOp::verify() {
943  uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
944  if (m == 8 && n == 8) {
945  if (num != 1 && num != 2 && num != 4) {
946  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
947  "matrix");
948  }
949  if (getEltType() != LdStMatrixEltType::B16) {
950  return emitOpError("expected element type to be b16 for 8x8 matrix");
951  }
952  } else if (m == 8 && n == 16) {
953  if (num != 1 && num != 2 && num != 4) {
954  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
955  "matrix");
956  }
957  if (getLayout() != MMALayout::row) {
958  return emitOpError("expected layout to be row for 8x16 matrix");
959  }
960  if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
961  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
962  return emitOpError("expected element type to be b8x16.b4x16_p64 or "
963  "b8x16.b6x16_p32 for 8x16 matrix");
964  }
965  } else if (m == 16 && n == 16) {
966  if (num != 1 && num != 2) {
967  return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
968  "matrix");
969  }
970  if (getLayout() != MMALayout::col) {
971  return emitOpError("expected layout to be col for 16x16 matrix");
972  }
973  if (getEltType() != LdStMatrixEltType::B8 &&
974  getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
975  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
976  return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
977  "b8x16.b6x16_p32 for 16x16 matrix");
978  }
979  } else {
980  return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
981  }
982 
983  Type i32 = IntegerType::get(getContext(), 32);
984  uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
985  if (numElements == 1 && getType() != i32)
986  return emitOpError("expected destination type is i32");
987  if (numElements == 2 || numElements == 4) {
988  Type dstType = LLVM::LLVMStructType::getLiteral(
989  getContext(), SmallVector<Type>(numElements, i32));
990  if (getType() != dstType)
991  return emitOpError("expected destination type is a structure of ")
992  << numElements << " elements of type i32";
993  }
994 
995  return success();
996 }
997 
998 LogicalResult NVVM::StMatrixOp::verify() {
999  int numMatrix = getSources().size();
1000  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1001  return emitOpError("expected num attribute to be 1, 2 or 4");
1002 
1003  int m = getShape().getM(), n = getShape().getN();
1004  if (m == 8 && n == 8) {
1005  if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1006  return emitOpError("expected element type to be B16 for 8x8 matrix");
1007  }
1008  } else if (m == 16 && n == 8) {
1009  if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1010  return emitOpError("expected element type to be B8 for 16x8 matrix");
1011  }
1012  if (getLayout() != NVVM::MMALayout::col) {
1013  return emitOpError("expected layout to be col for 16x8 matrix");
1014  }
1015  } else {
1016  return emitOpError("expected shape to be 8x8 or 16x8");
1017  }
1018 
1019  return success();
1020 }
1021 
1022 static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
1023  if (typeA == NVVM::WGMMATypes::tf32)
1024  return 8;
1025  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1026  return 16;
1027  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1028  return 32;
1029  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1030  return 32;
1031  if (typeA == NVVM::WGMMATypes::b1)
1032  return 256;
1033  return failure();
1034 }
1035 
1036 static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
1037  NVVM::WGMMATypes typeA,
1038  NVVM::WGMMATypes typeB) {
1039  switch (typeA) {
1040  case NVVM::WGMMATypes::f16:
1041  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1042  typeB == NVVM::WGMMATypes::f16)
1043  return success();
1044  break;
1045  case NVVM::WGMMATypes::tf32:
1046  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1047  return success();
1048  break;
1049  case NVVM::WGMMATypes::u8:
1050  case NVVM::WGMMATypes::s8:
1051  if (typeD == NVVM::WGMMATypes::s32 &&
1052  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1053  return success();
1054  break;
1055  case NVVM::WGMMATypes::b1:
1056  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1057  return success();
1058  break;
1059  case NVVM::WGMMATypes::bf16:
1060  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1061  typeB == NVVM::WGMMATypes::bf16)
1062  return success();
1063  break;
1064  case NVVM::WGMMATypes::e4m3:
1065  case NVVM::WGMMATypes::e5m2:
1066  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1067  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1068  return success();
1069  break;
1070  case WGMMATypes::f32:
1071  case WGMMATypes::s32:
1072  llvm_unreachable("unsupported input types");
1073  break;
1074  }
1075  return failure();
1076 }
1077 
1078 static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
1079  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
1080  72, 80, 88, 96, 104, 112, 120, 128,
1081  136, 144, 152, 160, 168, 176, 184, 192,
1082  200, 208, 216, 224, 232, 240, 248, 256};
1083  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
1084  80, 96, 112, 128, 144, 160,
1085  176, 192, 208, 224, 240, 256};
1086  switch (typeA) {
1087  case WGMMATypes::f16:
1088  case WGMMATypes::tf32:
1089  case WGMMATypes::bf16:
1090  case WGMMATypes::e4m3:
1091  case WGMMATypes::e5m2:
1092  if (llvm::is_contained(allowedN, sizeN))
1093  return success();
1094  break;
1095  case WGMMATypes::u8:
1096  case WGMMATypes::s8:
1097  case WGMMATypes::b1:
1098  if (llvm::is_contained(allowedNshort, sizeN))
1099  return success();
1100  break;
1101  case WGMMATypes::f32:
1102  case WGMMATypes::s32:
1103  llvm_unreachable("unsupported input types");
1104  break;
1105  }
1106  return failure();
1107 }
1108 
1109 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1110  Value outValue = getResults();
1111  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
1112  if (!stype)
1113  return emitOpError() << "expected results to be struct";
1114  int outputSize = stype.getBody().size();
1115  WGMMATypes typeD = getTypeD();
1116  WGMMATypes typeA = getTypeA();
1117  WGMMATypes typeB = getTypeB();
1118 
1119  for (Type t : stype.getBody()) {
1120  if (t != stype.getBody().front())
1121  return emitOpError()
1122  << "all elements in struct must be same type but there is " << t;
1123  }
1124 
1125  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1126  typeD != WGMMATypes::s32) {
1127  return emitOpError() << "does not support the given output type "
1128  << NVVM::stringifyWGMMATypes(typeD);
1129  }
1130  if (typeD == WGMMATypes::s32 &&
1131  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1132  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
1133  }
1134 
1135  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
1136  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1137  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
1138  << NVVM::stringifyWGMMATypes(typeB)
1139  << ", it is not supported.";
1140  }
1141 
1142  // Check M
1143  if (getShape().getM() != 64)
1144  return emitOpError() << "shape 'm' must be 64";
1145 
1146  // Check K
1147  FailureOr<int> allowedK = getAllowedSizeK(typeA);
1148  if (failed(allowedK) || allowedK.value() != getShape().getK())
1149  return emitOpError() << "shape 'k' must be " << allowedK.value()
1150  << " for input type "
1151  << NVVM::stringifyWGMMATypes(typeA);
1152 
1153  // Check N
1154  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
1155  return emitOpError() << "has input type "
1156  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
1157  << getShape().getN() << ", it is not supported.";
1158  }
1159 
1160  // Check transpose (only available for f16/bf16)
1161  // Matrices A should be stored in row-major and B in column-major.
1162  // Only f16/bf16 matrices can be stored in either column-major or row-major
1163  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
1164  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1165  (getLayoutA() == mlir::NVVM::MMALayout::col ||
1166  getLayoutB() == mlir::NVVM::MMALayout::row)) {
1167  return emitOpError()
1168  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1169  << " and layout_b = " << stringifyMMALayout(getLayoutB())
1170  << " for input types " << stringifyWGMMATypes(typeA) << " and "
1171  << stringifyWGMMATypes(typeB)
1172  << " requires transpose. However, this is only supported for: "
1173  << stringifyMMATypes(MMATypes::f16) << " and "
1174  << stringifyMMATypes(MMATypes::bf16);
1175  }
1176 
1177  // Check result registers
1178  int expectedOutput = 0;
1179  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1180  expectedOutput = getShape().getN() / 2;
1181  if (typeD == WGMMATypes::f16)
1182  expectedOutput = getShape().getN() / 4;
1183  if (outputSize != expectedOutput) {
1184  return emitOpError() << "results " << expectedOutput
1185  << ", however output struct has " << outputSize
1186  << " elements";
1187  }
1188  // Check satfinite (only available for s32 accumulator)
1189  if (typeD != WGMMATypes::s32 &&
1190  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1191  NVVM::MMAIntOverflow::satfinite) {
1192  return emitOpError()
1193  << " `satfinite` can be only used with s32 accumulator, however "
1194  "the current accumulator is "
1195  << NVVM::stringifyWGMMATypes(typeD);
1196  }
1197 
1198  return success();
1199 }
1200 
1201 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1202 
1203  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1204  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1205 
1206  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1207 
1208  int expectedOutputRegisters = 0;
1209  if (getTypeD() == WGMMATypes::f16)
1210  expectedOutputRegisters = getShape().getN() / 4;
1211  else
1212  expectedOutputRegisters = getShape().getN() / 2;
1213 
1214  std::string ptx;
1215  llvm::raw_string_ostream ss(ptx);
1216 
1217  ss << "{\n"
1218  ".reg .pred p;\n"
1219  "setp.ne.b32 p, $"
1220  << ((expectedOutputRegisters * 2) + 2)
1221  << ", 0;\n"
1222  "wgmma.mma_async.sync.aligned.m"
1223  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1224  << stringifyWGMMATypes(getTypeA()) << "."
1225  << stringifyWGMMATypes(getTypeB());
1226  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1227  NVVM::MMAIntOverflow::satfinite)
1228  ss << ".satfinite";
1229  ss << " {";
1230  int regCnt = 0;
1231  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1232  ss << "$" << regCnt;
1233  if (regCnt != expectedOutputRegisters - 1)
1234  ss << ", ";
1235  }
1236 
1237  ss << "},";
1238  // Need to map read/write registers correctly.
1239  regCnt = (regCnt * 2);
1240  ss << " $" << (regCnt) << ","
1241  << " $" << (regCnt + 1) << ","
1242  << " p";
1243  if (getTypeD() != WGMMATypes::s32) {
1244  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1245  }
1246  // Don't add transpose parameters unless needed.
1247  if (isF16) {
1248  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1249  }
1250  ss << ";\n"
1251  << "}\n";
1252  return ptx;
1253 }
1254 
1255 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1256  RewriterBase &rewriter,
1257  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1258  &asmValues) {
1259  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1260  if (getResults())
1261  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1262  if (getInouts())
1263  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1264  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1265  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1266  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1268  if (getTypeD() != WGMMATypes::s32) {
1269  asmValues.push_back(
1270  {makeConstantI32(rewriter,
1271  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1273  asmValues.push_back(
1274  {makeConstantI32(rewriter,
1275  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1277  }
1278  if (isF16) {
1279  asmValues.push_back(
1280  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1282  asmValues.push_back(
1283  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1285  }
1286  return true; // Has manual mapping
1287 }
1288 
1289 LogicalResult NVVM::FenceProxyOp::verify() {
1290  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1291  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1292  if (getKind() == NVVM::ProxyKind::GENERIC)
1293  return emitOpError() << "generic proxy not a supported proxy kind";
1294  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1295  return emitOpError() << "async_shared fence requires space attribute";
1296  }
1297  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1298  return emitOpError() << "only async_shared fence can have space attribute";
1299  }
1300  return success();
1301 }
1302 
1303 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1304  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1305  return emitOpError("uni-directional proxies only support generic for "
1306  "from_proxy attribute");
1307 
1308  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1309  return emitOpError("uni-directional proxies only support tensormap "
1310  "for to_proxy attribute");
1311 
1312  return success();
1313 }
1314 
1315 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1316  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1317  return emitOpError("uni-directional proxies only support generic for "
1318  "from_proxy attribute");
1319 
1320  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1321  return emitOpError("uni-directional proxies only support tensormap "
1322  "for to_proxy attribute");
1323 
1324  return success();
1325 }
1326 
1327 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1328  if (getRegCount() % 8)
1329  return emitOpError("new register size must be multiple of 8");
1330  if (getRegCount() < 24 || getRegCount() > 256)
1331  return emitOpError("new register size must be in between 24 to 256");
1332  return success();
1333 }
1334 
1335 LogicalResult NVVM::BarrierOp::verify() {
1336  if (getNumberOfThreads() && !getBarrierId())
1337  return emitOpError(
1338  "barrier id is missing, it should be set between 0 to 15");
1339  return success();
1340 }
1341 
1342 LogicalResult NVVM::Tcgen05CpOp::verify() {
1343  auto mc = getMulticast();
1344 
1345  using SH = Tcgen05CpShape;
1346  using MC = Tcgen05CpMulticast;
1347  switch (getShape()) {
1348  case SH::SHAPE_128x256b:
1349  case SH::SHAPE_128x128b:
1350  case SH::SHAPE_4x256b:
1351  if (mc != MC::NONE)
1352  return emitError("Invalid multicast type for tcgen05.cp Op");
1353  break;
1354  case SH::SHAPE_64x128b:
1355  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1356  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1357  "warpx2_02_13 for tcgen05.cp Op");
1358  break;
1359  case SH::SHAPE_32x128b:
1360  if (mc != MC::WARPX4)
1361  return emitError(
1362  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1363  break;
1364  }
1365  return success();
1366 }
1367 
1368 LogicalResult NVVM::MatchSyncOp::verify() {
1369  if (getKind() == NVVM::MatchSyncKind::all) {
1370  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1371  if (!type || type.getBody().size() != 2 ||
1372  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1373  return emitOpError("match.sync 'all' returns a two element struct with "
1374  "first element as i32 and second element as i1");
1375  }
1376  } else {
1377  if (!getType().isInteger(32)) {
1378  return emitOpError("match.sync 'any' returns an i32");
1379  }
1380  }
1381  return success();
1382 }
1383 
1384 LogicalResult NVVM::VoteSyncOp::verify() {
1385  if (getKind() == NVVM::VoteSyncKind::ballot) {
1386  if (!getType().isInteger(32)) {
1387  return emitOpError("vote.sync 'ballot' returns an i32");
1388  }
1389  } else {
1390  if (!getType().isInteger(1)) {
1391  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1392  }
1393  }
1394  return success();
1395 }
1396 
1397 LogicalResult NVVM::PrefetchOp::verify() {
1398  using MemSpace = NVVM::NVVMMemorySpace;
1399  using CacheLevel = NVVM::PrefetchCacheLevel;
1400 
1401  unsigned addressSpace =
1402  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1403  std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1404  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1405 
1406  if (getTensormap() && cacheLevel)
1407  return emitOpError("cannot specify both tensormap and cache level");
1408 
1409  if (getTensormap()) {
1410  if (addressSpace != MemSpace::Generic &&
1411  addressSpace != MemSpace::Constant) {
1412  return emitOpError(
1413  "prefetch tensormap requires a generic or constant pointer");
1414  }
1415 
1416  if (evictPriority) {
1417  return emitOpError(
1418  "prefetch tensormap does not support eviction priority");
1419  }
1420 
1421  if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1422  return emitOpError(
1423  "in_param_space can only be specified for a generic pointer");
1424  }
1425 
1426  } else if (cacheLevel) {
1427  if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1428  addressSpace != MemSpace::Local) {
1429  return emitOpError("prefetch to cache level requires a generic, global, "
1430  "or local pointer");
1431  }
1432 
1433  if (getUniform()) {
1434  if (*cacheLevel != CacheLevel::L1) {
1435  return emitOpError(
1436  "unsupported cache level, the only supported uniform "
1437  "cache level is L1");
1438  }
1439 
1440  if (addressSpace != MemSpace::Generic) {
1441  return emitOpError(
1442  "prefetch to uniform cache requires a generic pointer");
1443  }
1444  }
1445 
1446  if (evictPriority) {
1447  if (*cacheLevel != CacheLevel::L2)
1448  return emitOpError(
1449  "cache eviction priority supported only for cache level L2");
1450 
1451  if (addressSpace != MemSpace::Global)
1452  return emitOpError("cache eviction priority requires a global pointer");
1453 
1454  if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1455  *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1456  return emitOpError(
1457  "unsupported cache eviction priority, only evict_last and "
1458  "evict_normal are supported");
1459  }
1460 
1461  if (getPredicate())
1462  return emitOpError("predicate supported only on prefetch tensormap");
1463 
1464  } else {
1465  return emitOpError(
1466  "requires specification of either cache level or tensormap");
1467  }
1468 
1469  return success();
1470 }
1471 
1473  switch (getQueryType()) {
1474  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1475  if (!getType().isInteger(1))
1476  return emitOpError("is_canceled query type returns an i1");
1477  break;
1478  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1479  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1480  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1481  if (!getType().isInteger(32)) {
1482  return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
1483  "get_first_cta_id_z query types return an i32");
1484  }
1485  break;
1486  }
1487  return success();
1488 }
1489 
1490 /// Packs the given `field` into the `result`.
1491 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1492 static llvm::Value *
1493 packValInto64Bits(llvm::IRBuilderBase &builder,
1494  llvm::Value *result, // the `result` (unset bits are zero)
1495  llvm::Value *field, // `field` to pack into `result`
1496  unsigned sizeInBits, // Size of `field` in bits
1497  unsigned start) { // Starting bit within `result`
1498  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1499 
1500  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1501  if (mask != 0xffffffffu)
1502  field = builder.CreateAnd(field, builder.getInt32(mask));
1503 
1504  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1505  field = builder.CreateShl(field, start);
1506 
1507  return builder.CreateOr(result, field);
1508 }
1509 
1510 void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1511  LLVM::ModuleTranslation &mt,
1512  llvm::IRBuilderBase &builder) {
1513  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1514  llvm::Value *smemDesc = builder.getInt64(0);
1515 
1516  smemDesc = packValInto64Bits(builder, smemDesc,
1517  mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1518  smemDesc = packValInto64Bits(
1519  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1520  smemDesc = packValInto64Bits(
1521  builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1522 
1523  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1524  smemDesc = packValInto64Bits(builder, smemDesc,
1525  mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1526  smemDesc = packValInto64Bits(
1527  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1528  smemDesc = packValInto64Bits(builder, smemDesc,
1529  mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1530 
1531  mt.mapValue(thisOp.getRes()) = smemDesc;
1532 }
1533 
1534 //===----------------------------------------------------------------------===//
1535 // getIntrinsicID/getIntrinsicIDAndArgs methods
1536 //===----------------------------------------------------------------------===//
1537 
1538 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1539  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1540 
1541 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1542  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1543 
1545 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1548 
1549  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1550  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1551  switch (cpAsyncOp.getSize()) {
1552  case 4:
1553  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1554  break;
1555  case 8:
1556  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1557  break;
1558  case 16:
1559  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1560  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1561  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1562  break;
1563  default:
1564  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1565  }
1566 
1567  // Fill the Intrinsic Args
1568  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1569  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1570  if (hasCpSize)
1571  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1572 
1573  return id;
1574 }
1575 
1576 mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1577  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1578  auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1580  llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1581 
1582  // Fill the Intrinsic Args
1583  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1584  args.push_back(mt.lookupValue(thisOp.getSize()));
1585 
1586  mlir::Value cacheHint = thisOp.getL2CacheHint();
1587  const bool hasCacheHint = static_cast<bool>(cacheHint);
1588  llvm::Value *i64Unused =
1589  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1590  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1591  args.push_back(builder.getInt1(hasCacheHint));
1592 
1593  return {id, std::move(args)};
1594 }
1595 
1596 mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1597  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1598  auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1600 
1601  // Fill the Intrinsic Args: dst, mbar, src, size.
1602  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1603  args.push_back(mt.lookupValue(thisOp.getMbar()));
1604  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1605  args.push_back(mt.lookupValue(thisOp.getSize()));
1606 
1607  // Multicast mask, if available.
1608  mlir::Value multicastMask = thisOp.getMulticastMask();
1609  const bool hasMulticastMask = static_cast<bool>(multicastMask);
1610  llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1611  args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1612 
1613  // Cache hint, if available.
1614  mlir::Value cacheHint = thisOp.getL2CacheHint();
1615  const bool hasCacheHint = static_cast<bool>(cacheHint);
1616  llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
1617  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1618 
1619  // Flag arguments for multicast and cachehint.
1620  args.push_back(builder.getInt1(hasMulticastMask));
1621  args.push_back(builder.getInt1(hasCacheHint));
1622 
1623  llvm::Intrinsic::ID id =
1624  llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
1625 
1626  return {id, std::move(args)};
1627 }
1628 
1629 mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1630  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1631  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1633  llvm::Intrinsic::ID id =
1634  llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1635 
1636  // Fill the Intrinsic Args
1637  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1638  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1639  args.push_back(mt.lookupValue(thisOp.getSize()));
1640 
1641  mlir::Value cacheHint = thisOp.getL2CacheHint();
1642  const bool hasCacheHint = static_cast<bool>(cacheHint);
1643  llvm::Value *i64Unused =
1644  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1645  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1646  args.push_back(builder.getInt1(hasCacheHint));
1647 
1648  // Choose the bytemask variant
1649  if (mlir::Value byteMask = thisOp.getByteMask()) {
1650  args.push_back(mt.lookupValue(byteMask));
1651  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1652  }
1653 
1654  return {id, std::move(args)};
1655 }
1656 
1657 bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1658  RewriterBase &rewriter,
1659  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1660  &asmValues) {
1661  // Add all the operands but not the attrs to the asmValues list.
1662  // The attrs here are used to generate the right variants for
1663  // intrinsics-lowering. So, we ignore them while generating inline-PTX.
1664  for (auto val : getOperands())
1665  asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
1666 
1667  return false;
1668 }
1669 
1671 CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1672  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1673  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1674  const bool isCTAOnly = thisOp.getIsCTAOnly();
1676 
1677  // Fill the Intrinsic Args
1678  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1679  args.push_back(mt.lookupValue(thisOp.getMbar()));
1680  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1681 
1682  // Coordinates and im2col-offsets
1683  for (mlir::Value v : thisOp.getCoordinates())
1684  args.push_back(mt.lookupValue(v));
1685  for (mlir::Value v : thisOp.getIm2colOffsets())
1686  args.push_back(mt.lookupValue(v));
1687 
1688  // MulticastMask, if available
1689  mlir::Value mcMask = thisOp.getMulticastMask();
1690  const bool hasMC = static_cast<bool>(mcMask);
1691  llvm::Value *i16Zero =
1692  llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
1693 
1694  // CacheHint, if available
1695  mlir::Value cacheHint = thisOp.getL2CacheHint();
1696  const bool hasCacheHint = static_cast<bool>(cacheHint);
1697  llvm::Value *i64Zero =
1698  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1699 
1700  // Flag argument CTAGroup
1701  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1702  // Hence, the +1 to getGroup().
1703  const int32_t val =
1704  thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
1705  llvm::Value *cg =
1706  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
1707 
1708  if (!isCTAOnly) {
1709  // For shared::cluster, all the arguments that we build are applicable.
1710  args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1711  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1712  args.push_back(builder.getInt1(hasMC));
1713  args.push_back(builder.getInt1(hasCacheHint));
1714  args.push_back(cg);
1715  } else {
1716  // For shared::cta, only cache-hint is applicable.
1717  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1718  args.push_back(builder.getInt1(hasCacheHint));
1719  }
1720 
1721  constexpr size_t numDims = 5; // 1D to 5D
1722  constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
1723  using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1724  using TableTy = std::array<rowTy, numModes>;
1725  static constexpr TableTy IDTable{
1726  {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1727  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1728  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1729  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1730  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1732  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1733  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1734  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1736  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1737  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1738  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1740  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1741  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1742  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1744  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1745 
1746  static constexpr TableTy IDTableCTA{
1747  {{notIntrinsic,
1748  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1749  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1750  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1751  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1752  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1754  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1755  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1756  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1758  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1759  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1760  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1762  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1763  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1764  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1766  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1767 
1768  static_assert(
1769  (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1770  (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1771  "TMALoadModes must match number of rows in IDTable and IDTableCTA");
1772  size_t mode = static_cast<size_t>(thisOp.getMode());
1773  size_t dim = thisOp.getCoordinates().size();
1774  auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1775  assert(id != notIntrinsic &&
1776  "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1777 
1778  return {id, std::move(args)};
1779 }
1780 
1781 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
1782  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1783  auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1785 
1786  // Fill the Intrinsic Args
1787  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1788 
1789  for (auto v : thisOp.getCoordinates())
1790  args.push_back(mt.lookupValue(v));
1791  for (auto v : thisOp.getIm2colOffsets())
1792  args.push_back(mt.lookupValue(v));
1793 
1794  mlir::Value cacheHint = thisOp.getL2CacheHint();
1795  const bool hasCacheHint = static_cast<bool>(cacheHint);
1796  llvm::Value *i64Unused =
1797  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1798  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1799  args.push_back(builder.getInt1(hasCacheHint));
1800 
1801  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1802  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1803  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1804  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1805  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1806  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1807  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1808  {NI, NI, NI,
1809  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1810  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1811  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1812  {NI, NI, NI,
1813  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1814  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1815  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1816  {NI, NI, NI,
1817  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1818  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1819  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1820  {NI, NI, NI, NI, NI,
1821  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1822 
1823  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1824  "TMALoadModes must match number of rows in IDTable");
1825  size_t mode = static_cast<size_t>(thisOp.getMode());
1826  size_t dim = thisOp.getCoordinates().size();
1827  llvm::Intrinsic::ID id = IDTable[mode][dim];
1828  if (id == llvm::Intrinsic::not_intrinsic)
1829  llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1830 
1831  return {id, std::move(args)};
1832 }
1833 
1835 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1836  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1837  auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1839 
1840  // Fill the Intrinsic Args
1841  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1842  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1843 
1844  for (auto v : thisOp.getCoordinates())
1845  args.push_back(mt.lookupValue(v));
1846 
1847  mlir::Value cacheHint = thisOp.getL2CacheHint();
1848  const bool hasCacheHint = static_cast<bool>(cacheHint);
1849  llvm::Value *i64Unused =
1850  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1851  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1852  args.push_back(builder.getInt1(hasCacheHint));
1853 
1854  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1855  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1856  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1857  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1858  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1859  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1860  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1861  {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1862  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1863  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1864  {NI, NI, NI, NI, NI,
1865  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1866 
1867  static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
1868  "TMAStoreModes must match number of rows in IDTable");
1869  size_t mode = static_cast<size_t>(thisOp.getMode());
1870  size_t dim = thisOp.getCoordinates().size();
1871  llvm::Intrinsic::ID id = IDTable[mode][dim];
1872  if (id == llvm::Intrinsic::not_intrinsic)
1873  llvm_unreachable(
1874  "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
1875 
1876  return {id, std::move(args)};
1877 }
1878 
1879 NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
1880  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1881  auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1882  llvm::LLVMContext &ctx = mt.getLLVMContext();
1883 
1885 
1886  // Arguments to the intrinsic:
1887  // shared_mem_ptr, tmaDesc, tensorDims
1888  // cache_hint(if applicable) and flag(boolean)
1889  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1890  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1891 
1892  for (Value v : thisOp.getCoordinates())
1893  args.push_back(mt.lookupValue(v));
1894 
1895  mlir::Value cacheHint = thisOp.getL2CacheHint();
1896  const bool hasCacheHint = static_cast<bool>(cacheHint);
1897  llvm::Value *i64ZeroValue =
1898  llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
1899  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
1900  args.push_back(builder.getInt1(hasCacheHint));
1901 
1902  const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
1903 
1904  constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
1905  constexpr unsigned numLayouts = 2; // TILE, IM2COL
1906  constexpr unsigned maxDim = 5; // 1D to 5D
1907  using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
1908  using layoutTable = std::array<row, numLayouts>;
1909  using fullTable = std::array<layoutTable, numRedKinds>;
1910  static constexpr fullTable IDTable{
1911  {// RedTy::ADD
1912  {{{{notIntrinsic,
1913  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1914  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1915  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1916  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1917  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
1919  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1920  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1921  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
1922  // RedTy::MIN
1923  {{{{notIntrinsic,
1924  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1925  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1926  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1927  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1928  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
1930  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1931  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1932  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
1933  // RedTy::MAX
1934  {{{{notIntrinsic,
1935  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1936  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1937  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1938  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1939  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
1941  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1942  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1943  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
1944  // RedTy::INC
1945  {{{{notIntrinsic,
1946  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1947  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1948  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1949  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1950  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
1952  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1953  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1954  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
1955  // RedTy::DEC
1956  {{{{notIntrinsic,
1957  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1958  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1959  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1960  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1961  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
1963  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1964  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1965  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
1966  // RedTy::AND
1967  {{{{notIntrinsic,
1968  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1969  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1970  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1971  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1972  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
1974  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1975  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1976  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
1977  // RedTy::OR
1978  {{{{notIntrinsic,
1979  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1980  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1981  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1982  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1983  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
1985  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1986  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1987  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
1988  // RedTy::XOR
1989  {{{{notIntrinsic,
1990  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1991  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1992  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1993  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1994  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
1996  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1997  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1998  llvm::Intrinsic::
1999  nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2000 
2001  static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2002  "TMAReduxKinds must match number of rows in IDTable");
2003 
2004  size_t redKind = static_cast<size_t>(thisOp.getRedKind());
2005  size_t mode = static_cast<size_t>(thisOp.getMode());
2006  size_t dim = thisOp.getCoordinates().size();
2007 
2008  assert(redKind < IDTable.size() &&
2009  "Invalid redKind for CpAsyncBulkTensorReduceOp");
2010  assert(mode < IDTable[redKind].size() &&
2011  "Invalid mode for CpAsyncBulkTensorReduceOp");
2012  assert(dim < IDTable[redKind][mode].size() &&
2013  "Invalid dim for CpAsyncBulkTensorReduceOp");
2014 
2015  llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2016 
2017  assert(intrinsicID != notIntrinsic &&
2018  "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2019 
2020  return {intrinsicID, std::move(args)};
2021 }
2022 
2023 #define _none
2024 
2025 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2026  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2027  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2028 
2029 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2030  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2031  : CVT_F2TF32_ID_IMPL(rnd, relu, )
2032 
2034 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2035  NVVM::SaturationMode sat, bool hasRelu) {
2036  using RndMode = NVVM::FPRoundingMode;
2037  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2038  switch (rnd) {
2039  case RndMode::RN:
2040  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
2041  case RndMode::RZ:
2042  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
2043  case RndMode::RNA:
2044  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
2045  default:
2046  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
2047  }
2048 }
2049 
2050 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2051  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2052  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2053 
2054 llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
2055  bool hasRelu) {
2057  .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2058  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
2059  })
2060  .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2061  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
2062  })
2063  .Default([](mlir::Type) {
2064  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
2065  return llvm::Intrinsic::not_intrinsic;
2066  });
2067 }
2068 
2069 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2070  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2071  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2072 
2073 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2074  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2075  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2076 
2078 ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2079  NVVM::SaturationMode sat, bool hasRelu) {
2080  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2081  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2082  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2083 
2085  .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2086  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
2087  })
2088  .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2089  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
2090  })
2091  .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2092  if (hasRoundingModeRZ)
2093  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
2094  else if (hasRoundingModeRP)
2095  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
2096 
2097  llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2098  })
2099  .Default([](mlir::Type) {
2100  llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2101  return llvm::Intrinsic::not_intrinsic;
2102  });
2103 }
2104 
2105 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2106  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2107  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2108 
2109 llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
2110  bool hasRelu) {
2112  .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2113  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
2114  })
2115  .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2116  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
2117  })
2118  .Default([](mlir::Type) {
2119  llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
2120  return llvm::Intrinsic::not_intrinsic;
2121  });
2122 }
2123 
2124 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2125  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2126  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2127 
2129 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2130  NVVM::SaturationMode sat) {
2131  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2132  switch (rnd) {
2133  case NVVM::FPRoundingMode::RZ:
2134  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
2135  case NVVM::FPRoundingMode::RP:
2136  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
2137  default:
2138  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
2139  }
2140 }
2141 
2143 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
2144  LLVM::ModuleTranslation &mt,
2146  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2147  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2148  .getAddressSpace();
2149  bool isShared = as == NVVMMemorySpace::Shared;
2150  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2151 
2153  if (isShared) {
2154  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2155  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2156  } else {
2157  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2158  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2159  }
2160 
2161  // Fill the Intrinsic Args
2162  args.push_back(mt.lookupValue(curOp.getAddr()));
2163  args.push_back(mt.lookupValue(curOp.getNCols()));
2164 
2165  return id;
2166 }
2167 
2168 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2169  Operation &op, LLVM::ModuleTranslation &mt,
2171  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2172  auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2173  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2174  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2175 
2176  // Fill the Intrinsic Args
2177  args.push_back(mt.lookupValue(curOp.getTaddr()));
2178  args.push_back(mt.lookupValue(curOp.getNCols()));
2179 
2180  return id;
2181 }
2182 
2183 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2184  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2185  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2186 
2187 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2188  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2189  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2190 
2192 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
2193  LLVM::ModuleTranslation &mt,
2195  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2196  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2197  .getAddressSpace();
2198  bool isShared = as == NVVMMemorySpace::Shared;
2199  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
2200  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2201 
2202  llvm::Intrinsic::ID id =
2203  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
2204  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
2205 
2206  // Fill the Intrinsic Args
2207  args.push_back(mt.lookupValue(curOp.getAddr()));
2208  if (hasMulticast)
2209  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2210 
2211  return id;
2212 }
2213 
2214 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2215  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2216 
2217 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2218  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2219  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2220 
2221 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2222  [&]() -> auto { \
2223  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2224  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2225  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2226  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2227  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2228  }()
2229 
2230 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
2231  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2232  bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2233  auto srcFmt = curOp.getSrcFormat();
2234  auto mc = curOp.getMulticast();
2235 
2236  switch (curOp.getShape()) {
2237  case Tcgen05CpShape::SHAPE_128x256b:
2238  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
2239  case Tcgen05CpShape::SHAPE_128x128b:
2240  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
2241  case Tcgen05CpShape::SHAPE_4x256b:
2242  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
2243  case Tcgen05CpShape::SHAPE_32x128b:
2244  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
2245  case Tcgen05CpShape::SHAPE_64x128b:
2246  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2247  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
2248  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
2249  }
2250  llvm_unreachable("Invalid shape in tcgen05 cp Op");
2251 }
2252 
2253 // Returns the valid vector length for a given shape and vector length, the
2254 // function models the table mentioned in the tcgen05.{ld, st} Op description
2255 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
2256  unsigned vecLen) {
2257  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2258  return vecLen >= 2;
2259  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2260  return vecLen >= 4;
2261  return true;
2262 }
2263 
2264 LogicalResult Tcgen05LdOp::verify() {
2265  LogicalResult result = success();
2266  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2267  result = emitError("shape 16x32bx2 requires offset argument");
2268 
2269  auto resTy = getRes().getType();
2270  unsigned resLen = isa<VectorType>(resTy)
2271  ? llvm::cast<VectorType>(resTy).getNumElements()
2272  : 1;
2273  if (!isValidVectorLength(getShape(), resLen))
2274  result = emitError(llvm::formatv("invalid result type length {0} for shape "
2275  "{1} in tcgen05.ld Op",
2276  resLen, stringifyEnum(getShape())));
2277 
2278  return result;
2279 }
2280 
2281 LogicalResult Tcgen05StOp::verify() {
2282  LogicalResult result = success();
2283  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2284  result = emitError("shape 16x32bx2 requires offset argument");
2285 
2286  auto valTy = getVal().getType();
2287  unsigned valLen = isa<VectorType>(valTy)
2288  ? llvm::cast<VectorType>(valTy).getNumElements()
2289  : 1;
2290  if (!isValidVectorLength(getShape(), valLen))
2291  result = emitError(llvm::formatv("invalid input length {0} for shape "
2292  "{1} in tcgen05.st Op",
2293  valLen, stringifyEnum(getShape())));
2294 
2295  return result;
2296 }
2297 
2298 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
2299 /// have ConstantRangeAttr.
2300 static void nvvmInferResultRanges(Operation *op, Value result,
2302  SetIntRangeFn setResultRanges) {
2303  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
2304  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2305  rangeAttr.getLower(), rangeAttr.getUpper()});
2306  }
2307 }
2308 
2309 static llvm::Value *getAsPackedI32(llvm::Value *arg,
2310  llvm::IRBuilderBase &builder) {
2311  return builder.CreateBitCast(arg,
2312  llvm::Type::getInt32Ty(builder.getContext()));
2313 }
2314 
2315 NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
2316  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2317  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2318 
2320  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2321  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2322  args.push_back(mt.lookupValue(curOp.getC()));
2323 
2324  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2325  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2326  unsigned type = (isASigned << 1) | isBSigned;
2327  const llvm::Intrinsic::ID ids[] = {
2328  llvm::Intrinsic::nvvm_idp4a_u_u,
2329  llvm::Intrinsic::nvvm_idp4a_u_s,
2330  llvm::Intrinsic::nvvm_idp4a_s_u,
2331  llvm::Intrinsic::nvvm_idp4a_s_s,
2332  };
2333  return {ids[type], args};
2334 }
2335 
2336 NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
2337  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2338  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2339 
2341  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2342  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2343  args.push_back(builder.getInt1(curOp.getBHi()));
2344  args.push_back(mt.lookupValue(curOp.getC()));
2345 
2346  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2347  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2348  unsigned type = (isASigned << 1) | isBSigned;
2349  const llvm::Intrinsic::ID ids[] = {
2350  llvm::Intrinsic::nvvm_idp2a_u_u,
2351  llvm::Intrinsic::nvvm_idp2a_u_s,
2352  llvm::Intrinsic::nvvm_idp2a_s_u,
2353  llvm::Intrinsic::nvvm_idp2a_s_s,
2354  };
2355  return {ids[type], args};
2356 }
2357 
2358 static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2359  llvm::IRBuilderBase &builder) {
2360  return builder.CreateAddrSpaceCast(
2361  addr,
2362  llvm::PointerType::get(builder.getContext(),
2363  llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2364 }
2365 
2367 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2368  LLVM::ModuleTranslation &mt,
2369  llvm::IRBuilderBase &builder) {
2370  using MemSpace = NVVM::NVVMMemorySpace;
2371  using CacheLevel = NVVM::PrefetchCacheLevel;
2372 
2373  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2374  std::optional<NVVM::CacheEvictionPriority> evictPriority =
2375  op.getEvictPriority();
2376  unsigned addressSpace =
2377  llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2378  .getAddressSpace();
2379 
2381  llvm::Value *addr = mt.lookupValue(op.getAddr());
2382  args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
2383  : addr);
2384 
2385  if (op.getTensormap())
2386  return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2387 
2388  assert(cacheLevel && "expected cache level for non-tensormap prefetch");
2389 
2390  if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2391  return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2392 
2393  if (evictPriority && *cacheLevel == CacheLevel::L2) {
2394  switch (*evictPriority) {
2395  case NVVM::CacheEvictionPriority::EvictLast:
2396  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2397  case NVVM::CacheEvictionPriority::EvictNormal:
2398  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2399  default:
2400  llvm_unreachable("Invalid cache eviction priority");
2401  }
2402  }
2403 
2404  switch (static_cast<MemSpace>(addressSpace)) {
2405  case MemSpace::Generic:
2406  return *cacheLevel == CacheLevel::L1
2407  ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
2408  : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
2409  case MemSpace::Global:
2410  return *cacheLevel == CacheLevel::L1
2411  ? NVVM::IDArgPair(
2412  {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2413  : NVVM::IDArgPair(
2414  {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2415  case MemSpace::Local:
2416  return *cacheLevel == CacheLevel::L1
2417  ? NVVM::IDArgPair(
2418  {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2419  : NVVM::IDArgPair(
2420  {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2421  default:
2422  llvm_unreachable("Invalid pointer address space");
2423  }
2424 }
2425 
2426 bool NVVM::InlinePtxOp::getAsmValues(
2427  RewriterBase &rewriter,
2428  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2429  &asmValues) {
2430  for (auto arg : getReadWriteArgs())
2431  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
2432  for (auto arg : getResults())
2433  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
2434  for (auto arg : getReadOnlyArgs())
2435  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
2436  if (getPredicate())
2437  asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
2438  return false; // No manual mapping needed
2439 }
2440 
2441 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2442  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2443  auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2445  args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2446  args.push_back(mt.lookupValue(curOp.getMbarrier()));
2447 
2448  llvm::Intrinsic::ID intrinsicID =
2449  curOp.getMulticast()
2450  ? llvm::Intrinsic::
2451  nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2452  : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2453 
2454  return {intrinsicID, args};
2455 }
2456 
2457 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2458  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2459  auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2461  args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2462 
2463  llvm::Intrinsic::ID intrinsicID;
2464 
2465  switch (curOp.getQueryType()) {
2466  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2467  intrinsicID =
2468  llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2469  break;
2470  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2471  intrinsicID = llvm::Intrinsic::
2472  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2473  break;
2474  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2475  intrinsicID = llvm::Intrinsic::
2476  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2477  break;
2478  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2479  intrinsicID = llvm::Intrinsic::
2480  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2481  break;
2482  }
2483  return {intrinsicID, args};
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // NVVMDialect initialization, type parsing, and registration.
2488 //===----------------------------------------------------------------------===//
2489 
2490 // TODO: This should be the llvm.nvvm dialect once this is supported.
2491 void NVVMDialect::initialize() {
2492  addOperations<
2493 #define GET_OP_LIST
2494 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2495  >();
2496  addAttributes<
2497 #define GET_ATTRDEF_LIST
2498 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2499  >();
2500 
2501  // Support unknown operations because not all NVVM operations are
2502  // registered.
2503  allowUnknownOperations();
2504  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2505  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2506 }
2507 
2508 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
2509  NamedAttribute attr) {
2510  StringAttr attrName = attr.getName();
2511  // Kernel function attribute should be attached to functions.
2512  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2513  if (!isa<LLVM::LLVMFuncOp>(op)) {
2514  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
2515  << "' attribute attached to unexpected op";
2516  }
2517  }
2518  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
2519  // dim
2520  if (attrName == NVVMDialect::getMaxntidAttrName() ||
2521  attrName == NVVMDialect::getReqntidAttrName() ||
2522  attrName == NVVMDialect::getClusterDimAttrName()) {
2523  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
2524  if (!values || values.empty() || values.size() > 3) {
2525  return op->emitError()
2526  << "'" << attrName
2527  << "' attribute must be integer array with maximum 3 index";
2528  }
2529  }
2530  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
2531  // attribute
2532  if (attrName == NVVMDialect::getMinctasmAttrName() ||
2533  attrName == NVVMDialect::getMaxnregAttrName() ||
2534  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2535  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
2536  return op->emitError()
2537  << "'" << attrName << "' attribute must be integer constant";
2538  }
2539  }
2540  // blocksareclusters must be used along with reqntid and cluster_dim
2541  if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2542  if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
2543  !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
2544  return op->emitError()
2545  << "'" << attrName << "' attribute must be used along with "
2546  << "'" << NVVMDialect::getReqntidAttrName() << "' and "
2547  << "'" << NVVMDialect::getClusterDimAttrName() << "'";
2548  }
2549  }
2550 
2551  return success();
2552 }
2553 
2554 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
2555  unsigned regionIndex,
2556  unsigned argIndex,
2557  NamedAttribute argAttr) {
2558  auto funcOp = dyn_cast<FunctionOpInterface>(op);
2559  if (!funcOp)
2560  return success();
2561 
2562  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
2563  StringAttr attrName = argAttr.getName();
2564  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2565  if (!isKernel) {
2566  return op->emitError()
2567  << "'" << attrName
2568  << "' attribute must be present only on kernel arguments";
2569  }
2570  if (!isa<UnitAttr>(argAttr.getValue()))
2571  return op->emitError() << "'" << attrName << "' must be a unit attribute";
2572  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2573  return op->emitError()
2574  << "'" << attrName
2575  << "' attribute requires the argument to also have attribute '"
2576  << LLVM::LLVMDialect::getByValAttrName() << "'";
2577  }
2578  }
2579 
2580  return success();
2581 }
2582 
2583 //===----------------------------------------------------------------------===//
2584 // NVVM Address Space Attr
2585 //===----------------------------------------------------------------------===//
2586 
2587 unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
2588  return static_cast<unsigned>(getValue());
2589 }
2590 
2591 bool NVVMMemorySpaceAttr::isValidLoad(
2592  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2593  const ::mlir::DataLayout *dataLayout,
2595  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2596  dataLayout, emitError);
2597 }
2598 
2599 bool NVVMMemorySpaceAttr::isValidStore(
2600  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2601  const ::mlir::DataLayout *dataLayout,
2603  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2604  dataLayout, emitError);
2605 }
2606 
2607 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2608  ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
2609  std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2611  // TODO: update this method once `ptr.atomic_rmw` is implemented.
2612  assert(false && "unimplemented, see TODO in the source.");
2613  return false;
2614 }
2615 
2616 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2617  Type type, ptr::AtomicOrdering successOrdering,
2618  ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2619  const ::mlir::DataLayout *dataLayout,
2621  // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
2622  assert(false && "unimplemented, see TODO in the source.");
2623  return false;
2624 }
2625 
2626 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2627  Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
2628  // TODO: update this method once the `ptr.addrspace_cast` op is added to the
2629  // dialect.
2630  assert(false && "unimplemented, see TODO in the source.");
2631  return false;
2632 }
2633 
2634 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2635  Type intLikeTy, Type ptrLikeTy,
2637  // TODO: update this method once the int-cast ops are added to the `ptr`
2638  // dialect.
2639  assert(false && "unimplemented, see TODO in the source.");
2640  return false;
2641 }
2642 
2643 //===----------------------------------------------------------------------===//
2644 // NVVM target attribute.
2645 //===----------------------------------------------------------------------===//
2646 LogicalResult
2648  int optLevel, StringRef triple, StringRef chip,
2649  StringRef features, DictionaryAttr flags,
2650  ArrayAttr files, bool verifyTarget) {
2651  if (optLevel < 0 || optLevel > 3) {
2652  emitError() << "The optimization level must be a number between 0 and 3.";
2653  return failure();
2654  }
2655  if (triple.empty()) {
2656  emitError() << "The target triple cannot be empty.";
2657  return failure();
2658  }
2659  if (chip.empty()) {
2660  emitError() << "The target chip cannot be empty.";
2661  return failure();
2662  }
2663  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
2664  return mlir::isa_and_nonnull<StringAttr>(attr);
2665  })) {
2666  emitError() << "All the elements in the `link` array must be strings.";
2667  return failure();
2668  }
2669  return success();
2670 }
2671 
2672 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
2673  if (!getVerifyTarget())
2674  return success();
2675 
2676  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2677  if (!gpuModuleOp) {
2678  return emitError(gpuModule->getLoc(),
2679  "NVVM target attribute must be attached to a GPU module");
2680  }
2681 
2682  const NVVMCheckSMVersion targetSMVersion =
2683  NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2684  if (!targetSMVersion.isMinimumSMVersion()) {
2685  return emitError(gpuModule->getLoc(),
2686  "Minimum NVVM target SM version is sm_20");
2687  }
2688 
2689  gpuModuleOp->walk([&](Operation *op) {
2690  if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2691  const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
2692  if (!requirement.isCompatibleWith(targetSMVersion)) {
2693  op->emitOpError() << "is not supported on " << getChip();
2694  return WalkResult::interrupt();
2695  }
2696  }
2697  return WalkResult::advance();
2698  });
2699 
2700  return success();
2701 }
2702 
2703 #define GET_OP_CLASSES
2704 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2705 
2706 #define GET_ATTRDEF_CLASSES
2707 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
#define _none
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:57
static constexpr unsigned notIntrinsic
Definition: NVVMDialect.cpp:48
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition: LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition: NVVMDialect.h:54
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:85
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)