MLIR  22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cassert>
39 #include <optional>
40 #include <string>
41 
42 using namespace mlir;
43 using namespace NVVM;
44 
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
47 
48 //===----------------------------------------------------------------------===//
49 // Verifier methods
50 //===----------------------------------------------------------------------===//
51 
52 // This verifier is shared among the following Ops:
53 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
54 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
55 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
56  bool isIm2Col,
57  size_t numIm2ColOffsets,
58  Location loc) {
59  if (tensorDims < 1 || tensorDims > 5)
60  return emitError(loc, "expects coordinates between 1 to 5 dimension");
61 
62  // For Im2Col mode, there are two constraints:
63  if (isIm2Col) {
64  // 1. Tensor must always be at least 3-d.
65  if (tensorDims < 3)
66  return emitError(
67  loc,
68  "to use im2col mode, the tensor has to be at least 3-dimensional");
69  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
70  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
71  return emitError(
72  loc, "im2col offsets must be 2 less than number of coordinates");
73  }
74  return success();
75 }
76 
78  size_t numIm2ColOffsets = getIm2colOffsets().size();
79  bool isIm2Col = numIm2ColOffsets > 0;
80  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
81  numIm2ColOffsets, getLoc());
82 }
83 
85  TMAStoreMode mode = getMode();
86  // We lower through inline-ptx when getPredicate() is true.
87  // a) Only TILE mode is supported
88  // b) Cache-hint is not supported
89  if (getPredicate()) {
90  if (mode != TMAStoreMode::TILE)
91  return emitError("Inline-ptx lowering supported only for Tile mode.");
92  if (getL2CacheHint())
93  return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
94  }
95 
96  size_t dims = getCoordinates().size();
97  switch (mode) {
98  case TMAStoreMode::TILE:
99  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
100  case TMAStoreMode::IM2COL:
101  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
102  case TMAStoreMode::TILE_SCATTER4:
103  if (dims != 5)
104  return emitError("Scatter4 mode expects 5 coordinates");
105  }
106  return success();
107 }
108 
109 LogicalResult CpAsyncOp::verify() {
110  if (getModifier() != LoadCacheModifierKind::CG &&
111  getModifier() != LoadCacheModifierKind::CA)
112  return emitError("Only CG and CA cache modifiers are supported.");
113  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
114  return emitError("expected byte size to be either 4, 8 or 16.");
115  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
116  return emitError("CG cache modifier is only support for 16 bytes copy.");
117  return success();
118 }
119 
120 // This verify params can be shared across TMA Load and Prefetch Ops.
121 static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
122  TMALoadMode mode, Location loc) {
123  if (tensorDims < 1 || tensorDims > 5)
124  return emitError(loc, "expects coordinates between 1 to 5 dimension");
125 
126  auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
127  size_t expectedIm2colOff) -> LogicalResult {
128  if (isIm2col && (tensorDims < 3))
129  return emitError(loc)
130  << "to use " << stringifyEnum(mode)
131  << " mode, the tensor has to be at least 3-dimensional";
132 
133  if (numIm2colOff != expectedIm2colOff)
134  return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
135  << " (provided " << numIm2colOff << ")";
136 
137  return success();
138  };
139 
140  switch (mode) {
141  case TMALoadMode::TILE:
142  return checkTMALoadParams(mode, false, 0);
143  case TMALoadMode::IM2COL:
144  return checkTMALoadParams(mode, true, tensorDims - 2);
145  case TMALoadMode::IM2COL_W:
146  case TMALoadMode::IM2COL_W_128:
147  return checkTMALoadParams(mode, true, 2);
148  case TMALoadMode::TILE_GATHER4:
149  return (tensorDims == 5)
150  ? checkTMALoadParams(mode, false, 0)
151  : emitError(loc, "Gather4 mode expects 5 coordinates");
152  }
153  return success();
154 }
155 
156 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
157  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
158  getMode(), getLoc());
159 }
160 
161 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
162  TMAStoreMode mode = getMode();
163  size_t dims = getCoordinates().size();
164  switch (mode) {
165  case TMAStoreMode::TILE:
166  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
167  case TMAStoreMode::IM2COL:
168  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
169  case TMAStoreMode::TILE_SCATTER4:
170  return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
171  }
172  return success();
173 }
174 
175 LogicalResult ConvertFloatToTF32Op::verify() {
176  using RndMode = NVVM::FPRoundingMode;
177  switch (getRnd()) {
178  case RndMode::RNA:
179  if (getRelu())
180  return emitError("Relu not supported with rna rounding mode.");
181  break;
182  case RndMode::RN:
183  case RndMode::RZ:
184  break;
185  default:
186  return emitError(
187  "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
188  }
189  return success();
190 }
191 
192 LogicalResult ConvertF32x2ToF8x2Op::verify() {
193  using RndMode = NVVM::FPRoundingMode;
194  using SatMode = NVVM::SaturationMode;
195 
196  bool isRoundingModeRN = getRnd() == RndMode::RN;
197  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
198  bool isRoundingModeRP = getRnd() == RndMode::RP;
199  bool isSatFinite = getSat() == SatMode::SATFINITE;
200 
201  bool hasRelu = getRelu();
202 
203  switch (getType()) {
204  case ConvertFP8Type::E4M3:
205  case ConvertFP8Type::E5M2:
206  if (!isRoundingModeRN)
207  return emitOpError("Only RN rounding mode is supported for conversions "
208  "from f32x2 to .e4m3x2 or .e5m2x2 types");
209  if (!isSatFinite)
210  return emitOpError("Only SATFINITE saturation mode is supported for "
211  "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
212  break;
213  case ConvertFP8Type::UE8M0:
214  if (!(isRoundingModeRZ || isRoundingModeRP))
215  return emitOpError("Only RZ or RP rounding modes are supported for "
216  "conversions from f32x2 to .ue8m0x2 type");
217  if (hasRelu)
218  return emitOpError("relu not supported for conversions to .ue8m0x2 type");
219  break;
220  }
221  return success();
222 }
223 
224 LogicalResult ConvertF16x2ToF8x2Op::verify() {
225  if (getType() == ConvertFP8Type::UE8M0)
226  return emitOpError("Only .e4m3 or .e5m2 types are supported for "
227  "conversions from f16x2 to f8x2.");
228 
229  return success();
230 }
231 
232 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
233  using RndMode = NVVM::FPRoundingMode;
234 
235  if (getType() != ConvertFP8Type::UE8M0)
236  return emitOpError(
237  "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
238 
239  auto rnd = getRnd();
240  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
241  return emitOpError("Only RZ and RP rounding modes are supported for "
242  "conversions from bf16x2 to f8x2.");
243 
244  return success();
245 }
246 
247 LogicalResult BulkStoreOp::verify() {
248  if (getInitVal() != 0)
249  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
250  return success();
251 }
252 
253 LogicalResult PMEventOp::verify() {
254  auto eventId = getEventId();
255  auto maskedEventId = getMaskedEventId();
256  if (!maskedEventId && !eventId) {
257  return emitOpError() << "either `id` or `mask` must be set";
258  }
259 
260  if (maskedEventId && eventId) {
261  return emitOpError() << "`id` and `mask` cannot be set at the same time";
262  }
263 
264  if (eventId) {
265  if (eventId < 0 || eventId > 15) {
266  return emitOpError() << "`id` must be between 0 and 15";
267  }
268  }
269 
270  return llvm::success();
271 }
272 
273 // Given the element type of an operand and whether or not it is an accumulator,
274 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
275 // operand's element type.
276 std::optional<mlir::NVVM::MMATypes>
277 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
278  auto half2Type =
279  VectorType::get(2, Float16Type::get(operandElType.getContext()));
280  if (operandElType.isF64())
281  return NVVM::MMATypes::f64;
282  if (operandElType.isF16() || operandElType == half2Type)
283  return NVVM::MMATypes::f16;
284  if (operandElType.isF32() && isAccumulator)
285  return NVVM::MMATypes::f32;
286  if (operandElType.isF32() && !isAccumulator)
287  return NVVM::MMATypes::tf32;
288  if (llvm::isa<IntegerType>(operandElType)) {
289  if (isAccumulator)
290  return NVVM::MMATypes::s32;
291  return std::nullopt;
292  }
293 
294  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
295  if (structType.getBody().empty())
296  return std::nullopt;
297  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
298  }
299 
300  return std::nullopt;
301 }
302 
303 static bool isInt4PtxType(MMATypes type) {
304  return (type == MMATypes::u4 || type == MMATypes::s4);
305 }
306 
307 static bool isInt8PtxType(MMATypes type) {
308  return (type == MMATypes::u8 || type == MMATypes::s8);
309 }
310 
311 static bool isIntegerPtxType(MMATypes type) {
312  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
313  type == MMATypes::s32;
314 }
315 
316 MMATypes MmaOp::accumPtxType() {
317  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
318  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
319  assert(val.has_value() && "accumulator PTX type should always be inferrable");
320  return val.value();
321 }
322 
323 MMATypes MmaOp::resultPtxType() {
324  std::optional<mlir::NVVM::MMATypes> val =
325  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
326  assert(val.has_value() && "result PTX type should always be inferrable");
327  return val.value();
328 }
329 
330 void MmaOp::print(OpAsmPrinter &p) {
331  SmallVector<Type, 4> regTypes;
332  struct OperandFragment {
333  StringRef operandName;
334  StringRef ptxTypeAttr;
336  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
337  : operandName(name), ptxTypeAttr(ptxTypeName) {}
338  };
339 
340  std::array<OperandFragment, 3> frags{
341  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
342  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
343  OperandFragment("C", "")};
344  SmallVector<StringRef, 4> ignoreAttrNames{
345  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
346 
347  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
348  auto &frag = frags[fragIdx];
349  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
350  for (auto operandIdx = varOperandSpec.first;
351  operandIdx < varOperandSpec.first + varOperandSpec.second;
352  operandIdx++) {
353  frag.regs.push_back(this->getOperand(operandIdx));
354  if (operandIdx == 0) {
355  regTypes.push_back(this->getOperand(operandIdx).getType());
356  }
357  }
358  std::optional<MMATypes> inferredType =
359  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
360  if (inferredType)
361  ignoreAttrNames.push_back(frag.ptxTypeAttr);
362  }
363 
364  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
365  p << " " << frag.operandName;
366  p << "[";
367  p.printOperands(frag.regs);
368  p << "] ";
369  };
370 
371  for (const auto &frag : frags) {
372  printMmaOperand(frag);
373  }
374 
375  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
376 
377  // Print the types of the operands and result.
378  p << " : "
379  << "(";
380  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
381  frags[1].regs[0].getType(),
382  frags[2].regs[0].getType()},
383  p);
384  p << ")";
385  p.printArrowTypeList(TypeRange{this->getRes().getType()});
386 }
387 
388 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
389  ValueRange operandA, ValueRange operandB, ValueRange operandC,
390  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
391  std::optional<MMAIntOverflow> intOverflow,
392  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
393  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
394 
395  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
396  MLIRContext *ctx = builder.getContext();
397  result.addAttribute(
398  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
399 
400  result.addOperands(operandA);
401  result.addOperands(operandB);
402  result.addOperands(operandC);
403 
404  if (multiplicandPtxTypes) {
405  result.addAttribute("multiplicandAPtxType",
406  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
407  result.addAttribute("multiplicandBPtxType",
408  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
409  } else {
410  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
411  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
412  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
413  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
414  }
415 
416  if (multiplicandLayouts) {
417  result.addAttribute("layoutA",
418  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
419  result.addAttribute("layoutB",
420  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
421  } else {
422  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
423  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
424  }
425 
426  if (intOverflow.has_value())
427  result.addAttribute("intOverflowBehavior",
428  MMAIntOverflowAttr::get(ctx, *intOverflow));
429  if (b1Op.has_value())
430  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
431 
432  result.addTypes(resultType);
433  result.addAttribute(
434  MmaOp::getOperandSegmentSizeAttr(),
435  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
436  static_cast<int32_t>(operandB.size()),
437  static_cast<int32_t>(operandC.size())}));
438 }
439 
440 // <operation> :=
441 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
442 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
443 // `->` type($res)
444 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
445  struct OperandFragment {
446  std::optional<MMATypes> elemtype;
448  SmallVector<Type> regTypes;
449  };
450 
451  Builder &builder = parser.getBuilder();
452  std::array<OperandFragment, 4> frags;
453 
454  NamedAttrList namedAttributes;
455 
456  // A helper to parse the operand segments.
457  auto parseMmaOperand = [&](StringRef operandName,
458  OperandFragment &frag) -> LogicalResult {
459  if (parser.parseKeyword(operandName).failed())
460  return failure();
461  if (parser
462  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
463  .failed())
464  return failure();
465  return success();
466  };
467 
468  // Parse the operand segments.
469  if (parseMmaOperand("A", frags[0]).failed())
470  return failure();
471  if (parseMmaOperand("B", frags[1]).failed())
472  return failure();
473  if (parseMmaOperand("C", frags[2]).failed())
474  return failure();
475 
476  if (parser.parseOptionalAttrDict(namedAttributes).failed())
477  return failure();
478 
479  // Parse the type specification and resolve operands.
480  SmallVector<Type, 3> operandTypes;
481  if (failed(parser.parseColon()))
482  return failure();
483  if (failed(parser.parseLParen()))
484  return failure();
485  if (failed(parser.parseTypeList(operandTypes)))
486  return failure();
487  if (failed(parser.parseRParen()))
488  if (operandTypes.size() != 3)
489  return parser.emitError(
490  parser.getNameLoc(),
491  "expected one type for each operand segment but got " +
492  Twine(operandTypes.size()) + " types");
493  for (const auto &iter : llvm::enumerate(operandTypes)) {
494  auto &frag = frags[iter.index()];
495  frag.regTypes.resize(frag.regs.size(), iter.value());
496  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
497  parser.getNameLoc(), result.operands)))
498  return failure();
499  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
500  /*isAccumulator*/ iter.index() < 2);
501  }
502 
503  Type resultType;
504  if (parser.parseArrow() || parser.parseType(resultType))
505  return failure();
506  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
507 
508  std::array<StringRef, 2> names{"multiplicandAPtxType",
509  "multiplicandBPtxType"};
510  for (unsigned idx = 0; idx < names.size(); idx++) {
511  const auto &frag = frags[idx];
512  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
513  if (!frag.elemtype.has_value() && !attr.has_value()) {
514  return parser.emitError(
515  parser.getNameLoc(),
516  "attribute " + names[idx] +
517  " is not provided explicitly and cannot be inferred");
518  }
519  if (!attr.has_value())
520  result.addAttribute(
521  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
522  }
523 
524  result.addTypes(resultType);
525  if (!namedAttributes.empty())
526  result.addAttributes(namedAttributes);
527  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
528  builder.getDenseI32ArrayAttr({
529  static_cast<int32_t>(frags[0].regs.size()),
530  static_cast<int32_t>(frags[1].regs.size()),
531  static_cast<int32_t>(frags[2].regs.size()),
532  }));
533  return success();
534 }
535 
536 LogicalResult MmaOp::verify() {
537  MLIRContext *context = getContext();
538  auto f16Ty = Float16Type::get(context);
539  auto i32Ty = IntegerType::get(context, 32);
540  auto f16x2Ty = VectorType::get(2, f16Ty);
541  auto f32Ty = Float32Type::get(context);
542  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
543  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
544 
545  auto s32x4StructTy =
546  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
547  auto f32x8StructTy =
548  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
549  auto f16x2x2StructTy =
550  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
551  auto f32x4StructTy =
552  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
553  auto s32x2StructTy =
554  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
555 
556  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
557  getShapeAttr().getK()};
558 
559  // These variables define the set of allowed data types for matrices A, B, C,
560  // and result.
561  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
562  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
563  AllowedShapes allowedShapes;
564  AllowedTypes expectedA;
565  AllowedTypes expectedB;
566  AllowedTypes expectedC;
567  SmallVector<Type> expectedResult;
568 
569  // When M = 16, we just need to calculate the number of 8xk tiles, where
570  // k is a factor that depends on the data type.
571  if (mmaShape[0] == 16) {
572  int64_t kFactor;
573  Type multiplicandFragType;
574  switch (*getMultiplicandAPtxType()) {
575  case MMATypes::tf32:
576  kFactor = 4;
577  multiplicandFragType = i32Ty;
578  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
579  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
580  break;
581  case MMATypes::bf16:
582  kFactor = 8;
583  multiplicandFragType = i32Ty;
584  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
585  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
586  break;
587  case MMATypes::f16:
588  kFactor = 8;
589  multiplicandFragType = f16x2Ty;
590  expectedResult.push_back(f16x2x2StructTy);
591  expectedResult.push_back(f32x4StructTy);
592  break;
593  case MMATypes::s4:
594  case MMATypes::u4:
595  kFactor = 32;
596  break;
597  case MMATypes::b1:
598  kFactor = 128;
599  break;
600  case MMATypes::s8:
601  case MMATypes::u8:
602  kFactor = 16;
603  break;
604  default:
605  return emitError("invalid shape or multiplicand type: " +
606  stringifyEnum(getMultiplicandAPtxType().value()));
607  }
608 
609  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
610  expectedResult.push_back(s32x4StructTy);
611  expectedC.emplace_back(4, i32Ty);
612  multiplicandFragType = i32Ty;
613  } else {
614  expectedC.emplace_back(2, f16x2Ty);
615  expectedC.emplace_back(4, f32Ty);
616  }
617 
618  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
619  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
620  expectedA.emplace_back(unitA, multiplicandFragType);
621  expectedB.emplace_back(unitB, multiplicandFragType);
622  allowedShapes.push_back({16, 8, kFactor});
623  allowedShapes.push_back({16, 8, kFactor * 2});
624  }
625 
626  // In the M=8 case, there is only 1 possible case per data type.
627  if (mmaShape[0] == 8) {
628  if (*getMultiplicandAPtxType() == MMATypes::f16) {
629  expectedA.emplace_back(2, f16x2Ty);
630  expectedB.emplace_back(2, f16x2Ty);
631  expectedResult.push_back(f16x2x4StructTy);
632  expectedResult.push_back(f32x8StructTy);
633  expectedC.emplace_back(4, f16x2Ty);
634  expectedC.emplace_back(8, f32Ty);
635  allowedShapes.push_back({8, 8, 4});
636  }
637  if (*getMultiplicandAPtxType() == MMATypes::f64) {
638  Type f64Ty = Float64Type::get(context);
639  expectedA.emplace_back(1, f64Ty);
640  expectedB.emplace_back(1, f64Ty);
641  expectedC.emplace_back(2, f64Ty);
642  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
643  context, SmallVector<Type>(2, f64Ty)));
644  allowedShapes.push_back({8, 8, 4});
645  }
646  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
647  expectedA.push_back({i32Ty});
648  expectedB.push_back({i32Ty});
649  expectedC.push_back({i32Ty, i32Ty});
650  expectedResult.push_back(s32x2StructTy);
651  if (isInt4PtxType(getMultiplicandAPtxType().value()))
652  allowedShapes.push_back({8, 8, 32});
653  if (isInt8PtxType(getMultiplicandAPtxType().value()))
654  allowedShapes.push_back({8, 8, 16});
655  if (getMultiplicandAPtxType().value() == MMATypes::b1)
656  allowedShapes.push_back({8, 8, 128});
657  }
658  }
659 
660  std::string errorMessage;
661  llvm::raw_string_ostream errorStream(errorMessage);
662 
663  // Check that we matched an existing shape/dtype combination.
664  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
665  !llvm::is_contained(allowedShapes, mmaShape)) {
666  errorStream << "unimplemented variant for MMA shape <";
667  llvm::interleaveComma(mmaShape, errorStream);
668  errorStream << ">";
669  return emitOpError(errorMessage);
670  }
671 
672  // Verify the operand types for segments of A, B, and C operands.
673  std::array<StringRef, 3> operandNames{"A", "B", "C"};
674  for (const auto &iter : llvm::enumerate(
675  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
676  auto spec = this->getODSOperandIndexAndLength(iter.index());
677  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
678  operand_type_begin() + spec.first +
679  spec.second);
680  bool match = llvm::is_contained(iter.value(), operandTySeg);
681 
682  if (!match) {
683  errorStream << "Could not match types for the "
684  << operandNames[iter.index()]
685  << " operands; expected one of ";
686  for (const auto &x : iter.value()) {
687  errorStream << x.size() << "x" << x[0] << " ";
688  }
689  errorStream << "but got ";
690  llvm::interleaveComma(operandTySeg, errorStream);
691  return emitOpError(errorMessage);
692  }
693  }
694 
695  // Check the result type
696  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
697  return expectedResultType == getResult().getType();
698  })) {
699  errorStream
700  << "Could not match allowed types for the result; expected one of ";
701  llvm::interleaveComma(expectedResult, errorStream);
702  errorStream << " but got " << getResult().getType();
703  return emitOpError(errorMessage);
704  }
705 
706  // Ensure that binary MMA variants have a b1 MMA operation defined.
707  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
708  return emitOpError("op requires " + getB1OpAttrName().strref() +
709  " attribute");
710  }
711 
712  // Ensure int4/int8 MMA variants specify the accum overflow behavior
713  // attribute.
714  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
715  isInt8PtxType(*getMultiplicandAPtxType())) {
716  if (!getIntOverflowBehavior())
717  return emitOpError("op requires " +
718  getIntOverflowBehaviorAttrName().strref() +
719  " attribute");
720  }
721 
722  return success();
723 }
724 
725 LogicalResult ShflOp::verify() {
726  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
727  return success();
728  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
729  auto elementType = (type && type.getBody().size() == 2)
730  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
731  : nullptr;
732  if (!elementType || elementType.getWidth() != 1)
733  return emitError("expected return type to be a two-element struct with "
734  "i1 as the second element");
735  return success();
736 }
737 
738 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
739  NVVM::MMAFrag frag, int nRow,
740  int nCol,
741  MLIRContext *context) {
742  unsigned numberElements = 0;
743  Type elementType;
744  OpBuilder builder(context);
745  Type f16x2 = VectorType::get(2, builder.getF16Type());
746  if (type == NVVM::MMATypes::f16) {
747  elementType = f16x2;
748  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
749  numberElements = 8;
750  else
751  numberElements = 4;
752  } else if (type == NVVM::MMATypes::f32) {
753  elementType = builder.getF32Type();
754  numberElements = 8;
755  } else if (type == NVVM::MMATypes::tf32) {
756  elementType = builder.getI32Type();
757  numberElements = 4;
758  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
759  elementType = builder.getI32Type();
760  int parallelSize = 0;
761  if (frag == NVVM::MMAFrag::a)
762  parallelSize = nRow;
763  if (frag == NVVM::MMAFrag::b)
764  parallelSize = nCol;
765 
766  // m == 16 && n == 16 && k == 16
767  if (parallelSize == 16)
768  numberElements = 2;
769  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
770  else if (parallelSize == 8)
771  numberElements = 1;
772  else if (parallelSize == 32)
773  numberElements = 4;
774  } else if (type == NVVM::MMATypes::s32) {
775  elementType = builder.getI32Type();
776  numberElements = 8;
777  }
778  assert(numberElements != 0 && elementType != nullptr);
779  return std::make_pair(elementType, numberElements);
780 }
781 
782 static std::pair<mlir::Type, unsigned>
783 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
784  int k, MLIRContext *context) {
785  int nRow, nCol;
786  if (frag == NVVM::MMAFrag::a) {
787  nRow = m;
788  nCol = k;
789  } else if (frag == NVVM::MMAFrag::b) {
790  nRow = k;
791  nCol = n;
792  } else {
793  nRow = m;
794  nCol = n;
795  }
796  assert(nRow && nCol);
797  return inferMMAType(type, frag, nRow, nCol, context);
798 }
799 
800 LogicalResult NVVM::WMMALoadOp::verify() {
801  unsigned addressSpace =
802  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
803  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
804  addressSpace != NVVMMemorySpace::Shared)
805  return emitOpError("expected source pointer in memory "
806  "space 0, 1, 3");
807 
808  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
809  getEltype(), getFrag()) == 0)
810  return emitOpError() << "invalid attribute combination";
811  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
812  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
813  Type dstType = LLVM::LLVMStructType::getLiteral(
814  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
815  if (getType() != dstType)
816  return emitOpError("expected destination type is a structure of ")
817  << typeInfo.second << " elements of type " << typeInfo.first;
818  return success();
819 }
820 
821 LogicalResult NVVM::WMMAStoreOp::verify() {
822  unsigned addressSpace =
823  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
824  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
825  addressSpace != NVVMMemorySpace::Shared)
826  return emitOpError("expected operands to be a source pointer in memory "
827  "space 0, 1, 3");
828 
829  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
830  getEltype()) == 0)
831  return emitOpError() << "invalid attribute combination";
832  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
833  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
834  if (getArgs().size() != typeInfo.second)
835  return emitOpError() << "expected " << typeInfo.second << " data operands";
836  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
837  return operands.getType() != typeInfo.first;
838  }))
839  return emitOpError() << "expected data operands of type " << typeInfo.first;
840  return success();
841 }
842 
843 LogicalResult NVVM::WMMAMmaOp::verify() {
844  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
845  getLayoutB(), getEltypeA(),
846  getEltypeB()) == 0)
847  return emitOpError() << "invalid attribute combination";
848  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
849  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
850  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
851  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
852  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
853  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
854  SmallVector<Type, 32> arguments;
855  arguments.append(typeInfoA.second, typeInfoA.first);
856  arguments.append(typeInfoB.second, typeInfoB.first);
857  arguments.append(typeInfoC.second, typeInfoC.first);
858  unsigned numArgs = arguments.size();
859  if (getArgs().size() != numArgs)
860  return emitOpError() << "expected " << numArgs << " arguments";
861  for (unsigned i = 0; i < numArgs; i++) {
862  if (getArgs()[i].getType() != arguments[i])
863  return emitOpError() << "expected argument " << i << " to be of type "
864  << arguments[i];
865  }
866  Type dstType = LLVM::LLVMStructType::getLiteral(
867  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
868  if (getType() != dstType)
869  return emitOpError("expected destination type is a structure of ")
870  << typeInfoC.second << " elements of type " << typeInfoC.first;
871  return success();
872 }
873 
874 LogicalResult NVVM::LdMatrixOp::verify() {
875  uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
876  if (m == 8 && n == 8) {
877  if (num != 1 && num != 2 && num != 4) {
878  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
879  "matrix");
880  }
881  if (getEltType() != LdStMatrixEltType::B16) {
882  return emitOpError("expected element type to be b16 for 8x8 matrix");
883  }
884  } else if (m == 8 && n == 16) {
885  if (num != 1 && num != 2 && num != 4) {
886  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
887  "matrix");
888  }
889  if (getLayout() != MMALayout::row) {
890  return emitOpError("expected layout to be row for 8x16 matrix");
891  }
892  if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
893  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
894  return emitOpError("expected element type to be b8x16.b4x16_p64 or "
895  "b8x16.b6x16_p32 for 8x16 matrix");
896  }
897  } else if (m == 16 && n == 16) {
898  if (num != 1 && num != 2) {
899  return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
900  "matrix");
901  }
902  if (getLayout() != MMALayout::col) {
903  return emitOpError("expected layout to be col for 16x16 matrix");
904  }
905  if (getEltType() != LdStMatrixEltType::B8 &&
906  getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
907  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
908  return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
909  "b8x16.b6x16_p32 for 16x16 matrix");
910  }
911  } else {
912  return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
913  }
914 
915  Type i32 = IntegerType::get(getContext(), 32);
916  uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
917  if (numElements == 1 && getType() != i32)
918  return emitOpError("expected destination type is i32");
919  if (numElements == 2 || numElements == 4) {
920  Type dstType = LLVM::LLVMStructType::getLiteral(
921  getContext(), SmallVector<Type>(numElements, i32));
922  if (getType() != dstType)
923  return emitOpError("expected destination type is a structure of ")
924  << numElements << " elements of type i32";
925  }
926 
927  return success();
928 }
929 
930 LogicalResult NVVM::StMatrixOp::verify() {
931  int numMatrix = getSources().size();
932  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
933  return emitOpError("expected num attribute to be 1, 2 or 4");
934 
935  int m = getShape().getM(), n = getShape().getN();
936  if (m == 8 && n == 8) {
937  if (getEltType() != NVVM::LdStMatrixEltType::B16) {
938  return emitOpError("expected element type to be B16 for 8x8 matrix");
939  }
940  } else if (m == 16 && n == 8) {
941  if (getEltType() != NVVM::LdStMatrixEltType::B8) {
942  return emitOpError("expected element type to be B8 for 16x8 matrix");
943  }
944  if (getLayout() != NVVM::MMALayout::col) {
945  return emitOpError("expected layout to be col for 16x8 matrix");
946  }
947  } else {
948  return emitOpError("expected shape to be 8x8 or 16x8");
949  }
950 
951  return success();
952 }
953 
954 static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
955  if (typeA == NVVM::WGMMATypes::tf32)
956  return 8;
957  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
958  return 16;
959  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
960  return 32;
961  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
962  return 32;
963  if (typeA == NVVM::WGMMATypes::b1)
964  return 256;
965  return failure();
966 }
967 
968 static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
969  NVVM::WGMMATypes typeA,
970  NVVM::WGMMATypes typeB) {
971  switch (typeA) {
972  case NVVM::WGMMATypes::f16:
973  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
974  typeB == NVVM::WGMMATypes::f16)
975  return success();
976  break;
977  case NVVM::WGMMATypes::tf32:
978  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
979  return success();
980  break;
981  case NVVM::WGMMATypes::u8:
982  case NVVM::WGMMATypes::s8:
983  if (typeD == NVVM::WGMMATypes::s32 &&
984  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
985  return success();
986  break;
987  case NVVM::WGMMATypes::b1:
988  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
989  return success();
990  break;
991  case NVVM::WGMMATypes::bf16:
992  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
993  typeB == NVVM::WGMMATypes::bf16)
994  return success();
995  break;
996  case NVVM::WGMMATypes::e4m3:
997  case NVVM::WGMMATypes::e5m2:
998  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
999  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1000  return success();
1001  break;
1002  case WGMMATypes::f32:
1003  case WGMMATypes::s32:
1004  llvm_unreachable("unsupported input types");
1005  break;
1006  }
1007  return failure();
1008 }
1009 
1010 static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
1011  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
1012  72, 80, 88, 96, 104, 112, 120, 128,
1013  136, 144, 152, 160, 168, 176, 184, 192,
1014  200, 208, 216, 224, 232, 240, 248, 256};
1015  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
1016  80, 96, 112, 128, 144, 160,
1017  176, 192, 208, 224, 240, 256};
1018  switch (typeA) {
1019  case WGMMATypes::f16:
1020  case WGMMATypes::tf32:
1021  case WGMMATypes::bf16:
1022  case WGMMATypes::e4m3:
1023  case WGMMATypes::e5m2:
1024  if (llvm::is_contained(allowedN, sizeN))
1025  return success();
1026  break;
1027  case WGMMATypes::u8:
1028  case WGMMATypes::s8:
1029  case WGMMATypes::b1:
1030  if (llvm::is_contained(allowedNshort, sizeN))
1031  return success();
1032  break;
1033  case WGMMATypes::f32:
1034  case WGMMATypes::s32:
1035  llvm_unreachable("unsupported input types");
1036  break;
1037  }
1038  return failure();
1039 }
1040 
1041 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1042  Value outValue = getResults();
1043  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
1044  if (!stype)
1045  return emitOpError() << "expected results to be struct";
1046  int outputSize = stype.getBody().size();
1047  WGMMATypes typeD = getTypeD();
1048  WGMMATypes typeA = getTypeA();
1049  WGMMATypes typeB = getTypeB();
1050 
1051  for (Type t : stype.getBody()) {
1052  if (t != stype.getBody().front())
1053  return emitOpError()
1054  << "all elements in struct must be same type but there is " << t;
1055  }
1056 
1057  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1058  typeD != WGMMATypes::s32) {
1059  return emitOpError() << "does not support the given output type "
1060  << NVVM::stringifyWGMMATypes(typeD);
1061  }
1062  if (typeD == WGMMATypes::s32 &&
1063  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1064  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
1065  }
1066 
1067  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
1068  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1069  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
1070  << NVVM::stringifyWGMMATypes(typeB)
1071  << ", it is not supported.";
1072  }
1073 
1074  // Check M
1075  if (getShape().getM() != 64)
1076  return emitOpError() << "shape 'm' must be 64";
1077 
1078  // Check K
1079  FailureOr<int> allowedK = getAllowedSizeK(typeA);
1080  if (failed(allowedK) || allowedK.value() != getShape().getK())
1081  return emitOpError() << "shape 'k' must be " << allowedK.value()
1082  << " for input type "
1083  << NVVM::stringifyWGMMATypes(typeA);
1084 
1085  // Check N
1086  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
1087  return emitOpError() << "has input type "
1088  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
1089  << getShape().getN() << ", it is not supported.";
1090  }
1091 
1092  // Check transpose (only available for f16/bf16)
1093  // Matrices A should be stored in row-major and B in column-major.
1094  // Only f16/bf16 matrices can be stored in either column-major or row-major
1095  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
1096  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1097  (getLayoutA() == mlir::NVVM::MMALayout::col ||
1098  getLayoutB() == mlir::NVVM::MMALayout::row)) {
1099  return emitOpError()
1100  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1101  << " and layout_b = " << stringifyMMALayout(getLayoutB())
1102  << " for input types " << stringifyWGMMATypes(typeA) << " and "
1103  << stringifyWGMMATypes(typeB)
1104  << " requires transpose. However, this is only supported for: "
1105  << stringifyMMATypes(MMATypes::f16) << " and "
1106  << stringifyMMATypes(MMATypes::bf16);
1107  }
1108 
1109  // Check result registers
1110  int expectedOutput = 0;
1111  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1112  expectedOutput = getShape().getN() / 2;
1113  if (typeD == WGMMATypes::f16)
1114  expectedOutput = getShape().getN() / 4;
1115  if (outputSize != expectedOutput) {
1116  return emitOpError() << "results " << expectedOutput
1117  << ", however output struct has " << outputSize
1118  << " elements";
1119  }
1120  // Check satfinite (only available for s32 accumulator)
1121  if (typeD != WGMMATypes::s32 &&
1122  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1123  NVVM::MMAIntOverflow::satfinite) {
1124  return emitOpError()
1125  << " `satfinite` can be only used with s32 accumulator, however "
1126  "the current accumulator is "
1127  << NVVM::stringifyWGMMATypes(typeD);
1128  }
1129 
1130  return success();
1131 }
1132 
1133 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1134 
1135  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1136  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1137 
1138  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1139 
1140  int expectedOutputRegisters = 0;
1141  if (getTypeD() == WGMMATypes::f16)
1142  expectedOutputRegisters = getShape().getN() / 4;
1143  else
1144  expectedOutputRegisters = getShape().getN() / 2;
1145 
1146  std::string ptx;
1147  llvm::raw_string_ostream ss(ptx);
1148 
1149  ss << "{\n"
1150  ".reg .pred p;\n"
1151  "setp.ne.b32 p, $"
1152  << ((expectedOutputRegisters * 2) + 2)
1153  << ", 0;\n"
1154  "wgmma.mma_async.sync.aligned.m"
1155  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1156  << stringifyWGMMATypes(getTypeA()) << "."
1157  << stringifyWGMMATypes(getTypeB());
1158  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1159  NVVM::MMAIntOverflow::satfinite)
1160  ss << ".satfinite";
1161  ss << " {";
1162  int regCnt = 0;
1163  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1164  ss << "$" << regCnt;
1165  if (regCnt != expectedOutputRegisters - 1)
1166  ss << ", ";
1167  }
1168 
1169  ss << "},";
1170  // Need to map read/write registers correctly.
1171  regCnt = (regCnt * 2);
1172  ss << " $" << (regCnt) << ","
1173  << " $" << (regCnt + 1) << ","
1174  << " p";
1175  if (getTypeD() != WGMMATypes::s32) {
1176  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1177  }
1178  // Don't add transpose parameters unless needed.
1179  if (isF16) {
1180  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1181  }
1182  ss << ";\n"
1183  << "}\n";
1184  return ptx;
1185 }
1186 
1187 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1188  RewriterBase &rewriter,
1189  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1190  &asmValues) {
1191  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1192  if (getResults())
1193  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1194  if (getInouts())
1195  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1196  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1197  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1198  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1200  if (getTypeD() != WGMMATypes::s32) {
1201  asmValues.push_back(
1202  {makeConstantI32(rewriter,
1203  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1205  asmValues.push_back(
1206  {makeConstantI32(rewriter,
1207  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1209  }
1210  if (isF16) {
1211  asmValues.push_back(
1212  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1214  asmValues.push_back(
1215  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1217  }
1218  return true; // Has manual mapping
1219 }
1220 
1221 LogicalResult NVVM::FenceProxyOp::verify() {
1222  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1223  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1224  if (getKind() == NVVM::ProxyKind::GENERIC)
1225  return emitOpError() << "generic proxy not a supported proxy kind";
1226  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1227  return emitOpError() << "async_shared fence requires space attribute";
1228  }
1229  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1230  return emitOpError() << "only async_shared fence can have space attribute";
1231  }
1232  return success();
1233 }
1234 
1235 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1236  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1237  return emitOpError("uni-directional proxies only support generic for "
1238  "from_proxy attribute");
1239 
1240  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1241  return emitOpError("uni-directional proxies only support tensormap "
1242  "for to_proxy attribute");
1243 
1244  return success();
1245 }
1246 
1247 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1248  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1249  return emitOpError("uni-directional proxies only support generic for "
1250  "from_proxy attribute");
1251 
1252  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1253  return emitOpError("uni-directional proxies only support tensormap "
1254  "for to_proxy attribute");
1255 
1256  return success();
1257 }
1258 
1259 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1260  if (getRegCount() % 8)
1261  return emitOpError("new register size must be multiple of 8");
1262  if (getRegCount() < 24 || getRegCount() > 256)
1263  return emitOpError("new register size must be in between 24 to 256");
1264  return success();
1265 }
1266 
1267 LogicalResult NVVM::BarrierOp::verify() {
1268  if (getNumberOfThreads() && !getBarrierId())
1269  return emitOpError(
1270  "barrier id is missing, it should be set between 0 to 15");
1271  return success();
1272 }
1273 
1274 LogicalResult NVVM::Tcgen05CpOp::verify() {
1275  auto mc = getMulticast();
1276 
1277  using SH = Tcgen05CpShape;
1278  using MC = Tcgen05CpMulticast;
1279  switch (getShape()) {
1280  case SH::SHAPE_128x256b:
1281  case SH::SHAPE_128x128b:
1282  case SH::SHAPE_4x256b:
1283  if (mc != MC::NONE)
1284  return emitError("Invalid multicast type for tcgen05.cp Op");
1285  break;
1286  case SH::SHAPE_64x128b:
1287  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1288  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1289  "warpx2_02_13 for tcgen05.cp Op");
1290  break;
1291  case SH::SHAPE_32x128b:
1292  if (mc != MC::WARPX4)
1293  return emitError(
1294  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1295  break;
1296  }
1297  return success();
1298 }
1299 
1300 LogicalResult NVVM::MatchSyncOp::verify() {
1301  if (getKind() == NVVM::MatchSyncKind::all) {
1302  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1303  if (!type || type.getBody().size() != 2 ||
1304  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1305  return emitOpError("match.sync 'all' returns a two element struct with "
1306  "first element as i32 and second element as i1");
1307  }
1308  } else {
1309  if (!getType().isInteger(32)) {
1310  return emitOpError("match.sync 'any' returns an i32");
1311  }
1312  }
1313  return success();
1314 }
1315 
1316 LogicalResult NVVM::VoteSyncOp::verify() {
1317  if (getKind() == NVVM::VoteSyncKind::ballot) {
1318  if (!getType().isInteger(32)) {
1319  return emitOpError("vote.sync 'ballot' returns an i32");
1320  }
1321  } else {
1322  if (!getType().isInteger(1)) {
1323  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1324  }
1325  }
1326  return success();
1327 }
1328 
1329 LogicalResult NVVM::PrefetchOp::verify() {
1330  using MemSpace = NVVM::NVVMMemorySpace;
1331  using CacheLevel = NVVM::PrefetchCacheLevel;
1332 
1333  unsigned addressSpace =
1334  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1335  std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1336  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1337 
1338  if (getTensormap() && cacheLevel)
1339  return emitOpError("cannot specify both tensormap and cache level");
1340 
1341  if (getTensormap()) {
1342  if (addressSpace != MemSpace::Generic &&
1343  addressSpace != MemSpace::Constant) {
1344  return emitOpError(
1345  "prefetch tensormap requires a generic or constant pointer");
1346  }
1347 
1348  if (evictPriority) {
1349  return emitOpError(
1350  "prefetch tensormap does not support eviction priority");
1351  }
1352 
1353  if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1354  return emitOpError(
1355  "in_param_space can only be specified for a generic pointer");
1356  }
1357 
1358  } else if (cacheLevel) {
1359  if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1360  addressSpace != MemSpace::Local) {
1361  return emitOpError("prefetch to cache level requires a generic, global, "
1362  "or local pointer");
1363  }
1364 
1365  if (getUniform()) {
1366  if (*cacheLevel != CacheLevel::L1) {
1367  return emitOpError(
1368  "unsupported cache level, the only supported uniform "
1369  "cache level is L1");
1370  }
1371 
1372  if (addressSpace != MemSpace::Generic) {
1373  return emitOpError(
1374  "prefetch to uniform cache requires a generic pointer");
1375  }
1376  }
1377 
1378  if (evictPriority) {
1379  if (*cacheLevel != CacheLevel::L2)
1380  return emitOpError(
1381  "cache eviction priority supported only for cache level L2");
1382 
1383  if (addressSpace != MemSpace::Global)
1384  return emitOpError("cache eviction priority requires a global pointer");
1385 
1386  if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1387  *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1388  return emitOpError(
1389  "unsupported cache eviction priority, only evict_last and "
1390  "evict_normal are supported");
1391  }
1392 
1393  if (getPredicate())
1394  return emitOpError("predicate supported only on prefetch tensormap");
1395 
1396  } else {
1397  return emitOpError(
1398  "requires specification of either cache level or tensormap");
1399  }
1400 
1401  return success();
1402 }
1403 
1405  switch (getQueryType()) {
1406  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1407  if (!getType().isInteger(1))
1408  return emitOpError("is_canceled query type returns an i1");
1409  break;
1410  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1411  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1412  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1413  if (!getType().isInteger(32)) {
1414  return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
1415  "get_first_cta_id_z query types return an i32");
1416  }
1417  break;
1418  }
1419  return success();
1420 }
1421 
1422 /// Packs the given `field` into the `result`.
1423 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1424 static llvm::Value *
1425 packValInto64Bits(llvm::IRBuilderBase &builder,
1426  llvm::Value *result, // the `result` (unset bits are zero)
1427  llvm::Value *field, // `field` to pack into `result`
1428  unsigned sizeInBits, // Size of `field` in bits
1429  unsigned start) { // Starting bit within `result`
1430  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1431 
1432  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1433  if (mask != 0xffffffffu)
1434  field = builder.CreateAnd(field, builder.getInt32(mask));
1435 
1436  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1437  field = builder.CreateShl(field, start);
1438 
1439  return builder.CreateOr(result, field);
1440 }
1441 
1442 void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1443  LLVM::ModuleTranslation &mt,
1444  llvm::IRBuilderBase &builder) {
1445  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1446  llvm::Value *smemDesc = builder.getInt64(0);
1447 
1448  smemDesc = packValInto64Bits(builder, smemDesc,
1449  mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1450  smemDesc = packValInto64Bits(
1451  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1452  smemDesc = packValInto64Bits(
1453  builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1454 
1455  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1456  smemDesc = packValInto64Bits(builder, smemDesc,
1457  mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1458  smemDesc = packValInto64Bits(
1459  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1460  smemDesc = packValInto64Bits(builder, smemDesc,
1461  mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1462 
1463  mt.mapValue(thisOp.getRes()) = smemDesc;
1464 }
1465 
1466 //===----------------------------------------------------------------------===//
1467 // getIntrinsicID/getIntrinsicIDAndArgs methods
1468 //===----------------------------------------------------------------------===//
1469 
1470 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1471  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1472 
1473 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1474  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1475 
1477 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1480 
1481  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1482  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1483  switch (cpAsyncOp.getSize()) {
1484  case 4:
1485  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1486  break;
1487  case 8:
1488  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1489  break;
1490  case 16:
1491  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1492  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1493  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1494  break;
1495  default:
1496  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1497  }
1498 
1499  // Fill the Intrinsic Args
1500  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1501  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1502  if (hasCpSize)
1503  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1504 
1505  return id;
1506 }
1507 
1508 mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1509  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1510  auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1512  llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1513 
1514  // Fill the Intrinsic Args
1515  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1516  args.push_back(mt.lookupValue(thisOp.getSize()));
1517 
1518  mlir::Value cacheHint = thisOp.getL2CacheHint();
1519  const bool hasCacheHint = static_cast<bool>(cacheHint);
1520  llvm::Value *i64Unused =
1521  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1522  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1523  args.push_back(builder.getInt1(hasCacheHint));
1524 
1525  return {id, std::move(args)};
1526 }
1527 
1528 mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1529  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1530  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1532  llvm::Intrinsic::ID id =
1533  llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1534 
1535  // Fill the Intrinsic Args
1536  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1537  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1538  args.push_back(mt.lookupValue(thisOp.getSize()));
1539 
1540  mlir::Value cacheHint = thisOp.getL2CacheHint();
1541  const bool hasCacheHint = static_cast<bool>(cacheHint);
1542  llvm::Value *i64Unused =
1543  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1544  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1545  args.push_back(builder.getInt1(hasCacheHint));
1546 
1547  // Choose the bytemask variant
1548  if (mlir::Value byteMask = thisOp.getByteMask()) {
1549  args.push_back(mt.lookupValue(byteMask));
1550  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1551  }
1552 
1553  return {id, std::move(args)};
1554 }
1555 
1556 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
1557  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1558  auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1560 
1561  // Fill the Intrinsic Args
1562  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1563 
1564  for (auto v : thisOp.getCoordinates())
1565  args.push_back(mt.lookupValue(v));
1566  for (auto v : thisOp.getIm2colOffsets())
1567  args.push_back(mt.lookupValue(v));
1568 
1569  mlir::Value cacheHint = thisOp.getL2CacheHint();
1570  const bool hasCacheHint = static_cast<bool>(cacheHint);
1571  llvm::Value *i64Unused =
1572  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1573  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1574  args.push_back(builder.getInt1(hasCacheHint));
1575 
1576  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1577  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1578  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1579  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1580  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1581  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1582  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1583  {NI, NI, NI,
1584  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1585  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1586  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1587  {NI, NI, NI,
1588  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1589  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1590  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1591  {NI, NI, NI,
1592  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1593  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1594  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1595  {NI, NI, NI, NI, NI,
1596  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1597 
1598  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1599  "TMALoadModes must match number of rows in IDTable");
1600  size_t mode = static_cast<size_t>(thisOp.getMode());
1601  size_t dim = thisOp.getCoordinates().size();
1602  llvm::Intrinsic::ID id = IDTable[mode][dim];
1603  if (id == llvm::Intrinsic::not_intrinsic)
1604  llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1605 
1606  return {id, std::move(args)};
1607 }
1608 
1610 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1611  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1612  auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1614 
1615  // Fill the Intrinsic Args
1616  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1617  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1618 
1619  for (auto v : thisOp.getCoordinates())
1620  args.push_back(mt.lookupValue(v));
1621 
1622  mlir::Value cacheHint = thisOp.getL2CacheHint();
1623  const bool hasCacheHint = static_cast<bool>(cacheHint);
1624  llvm::Value *i64Unused =
1625  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1626  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1627  args.push_back(builder.getInt1(hasCacheHint));
1628 
1629  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1630  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1631  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1632  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1633  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1634  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1635  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1636  {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1637  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1638  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1639  {NI, NI, NI, NI, NI,
1640  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1641 
1642  static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
1643  "TMAStoreModes must match number of rows in IDTable");
1644  size_t mode = static_cast<size_t>(thisOp.getMode());
1645  size_t dim = thisOp.getCoordinates().size();
1646  llvm::Intrinsic::ID id = IDTable[mode][dim];
1647  if (id == llvm::Intrinsic::not_intrinsic)
1648  llvm_unreachable(
1649  "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
1650 
1651  return {id, std::move(args)};
1652 }
1653 
1654 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1655  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1656 
1657 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1658  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1659  : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1660 
1661 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1662  [&]() -> auto { \
1663  switch (dims) { \
1664  case 1: \
1665  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1666  case 2: \
1667  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1668  case 3: \
1669  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1670  case 4: \
1671  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1672  case 5: \
1673  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1674  default: \
1675  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1676  } \
1677  }()
1678 
1679 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1680  int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1681  using RedTy = NVVM::TMAReduxKind;
1682  switch (kind) {
1683  case RedTy::ADD:
1684  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1685  case RedTy::MIN:
1686  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1687  case RedTy::MAX:
1688  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1689  case RedTy::INC:
1690  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1691  case RedTy::DEC:
1692  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1693  case RedTy::AND:
1694  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1695  case RedTy::OR:
1696  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1697  case RedTy::XOR:
1698  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1699  }
1700  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1701 }
1702 
1703 #define _none
1704 
1705 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1706  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1707  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1708 
1709 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1710  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1711  : CVT_F2TF32_ID_IMPL(rnd, relu, )
1712 
1714 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1715  NVVM::SaturationMode sat, bool hasRelu) {
1716  using RndMode = NVVM::FPRoundingMode;
1717  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1718  switch (rnd) {
1719  case RndMode::RN:
1720  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1721  case RndMode::RZ:
1722  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1723  case RndMode::RNA:
1724  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1725  default:
1726  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1727  }
1728 }
1729 
1730 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1731  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1732  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1733 
1735 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1736  switch (type) {
1737  case NVVM::ConvertFP6Type::E2M3:
1738  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1739  case NVVM::ConvertFP6Type::E3M2:
1740  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1741  }
1742  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1743 }
1744 
1745 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1746  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1747  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1748 
1749 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1750  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1751  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1752 
1754 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1755  NVVM::FPRoundingMode rnd,
1756  NVVM::SaturationMode sat, bool hasRelu) {
1757  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1758  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1759  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1760 
1761  switch (type) {
1762  case NVVM::ConvertFP8Type::E4M3:
1763  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1764  case NVVM::ConvertFP8Type::E5M2:
1765  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1766  case NVVM::ConvertFP8Type::UE8M0:
1767  if (hasRoundingModeRZ)
1768  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1769  else if (hasRoundingModeRP)
1770  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1771  }
1772  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1773 }
1774 
1775 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1776  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1777  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1778 
1780 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1781  switch (type) {
1782  case NVVM::ConvertFP8Type::E4M3:
1783  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1784  case NVVM::ConvertFP8Type::E5M2:
1785  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1786  default:
1787  llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1788  }
1789 }
1790 
1791 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1792  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1793  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1794 
1796 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1797  NVVM::SaturationMode sat) {
1798  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1799  switch (rnd) {
1800  case NVVM::FPRoundingMode::RZ:
1801  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1802  case NVVM::FPRoundingMode::RP:
1803  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1804  default:
1805  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1806  }
1807 }
1808 
1810 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1811  LLVM::ModuleTranslation &mt,
1813  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1814  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1815  .getAddressSpace();
1816  bool isShared = as == NVVMMemorySpace::Shared;
1817  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
1818 
1820  if (isShared) {
1821  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1822  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1823  } else {
1824  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1825  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1826  }
1827 
1828  // Fill the Intrinsic Args
1829  args.push_back(mt.lookupValue(curOp.getAddr()));
1830  args.push_back(mt.lookupValue(curOp.getNCols()));
1831 
1832  return id;
1833 }
1834 
1835 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1836  Operation &op, LLVM::ModuleTranslation &mt,
1838  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1839  auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
1840  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1841  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1842 
1843  // Fill the Intrinsic Args
1844  args.push_back(mt.lookupValue(curOp.getTaddr()));
1845  args.push_back(mt.lookupValue(curOp.getNCols()));
1846 
1847  return id;
1848 }
1849 
1850 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1851  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1852  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1853 
1854 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1855  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1856  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1857 
1859 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1860  LLVM::ModuleTranslation &mt,
1862  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1863  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1864  .getAddressSpace();
1865  bool isShared = as == NVVMMemorySpace::Shared;
1866  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1867  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
1868 
1869  llvm::Intrinsic::ID id =
1870  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1871  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1872 
1873  // Fill the Intrinsic Args
1874  args.push_back(mt.lookupValue(curOp.getAddr()));
1875  if (hasMulticast)
1876  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1877 
1878  return id;
1879 }
1880 
1881 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1882  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1883 
1884 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1885  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1886  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1887 
1888 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1889  [&]() -> auto { \
1890  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1891  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1892  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1893  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1894  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1895  }()
1896 
1897 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1898  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1899  bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
1900  auto srcFmt = curOp.getSrcFormat();
1901  auto mc = curOp.getMulticast();
1902 
1903  switch (curOp.getShape()) {
1904  case Tcgen05CpShape::SHAPE_128x256b:
1905  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1906  case Tcgen05CpShape::SHAPE_128x128b:
1907  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1908  case Tcgen05CpShape::SHAPE_4x256b:
1909  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1910  case Tcgen05CpShape::SHAPE_32x128b:
1911  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1912  case Tcgen05CpShape::SHAPE_64x128b:
1913  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1914  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1915  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1916  }
1917  llvm_unreachable("Invalid shape in tcgen05 cp Op");
1918 }
1919 
1920 // Returns the valid vector length for a given shape and vector length, the
1921 // function models the table mentioned in the tcgen05.{ld, st} Op description
1922 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1923  unsigned vecLen) {
1924  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1925  return vecLen >= 2;
1926  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1927  return vecLen >= 4;
1928  return true;
1929 }
1930 
1931 LogicalResult Tcgen05LdOp::verify() {
1932  LogicalResult result = success();
1933  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1934  result = emitError("shape 16x32bx2 requires offset argument");
1935 
1936  auto resTy = getRes().getType();
1937  unsigned resLen = isa<VectorType>(resTy)
1938  ? llvm::cast<VectorType>(resTy).getNumElements()
1939  : 1;
1940  if (!isValidVectorLength(getShape(), resLen))
1941  result = emitError(llvm::formatv("invalid result type length {0} for shape "
1942  "{1} in tcgen05.ld Op",
1943  resLen, stringifyEnum(getShape())));
1944 
1945  return result;
1946 }
1947 
1948 LogicalResult Tcgen05StOp::verify() {
1949  LogicalResult result = success();
1950  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1951  result = emitError("shape 16x32bx2 requires offset argument");
1952 
1953  auto valTy = getVal().getType();
1954  unsigned valLen = isa<VectorType>(valTy)
1955  ? llvm::cast<VectorType>(valTy).getNumElements()
1956  : 1;
1957  if (!isValidVectorLength(getShape(), valLen))
1958  result = emitError(llvm::formatv("invalid input length {0} for shape "
1959  "{1} in tcgen05.st Op",
1960  valLen, stringifyEnum(getShape())));
1961 
1962  return result;
1963 }
1964 
1965 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1966 /// have ConstantRangeAttr.
1967 static void nvvmInferResultRanges(Operation *op, Value result,
1969  SetIntRangeFn setResultRanges) {
1970  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1971  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1972  rangeAttr.getLower(), rangeAttr.getUpper()});
1973  }
1974 }
1975 
1976 static llvm::Value *getAsPackedI32(llvm::Value *arg,
1977  llvm::IRBuilderBase &builder) {
1978  return builder.CreateBitCast(arg,
1979  llvm::Type::getInt32Ty(builder.getContext()));
1980 }
1981 
1982 NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1983  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1984  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1985 
1987  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1988  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1989  args.push_back(mt.lookupValue(curOp.getC()));
1990 
1991  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1992  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1993  unsigned type = (isASigned << 1) | isBSigned;
1994  const llvm::Intrinsic::ID ids[] = {
1995  llvm::Intrinsic::nvvm_idp4a_u_u,
1996  llvm::Intrinsic::nvvm_idp4a_u_s,
1997  llvm::Intrinsic::nvvm_idp4a_s_u,
1998  llvm::Intrinsic::nvvm_idp4a_s_s,
1999  };
2000  return {ids[type], args};
2001 }
2002 
2003 NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
2004  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2005  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2006 
2008  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2009  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2010  args.push_back(builder.getInt1(curOp.getBHi()));
2011  args.push_back(mt.lookupValue(curOp.getC()));
2012 
2013  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2014  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2015  unsigned type = (isASigned << 1) | isBSigned;
2016  const llvm::Intrinsic::ID ids[] = {
2017  llvm::Intrinsic::nvvm_idp2a_u_u,
2018  llvm::Intrinsic::nvvm_idp2a_u_s,
2019  llvm::Intrinsic::nvvm_idp2a_s_u,
2020  llvm::Intrinsic::nvvm_idp2a_s_s,
2021  };
2022  return {ids[type], args};
2023 }
2024 
2025 static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2026  llvm::IRBuilderBase &builder) {
2027  return builder.CreateAddrSpaceCast(
2028  addr,
2029  llvm::PointerType::get(builder.getContext(),
2030  llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2031 }
2032 
2034 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2035  LLVM::ModuleTranslation &mt,
2036  llvm::IRBuilderBase &builder) {
2037  using MemSpace = NVVM::NVVMMemorySpace;
2038  using CacheLevel = NVVM::PrefetchCacheLevel;
2039 
2040  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2041  std::optional<NVVM::CacheEvictionPriority> evictPriority =
2042  op.getEvictPriority();
2043  unsigned addressSpace =
2044  llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2045  .getAddressSpace();
2046 
2048  llvm::Value *addr = mt.lookupValue(op.getAddr());
2049  args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
2050  : addr);
2051 
2052  if (op.getTensormap())
2053  return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2054 
2055  assert(cacheLevel && "expected cache level for non-tensormap prefetch");
2056 
2057  if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2058  return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2059 
2060  if (evictPriority && *cacheLevel == CacheLevel::L2) {
2061  switch (*evictPriority) {
2062  case NVVM::CacheEvictionPriority::EvictLast:
2063  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2064  case NVVM::CacheEvictionPriority::EvictNormal:
2065  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2066  default:
2067  llvm_unreachable("Invalid cache eviction priority");
2068  }
2069  }
2070 
2071  switch (static_cast<MemSpace>(addressSpace)) {
2072  case MemSpace::Generic:
2073  return *cacheLevel == CacheLevel::L1
2074  ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
2075  : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
2076  case MemSpace::Global:
2077  return *cacheLevel == CacheLevel::L1
2078  ? NVVM::IDArgPair(
2079  {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2080  : NVVM::IDArgPair(
2081  {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2082  case MemSpace::Local:
2083  return *cacheLevel == CacheLevel::L1
2084  ? NVVM::IDArgPair(
2085  {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2086  : NVVM::IDArgPair(
2087  {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2088  default:
2089  llvm_unreachable("Invalid pointer address space");
2090  }
2091 }
2092 
2093 bool NVVM::InlinePtxOp::getAsmValues(
2094  RewriterBase &rewriter,
2095  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2096  &asmValues) {
2097  for (auto arg : getReadWriteArgs())
2098  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
2099  for (auto arg : getResults())
2100  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
2101  for (auto arg : getReadOnlyArgs())
2102  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
2103  if (getPredicate())
2104  asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
2105  return false; // No manual mapping needed
2106 }
2107 
2108 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2109  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2110  auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2112  args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2113  args.push_back(mt.lookupValue(curOp.getMbarrier()));
2114 
2115  llvm::Intrinsic::ID intrinsicID =
2116  curOp.getMulticast()
2117  ? llvm::Intrinsic::
2118  nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2119  : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2120 
2121  return {intrinsicID, args};
2122 }
2123 
2124 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2125  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2126  auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2128  args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2129 
2130  llvm::Intrinsic::ID intrinsicID;
2131 
2132  switch (curOp.getQueryType()) {
2133  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2134  intrinsicID =
2135  llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2136  break;
2137  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2138  intrinsicID = llvm::Intrinsic::
2139  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2140  break;
2141  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2142  intrinsicID = llvm::Intrinsic::
2143  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2144  break;
2145  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2146  intrinsicID = llvm::Intrinsic::
2147  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2148  break;
2149  }
2150  return {intrinsicID, args};
2151 }
2152 
2153 //===----------------------------------------------------------------------===//
2154 // NVVMDialect initialization, type parsing, and registration.
2155 //===----------------------------------------------------------------------===//
2156 
2157 // TODO: This should be the llvm.nvvm dialect once this is supported.
2158 void NVVMDialect::initialize() {
2159  addOperations<
2160 #define GET_OP_LIST
2161 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2162  >();
2163  addAttributes<
2164 #define GET_ATTRDEF_LIST
2165 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2166  >();
2167 
2168  // Support unknown operations because not all NVVM operations are
2169  // registered.
2170  allowUnknownOperations();
2171  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2172  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2173 }
2174 
2175 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
2176  NamedAttribute attr) {
2177  StringAttr attrName = attr.getName();
2178  // Kernel function attribute should be attached to functions.
2179  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2180  if (!isa<LLVM::LLVMFuncOp>(op)) {
2181  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
2182  << "' attribute attached to unexpected op";
2183  }
2184  }
2185  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
2186  // dim
2187  if (attrName == NVVMDialect::getMaxntidAttrName() ||
2188  attrName == NVVMDialect::getReqntidAttrName() ||
2189  attrName == NVVMDialect::getClusterDimAttrName()) {
2190  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
2191  if (!values || values.empty() || values.size() > 3) {
2192  return op->emitError()
2193  << "'" << attrName
2194  << "' attribute must be integer array with maximum 3 index";
2195  }
2196  }
2197  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
2198  // attribute
2199  if (attrName == NVVMDialect::getMinctasmAttrName() ||
2200  attrName == NVVMDialect::getMaxnregAttrName() ||
2201  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2202  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
2203  return op->emitError()
2204  << "'" << attrName << "' attribute must be integer constant";
2205  }
2206  }
2207  // blocksareclusters must be used along with reqntid and cluster_dim
2208  if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2209  if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
2210  !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
2211  return op->emitError()
2212  << "'" << attrName << "' attribute must be used along with "
2213  << "'" << NVVMDialect::getReqntidAttrName() << "' and "
2214  << "'" << NVVMDialect::getClusterDimAttrName() << "'";
2215  }
2216  }
2217 
2218  return success();
2219 }
2220 
2221 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
2222  unsigned regionIndex,
2223  unsigned argIndex,
2224  NamedAttribute argAttr) {
2225  auto funcOp = dyn_cast<FunctionOpInterface>(op);
2226  if (!funcOp)
2227  return success();
2228 
2229  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
2230  StringAttr attrName = argAttr.getName();
2231  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2232  if (!isKernel) {
2233  return op->emitError()
2234  << "'" << attrName
2235  << "' attribute must be present only on kernel arguments";
2236  }
2237  if (!isa<UnitAttr>(argAttr.getValue()))
2238  return op->emitError() << "'" << attrName << "' must be a unit attribute";
2239  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2240  return op->emitError()
2241  << "'" << attrName
2242  << "' attribute requires the argument to also have attribute '"
2243  << LLVM::LLVMDialect::getByValAttrName() << "'";
2244  }
2245  }
2246 
2247  return success();
2248 }
2249 
2250 //===----------------------------------------------------------------------===//
2251 // NVVM Address Space Attr
2252 //===----------------------------------------------------------------------===//
2253 
2254 unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
2255  return static_cast<unsigned>(getValue());
2256 }
2257 
2258 bool NVVMMemorySpaceAttr::isValidLoad(
2259  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2260  const ::mlir::DataLayout *dataLayout,
2262  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2263  dataLayout, emitError);
2264 }
2265 
2266 bool NVVMMemorySpaceAttr::isValidStore(
2267  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2268  const ::mlir::DataLayout *dataLayout,
2270  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2271  dataLayout, emitError);
2272 }
2273 
2274 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2275  ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
2276  std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2278  // TODO: update this method once `ptr.atomic_rmw` is implemented.
2279  assert(false && "unimplemented, see TODO in the source.");
2280  return false;
2281 }
2282 
2283 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2284  Type type, ptr::AtomicOrdering successOrdering,
2285  ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2286  const ::mlir::DataLayout *dataLayout,
2288  // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
2289  assert(false && "unimplemented, see TODO in the source.");
2290  return false;
2291 }
2292 
2293 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2294  Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
2295  // TODO: update this method once the `ptr.addrspace_cast` op is added to the
2296  // dialect.
2297  assert(false && "unimplemented, see TODO in the source.");
2298  return false;
2299 }
2300 
2301 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2302  Type intLikeTy, Type ptrLikeTy,
2304  // TODO: update this method once the int-cast ops are added to the `ptr`
2305  // dialect.
2306  assert(false && "unimplemented, see TODO in the source.");
2307  return false;
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // NVVM target attribute.
2312 //===----------------------------------------------------------------------===//
2313 LogicalResult
2315  int optLevel, StringRef triple, StringRef chip,
2316  StringRef features, DictionaryAttr flags,
2317  ArrayAttr files, bool verifyTarget) {
2318  if (optLevel < 0 || optLevel > 3) {
2319  emitError() << "The optimization level must be a number between 0 and 3.";
2320  return failure();
2321  }
2322  if (triple.empty()) {
2323  emitError() << "The target triple cannot be empty.";
2324  return failure();
2325  }
2326  if (chip.empty()) {
2327  emitError() << "The target chip cannot be empty.";
2328  return failure();
2329  }
2330  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
2331  return mlir::isa_and_nonnull<StringAttr>(attr);
2332  })) {
2333  emitError() << "All the elements in the `link` array must be strings.";
2334  return failure();
2335  }
2336  return success();
2337 }
2338 
2339 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
2340  if (!getVerifyTarget())
2341  return success();
2342 
2343  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2344  if (!gpuModuleOp) {
2345  return emitError(gpuModule->getLoc(),
2346  "NVVM target attribute must be attached to a GPU module");
2347  }
2348 
2349  const NVVMCheckSMVersion targetSMVersion =
2350  NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2351  if (!targetSMVersion.isMinimumSMVersion()) {
2352  return emitError(gpuModule->getLoc(),
2353  "Minimum NVVM target SM version is sm_20");
2354  }
2355 
2356  gpuModuleOp->walk([&](Operation *op) {
2357  if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2358  const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
2359  if (!requirement.isCompatibleWith(targetSMVersion)) {
2360  op->emitOpError() << "is not supported on " << getChip();
2361  return WalkResult::interrupt();
2362  }
2363  }
2364  return WalkResult::advance();
2365  });
2366 
2367  return success();
2368 }
2369 
2370 #define GET_OP_CLASSES
2371 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2372 
2373 #define GET_ATTRDEF_CLASSES
2374 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1241::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
#define _none
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:55
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI32Type()
Definition: Builders.cpp:62
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition: LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition: NVVMDialect.h:54
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:85
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)