MLIR  22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cassert>
39 #include <optional>
40 #include <string>
41 
42 using namespace mlir;
43 using namespace NVVM;
44 
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
47 
48 static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49 
50 //===----------------------------------------------------------------------===//
51 // Verifier methods
52 //===----------------------------------------------------------------------===//
53 
54 // This verifier is shared among the following Ops:
55 // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
56 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
57 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
58  bool isIm2Col,
59  size_t numIm2ColOffsets,
60  Location loc) {
61  if (tensorDims < 1 || tensorDims > 5)
62  return emitError(loc, "expects coordinates between 1 to 5 dimension");
63 
64  // For Im2Col mode, there are two constraints:
65  if (isIm2Col) {
66  // 1. Tensor must always be at least 3-d.
67  if (tensorDims < 3)
68  return emitError(
69  loc,
70  "to use im2col mode, the tensor has to be at least 3-dimensional");
71  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
72  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
73  return emitError(
74  loc, "im2col offsets must be 2 less than number of coordinates");
75  }
76  return success();
77 }
78 
80  TMAStoreMode mode = getMode();
81  // We lower through inline-ptx when getPredicate() is true.
82  // a) Only TILE mode is supported
83  // b) Cache-hint is not supported
84  if (getPredicate()) {
85  if (mode != TMAStoreMode::TILE)
86  return emitError("Inline-ptx lowering supported only for Tile mode.");
87  if (getL2CacheHint())
88  return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
89  }
90 
91  size_t dims = getCoordinates().size();
92  switch (mode) {
93  case TMAStoreMode::TILE:
94  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
95  case TMAStoreMode::IM2COL:
96  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
97  case TMAStoreMode::TILE_SCATTER4:
98  if (dims != 5)
99  return emitError("Scatter4 mode expects 5 coordinates");
100  }
101  return success();
102 }
103 
104 LogicalResult CpAsyncOp::verify() {
105  if (getModifier() != LoadCacheModifierKind::CG &&
106  getModifier() != LoadCacheModifierKind::CA)
107  return emitError("Only CG and CA cache modifiers are supported.");
108  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
109  return emitError("expected byte size to be either 4, 8 or 16.");
110  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
111  return emitError("CG cache modifier is only support for 16 bytes copy.");
112  return success();
113 }
114 
115 // This verify params can be shared across TMA Load and Prefetch Ops.
116 static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
117  TMALoadMode mode, Location loc) {
118  if (tensorDims < 1 || tensorDims > 5)
119  return emitError(loc, "expects coordinates between 1 to 5 dimension");
120 
121  auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
122  size_t expectedIm2colOff) -> LogicalResult {
123  if (isIm2col && (tensorDims < 3))
124  return emitError(loc)
125  << "to use " << stringifyEnum(mode)
126  << " mode, the tensor has to be at least 3-dimensional";
127 
128  if (numIm2colOff != expectedIm2colOff)
129  return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
130  << " (provided " << numIm2colOff << ")";
131 
132  return success();
133  };
134 
135  switch (mode) {
136  case TMALoadMode::TILE:
137  return checkTMALoadParams(mode, false, 0);
138  case TMALoadMode::IM2COL:
139  return checkTMALoadParams(mode, true, tensorDims - 2);
140  case TMALoadMode::IM2COL_W:
141  case TMALoadMode::IM2COL_W_128:
142  return checkTMALoadParams(mode, true, 2);
143  case TMALoadMode::TILE_GATHER4:
144  return (tensorDims == 5)
145  ? checkTMALoadParams(mode, false, 0)
146  : emitError(loc, "Gather4 mode expects 5 coordinates");
147  }
148  return success();
149 }
150 
151 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
152  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
153  getMode(), getLoc());
154 }
155 
157  TMALoadMode mode = getMode();
158  bool isCTAOnly = getIsCTAOnly();
159  if (getPredicate()) { // Inline-asm based lowering
160  if (isCTAOnly)
161  return emitError("Predicate is supported only for shared::cluster mode.");
162  if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163  return emitError(
164  "Predicate is supported only for Tile and Im2col modes.");
165  } else { // Intrinsics-based lowering
166  NVVMMemorySpace expectedAS =
167  isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168  unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
169  .getAddressSpace();
170  if (AS != expectedAS)
171  return emitError()
172  << (isCTAOnly
173  ? "Shared::cta destination requires address-space 3."
174  : "Shared::cluster destination requires address-space 7.");
175  // Checks specific to shared::cta mode
176  if (isCTAOnly) {
177  if (getMulticastMask())
178  return emitError("Multicast is not supported with shared::cta mode.");
179  if (getGroup())
180  return emitError("CTAGroup is not supported with shared::cta mode.");
181  }
182  }
183 
184  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
185  getMode(), getLoc());
186 }
187 
188 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
189  TMAStoreMode mode = getMode();
190  size_t dims = getCoordinates().size();
191  switch (mode) {
192  case TMAStoreMode::TILE:
193  return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
194  case TMAStoreMode::IM2COL:
195  return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
196  case TMAStoreMode::TILE_SCATTER4:
197  return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
198  }
199  return success();
200 }
201 
202 LogicalResult ConvertFloatToTF32Op::verify() {
203  using RndMode = NVVM::FPRoundingMode;
204  switch (getRnd()) {
205  case RndMode::RNA:
206  if (getRelu())
207  return emitError("Relu not supported with rna rounding mode.");
208  break;
209  case RndMode::RN:
210  case RndMode::RZ:
211  break;
212  default:
213  return emitError(
214  "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
215  }
216  return success();
217 }
218 
219 LogicalResult ConvertF32x2ToF6x2Op::verify() {
220  mlir::MLIRContext *ctx = getContext();
221 
222  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
223  return emitOpError("Only ")
224  << mlir::Float6E2M3FNType::get(ctx) << " and "
226  << " types are supported for conversions from f32x2 to f6x2.";
227  }
228  return success();
229 }
230 
231 LogicalResult ConvertF32x2ToF8x2Op::verify() {
232  using RndMode = NVVM::FPRoundingMode;
233  using SatMode = NVVM::SaturationMode;
234 
235  bool isRoundingModeRN = getRnd() == RndMode::RN;
236  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
237  bool isRoundingModeRP = getRnd() == RndMode::RP;
238  bool isSatFinite = getSat() == SatMode::SATFINITE;
239 
240  bool hasRelu = getRelu();
241 
242  mlir::MLIRContext *ctx = getContext();
243 
245  .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
246  [&](mlir::Type) -> LogicalResult {
247  if (!isRoundingModeRN) {
248  return emitOpError("Only RN rounding mode is supported for "
249  "conversions from f32x2 to ")
250  << mlir::Float8E4M3FNType::get(ctx) << " and "
251  << mlir::Float8E5M2Type::get(ctx) << " types";
252  }
253  if (!isSatFinite) {
254  return emitOpError("Only SATFINITE saturation mode is supported "
255  "for conversions "
256  "from f32x2 to ")
257  << mlir::Float8E4M3FNType::get(ctx) << " and "
258  << mlir::Float8E5M2Type::get(ctx) << " types";
259  }
260  return success();
261  })
262  .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
263  if (!(isRoundingModeRZ || isRoundingModeRP)) {
264  return emitOpError("Only RZ and RP rounding modes are supported for "
265  "conversions from f32x2 to ")
266  << mlir::Float8E8M0FNUType::get(ctx) << " type";
267  }
268  if (hasRelu) {
269  return emitOpError("relu not supported for conversions to ")
270  << mlir::Float8E8M0FNUType::get(ctx) << " type";
271  }
272  return success();
273  })
274  .Default([&](mlir::Type) {
275  return emitOpError("Only ")
276  << mlir::Float8E4M3FNType::get(ctx) << ", "
277  << mlir::Float8E5M2Type::get(ctx) << ", and "
279  << " types are "
280  "supported for conversions from f32x2 to f8x2";
281  });
282 }
283 
284 LogicalResult ConvertF16x2ToF8x2Op::verify() {
285  mlir::MLIRContext *ctx = getContext();
286 
287  if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
288  return emitOpError("Only ")
289  << mlir::Float8E4M3FNType::get(ctx) << " and "
291  << " types are supported for conversions from f16x2 to f8x2.";
292  }
293  return success();
294 }
295 
296 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
297  using RndMode = NVVM::FPRoundingMode;
298 
299  if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
300  return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
301  << " type is supported for conversions from "
302  "bf16x2 to f8x2.";
303 
304  auto rnd = getRnd();
305  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
306  return emitOpError("Only RZ and RP rounding modes are supported for "
307  "conversions from bf16x2 to f8x2.");
308 
309  return success();
310 }
311 
312 LogicalResult ConvertF32x2ToF4x2Op::verify() {
313  mlir::MLIRContext *ctx = getContext();
314 
315  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
316  return emitOpError("Only ")
318  << " type is supported for conversions from f32x2 to f4x2.";
319 
320  return success();
321 }
322 
323 LogicalResult ConvertF8x2ToF16x2Op::verify() {
324  mlir::MLIRContext *ctx = getContext();
325 
326  if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
327  return emitOpError("Only ")
328  << mlir::Float8E4M3FNType::get(ctx) << " and "
330  << " types are supported for conversions from f8x2 to f16x2.";
331 
332  return success();
333 }
334 
335 LogicalResult ConvertF8x2ToBF16x2Op::verify() {
336  mlir::MLIRContext *ctx = getContext();
337  if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
338  return emitOpError("Only ")
340  << " type is supported for conversions from f8x2 to bf16x2.";
341 
342  return success();
343 }
344 
345 LogicalResult ConvertF6x2ToF16x2Op::verify() {
346  mlir::MLIRContext *ctx = getContext();
347 
348  if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
349  return emitOpError("Only ")
350  << mlir::Float6E2M3FNType::get(ctx) << " and "
352  << " types are supported for conversions from f6x2 to f16x2.";
353 
354  return success();
355 }
356 
357 LogicalResult ConvertF4x2ToF16x2Op::verify() {
358  mlir::MLIRContext *ctx = getContext();
359 
360  if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
361  return emitOpError("Only ")
363  << " type is supported for conversions from f4x2 to f16x2.";
364 
365  return success();
366 }
367 
368 LogicalResult BulkStoreOp::verify() {
369  if (getInitVal() != 0)
370  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
371  return success();
372 }
373 
374 LogicalResult PMEventOp::verify() {
375  auto eventId = getEventId();
376  auto maskedEventId = getMaskedEventId();
377  if (!maskedEventId && !eventId) {
378  return emitOpError() << "either `id` or `mask` must be set";
379  }
380 
381  if (maskedEventId && eventId) {
382  return emitOpError() << "`id` and `mask` cannot be set at the same time";
383  }
384 
385  if (eventId) {
386  if (eventId < 0 || eventId > 15) {
387  return emitOpError() << "`id` must be between 0 and 15";
388  }
389  }
390 
391  return llvm::success();
392 }
393 
394 // Given the element type of an operand and whether or not it is an accumulator,
395 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
396 // operand's element type.
397 std::optional<mlir::NVVM::MMATypes>
398 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
399  auto half2Type =
400  VectorType::get(2, Float16Type::get(operandElType.getContext()));
401  if (operandElType.isF64())
402  return NVVM::MMATypes::f64;
403  if (operandElType.isF16() || operandElType == half2Type)
404  return NVVM::MMATypes::f16;
405  if (operandElType.isF32() && isAccumulator)
406  return NVVM::MMATypes::f32;
407  if (operandElType.isF32() && !isAccumulator)
408  return NVVM::MMATypes::tf32;
409  if (llvm::isa<IntegerType>(operandElType)) {
410  if (isAccumulator)
411  return NVVM::MMATypes::s32;
412  return std::nullopt;
413  }
414 
415  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
416  if (structType.getBody().empty())
417  return std::nullopt;
418  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
419  }
420 
421  return std::nullopt;
422 }
423 
424 static bool isInt4PtxType(MMATypes type) {
425  return (type == MMATypes::u4 || type == MMATypes::s4);
426 }
427 
428 static bool isInt8PtxType(MMATypes type) {
429  return (type == MMATypes::u8 || type == MMATypes::s8);
430 }
431 
432 static bool isIntegerPtxType(MMATypes type) {
433  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
434  type == MMATypes::s32;
435 }
436 
437 MMATypes MmaOp::accumPtxType() {
438  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
439  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
440  assert(val.has_value() && "accumulator PTX type should always be inferrable");
441  return val.value();
442 }
443 
444 MMATypes MmaOp::resultPtxType() {
445  std::optional<mlir::NVVM::MMATypes> val =
446  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
447  assert(val.has_value() && "result PTX type should always be inferrable");
448  return val.value();
449 }
450 
451 void MmaOp::print(OpAsmPrinter &p) {
452  SmallVector<Type, 4> regTypes;
453  struct OperandFragment {
454  StringRef operandName;
455  StringRef ptxTypeAttr;
457  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
458  : operandName(name), ptxTypeAttr(ptxTypeName) {}
459  };
460 
461  std::array<OperandFragment, 3> frags{
462  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
463  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
464  OperandFragment("C", "")};
465  SmallVector<StringRef, 4> ignoreAttrNames{
466  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
467 
468  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
469  auto &frag = frags[fragIdx];
470  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
471  for (auto operandIdx = varOperandSpec.first;
472  operandIdx < varOperandSpec.first + varOperandSpec.second;
473  operandIdx++) {
474  frag.regs.push_back(this->getOperand(operandIdx));
475  if (operandIdx == 0) {
476  regTypes.push_back(this->getOperand(operandIdx).getType());
477  }
478  }
479  std::optional<MMATypes> inferredType =
480  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
481  if (inferredType)
482  ignoreAttrNames.push_back(frag.ptxTypeAttr);
483  }
484 
485  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
486  p << " " << frag.operandName;
487  p << "[";
488  p.printOperands(frag.regs);
489  p << "] ";
490  };
491 
492  for (const auto &frag : frags) {
493  printMmaOperand(frag);
494  }
495 
496  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
497 
498  // Print the types of the operands and result.
499  p << " : "
500  << "(";
501  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
502  frags[1].regs[0].getType(),
503  frags[2].regs[0].getType()},
504  p);
505  p << ")";
506  p.printArrowTypeList(TypeRange{this->getRes().getType()});
507 }
508 
509 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
510  ValueRange operandA, ValueRange operandB, ValueRange operandC,
511  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
512  std::optional<MMAIntOverflow> intOverflow,
513  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
514  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
515 
516  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
517  MLIRContext *ctx = builder.getContext();
518  result.addAttribute(
519  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
520 
521  result.addOperands(operandA);
522  result.addOperands(operandB);
523  result.addOperands(operandC);
524 
525  if (multiplicandPtxTypes) {
526  result.addAttribute("multiplicandAPtxType",
527  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
528  result.addAttribute("multiplicandBPtxType",
529  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
530  } else {
531  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
532  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
533  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
534  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
535  }
536 
537  if (multiplicandLayouts) {
538  result.addAttribute("layoutA",
539  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
540  result.addAttribute("layoutB",
541  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
542  } else {
543  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
544  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
545  }
546 
547  if (intOverflow.has_value())
548  result.addAttribute("intOverflowBehavior",
549  MMAIntOverflowAttr::get(ctx, *intOverflow));
550  if (b1Op.has_value())
551  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
552 
553  result.addTypes(resultType);
554  result.addAttribute(
555  MmaOp::getOperandSegmentSizeAttr(),
556  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
557  static_cast<int32_t>(operandB.size()),
558  static_cast<int32_t>(operandC.size())}));
559 }
560 
561 // <operation> :=
562 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
563 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
564 // `->` type($res)
565 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
566  struct OperandFragment {
567  std::optional<MMATypes> elemtype;
569  SmallVector<Type> regTypes;
570  };
571 
572  Builder &builder = parser.getBuilder();
573  std::array<OperandFragment, 4> frags;
574 
575  NamedAttrList namedAttributes;
576 
577  // A helper to parse the operand segments.
578  auto parseMmaOperand = [&](StringRef operandName,
579  OperandFragment &frag) -> LogicalResult {
580  if (parser.parseKeyword(operandName).failed())
581  return failure();
582  if (parser
583  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
584  .failed())
585  return failure();
586  return success();
587  };
588 
589  // Parse the operand segments.
590  if (parseMmaOperand("A", frags[0]).failed())
591  return failure();
592  if (parseMmaOperand("B", frags[1]).failed())
593  return failure();
594  if (parseMmaOperand("C", frags[2]).failed())
595  return failure();
596 
597  if (parser.parseOptionalAttrDict(namedAttributes).failed())
598  return failure();
599 
600  // Parse the type specification and resolve operands.
601  SmallVector<Type, 3> operandTypes;
602  if (failed(parser.parseColon()))
603  return failure();
604  if (failed(parser.parseLParen()))
605  return failure();
606  if (failed(parser.parseTypeList(operandTypes)))
607  return failure();
608  if (failed(parser.parseRParen()))
609  if (operandTypes.size() != 3)
610  return parser.emitError(
611  parser.getNameLoc(),
612  "expected one type for each operand segment but got " +
613  Twine(operandTypes.size()) + " types");
614  for (const auto &iter : llvm::enumerate(operandTypes)) {
615  auto &frag = frags[iter.index()];
616  frag.regTypes.resize(frag.regs.size(), iter.value());
617  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
618  parser.getNameLoc(), result.operands)))
619  return failure();
620  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
621  /*isAccumulator*/ iter.index() < 2);
622  }
623 
624  Type resultType;
625  if (parser.parseArrow() || parser.parseType(resultType))
626  return failure();
627  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
628 
629  std::array<StringRef, 2> names{"multiplicandAPtxType",
630  "multiplicandBPtxType"};
631  for (unsigned idx = 0; idx < names.size(); idx++) {
632  const auto &frag = frags[idx];
633  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
634  if (!frag.elemtype.has_value() && !attr.has_value()) {
635  return parser.emitError(
636  parser.getNameLoc(),
637  "attribute " + names[idx] +
638  " is not provided explicitly and cannot be inferred");
639  }
640  if (!attr.has_value())
641  result.addAttribute(
642  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
643  }
644 
645  result.addTypes(resultType);
646  if (!namedAttributes.empty())
647  result.addAttributes(namedAttributes);
648  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
649  builder.getDenseI32ArrayAttr({
650  static_cast<int32_t>(frags[0].regs.size()),
651  static_cast<int32_t>(frags[1].regs.size()),
652  static_cast<int32_t>(frags[2].regs.size()),
653  }));
654  return success();
655 }
656 
657 LogicalResult MmaOp::verify() {
658  MLIRContext *context = getContext();
659  auto f16Ty = Float16Type::get(context);
660  auto i32Ty = IntegerType::get(context, 32);
661  auto f16x2Ty = VectorType::get(2, f16Ty);
662  auto f32Ty = Float32Type::get(context);
663  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
664  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
665 
666  auto s32x4StructTy =
667  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
668  auto f32x8StructTy =
669  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
670  auto f16x2x2StructTy =
671  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
672  auto f32x4StructTy =
673  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
674  auto s32x2StructTy =
675  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
676 
677  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
678  getShapeAttr().getK()};
679 
680  // These variables define the set of allowed data types for matrices A, B, C,
681  // and result.
682  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
683  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
684  AllowedShapes allowedShapes;
685  AllowedTypes expectedA;
686  AllowedTypes expectedB;
687  AllowedTypes expectedC;
688  SmallVector<Type> expectedResult;
689 
690  // When M = 16, we just need to calculate the number of 8xk tiles, where
691  // k is a factor that depends on the data type.
692  if (mmaShape[0] == 16) {
693  int64_t kFactor;
694  Type multiplicandFragType;
695  switch (*getMultiplicandAPtxType()) {
696  case MMATypes::tf32:
697  kFactor = 4;
698  multiplicandFragType = i32Ty;
699  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
700  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
701  break;
702  case MMATypes::bf16:
703  kFactor = 8;
704  multiplicandFragType = i32Ty;
705  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
706  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
707  break;
708  case MMATypes::f16:
709  kFactor = 8;
710  multiplicandFragType = f16x2Ty;
711  expectedResult.push_back(f16x2x2StructTy);
712  expectedResult.push_back(f32x4StructTy);
713  break;
714  case MMATypes::s4:
715  case MMATypes::u4:
716  kFactor = 32;
717  break;
718  case MMATypes::b1:
719  kFactor = 128;
720  break;
721  case MMATypes::s8:
722  case MMATypes::u8:
723  kFactor = 16;
724  break;
725  default:
726  return emitError("invalid shape or multiplicand type: " +
727  stringifyEnum(getMultiplicandAPtxType().value()));
728  }
729 
730  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
731  expectedResult.push_back(s32x4StructTy);
732  expectedC.emplace_back(4, i32Ty);
733  multiplicandFragType = i32Ty;
734  } else {
735  expectedC.emplace_back(2, f16x2Ty);
736  expectedC.emplace_back(4, f32Ty);
737  }
738 
739  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
740  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
741  expectedA.emplace_back(unitA, multiplicandFragType);
742  expectedB.emplace_back(unitB, multiplicandFragType);
743  allowedShapes.push_back({16, 8, kFactor});
744  allowedShapes.push_back({16, 8, kFactor * 2});
745 
746  if (resultPtxType() != accumPtxType())
747  return emitOpError("ctype does not match dtype");
748  }
749 
750  // In the M=8 case, there is only 1 possible case per data type.
751  if (mmaShape[0] == 8) {
752  if (*getMultiplicandAPtxType() == MMATypes::f16) {
753  expectedA.emplace_back(2, f16x2Ty);
754  expectedB.emplace_back(2, f16x2Ty);
755  expectedResult.push_back(f16x2x4StructTy);
756  expectedResult.push_back(f32x8StructTy);
757  expectedC.emplace_back(4, f16x2Ty);
758  expectedC.emplace_back(8, f32Ty);
759  allowedShapes.push_back({8, 8, 4});
760  }
761  if (*getMultiplicandAPtxType() == MMATypes::f64) {
762  Type f64Ty = Float64Type::get(context);
763  expectedA.emplace_back(1, f64Ty);
764  expectedB.emplace_back(1, f64Ty);
765  expectedC.emplace_back(2, f64Ty);
766  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
767  context, SmallVector<Type>(2, f64Ty)));
768  allowedShapes.push_back({8, 8, 4});
769  }
770  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
771  expectedA.push_back({i32Ty});
772  expectedB.push_back({i32Ty});
773  expectedC.push_back({i32Ty, i32Ty});
774  expectedResult.push_back(s32x2StructTy);
775  if (isInt4PtxType(getMultiplicandAPtxType().value()))
776  allowedShapes.push_back({8, 8, 32});
777  if (isInt8PtxType(getMultiplicandAPtxType().value()))
778  allowedShapes.push_back({8, 8, 16});
779  if (getMultiplicandAPtxType().value() == MMATypes::b1)
780  allowedShapes.push_back({8, 8, 128});
781  }
782  }
783 
784  std::string errorMessage;
785  llvm::raw_string_ostream errorStream(errorMessage);
786 
787  // Check that we matched an existing shape/dtype combination.
788  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
789  !llvm::is_contained(allowedShapes, mmaShape)) {
790  errorStream << "unimplemented variant for MMA shape <";
791  llvm::interleaveComma(mmaShape, errorStream);
792  errorStream << ">";
793  return emitOpError(errorMessage);
794  }
795 
796  // Verify the operand types for segments of A, B, and C operands.
797  std::array<StringRef, 3> operandNames{"A", "B", "C"};
798  for (const auto &iter : llvm::enumerate(
799  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
800  auto spec = this->getODSOperandIndexAndLength(iter.index());
801  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
802  operand_type_begin() + spec.first +
803  spec.second);
804  bool match = llvm::is_contained(iter.value(), operandTySeg);
805 
806  if (!match) {
807  errorStream << "Could not match types for the "
808  << operandNames[iter.index()]
809  << " operands; expected one of ";
810  for (const auto &x : iter.value()) {
811  errorStream << x.size() << "x" << x[0] << " ";
812  }
813  errorStream << "but got ";
814  llvm::interleaveComma(operandTySeg, errorStream);
815  return emitOpError(errorMessage);
816  }
817  }
818 
819  // Check the result type
820  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
821  return expectedResultType == getResult().getType();
822  })) {
823  errorStream
824  << "Could not match allowed types for the result; expected one of ";
825  llvm::interleaveComma(expectedResult, errorStream);
826  errorStream << " but got " << getResult().getType();
827  return emitOpError(errorMessage);
828  }
829 
830  // Ensure that binary MMA variants have a b1 MMA operation defined.
831  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
832  return emitOpError("op requires " + getB1OpAttrName().strref() +
833  " attribute");
834  }
835 
836  // Ensure int4/int8 MMA variants specify the accum overflow behavior
837  // attribute.
838  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
839  isInt8PtxType(*getMultiplicandAPtxType())) {
840  if (!getIntOverflowBehavior())
841  return emitOpError("op requires " +
842  getIntOverflowBehaviorAttrName().strref() +
843  " attribute");
844  }
845 
846  // Validate layout combinations. According to the operation description, most
847  // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
848  // can use other layout combinations.
849  bool isM8N8K4_F16 =
850  (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
851  getMultiplicandAPtxType() == MMATypes::f16);
852 
853  if (!isM8N8K4_F16) {
854  // For all other shapes/types, layoutA must be row and layoutB must be col
855  if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
856  return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
857  "layoutB = #nvvm.mma_layout<col> for shape <")
858  << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
859  << "> with element types "
860  << stringifyEnum(*getMultiplicandAPtxType()) << " and "
861  << stringifyEnum(*getMultiplicandBPtxType())
862  << ". Only m8n8k4 with f16 supports other layouts.";
863  }
864  }
865 
866  return success();
867 }
868 
869 LogicalResult ShflOp::verify() {
870  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
871  return success();
872  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
873  auto elementType = (type && type.getBody().size() == 2)
874  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
875  : nullptr;
876  if (!elementType || elementType.getWidth() != 1)
877  return emitError("expected return type to be a two-element struct with "
878  "i1 as the second element");
879  return success();
880 }
881 
882 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
883  NVVM::MMAFrag frag, int nRow,
884  int nCol,
885  MLIRContext *context) {
886  unsigned numberElements = 0;
887  Type elementType;
888  OpBuilder builder(context);
889  Type f16x2 = VectorType::get(2, builder.getF16Type());
890  if (type == NVVM::MMATypes::f16) {
891  elementType = f16x2;
892  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
893  numberElements = 8;
894  else
895  numberElements = 4;
896  } else if (type == NVVM::MMATypes::f32) {
897  elementType = builder.getF32Type();
898  numberElements = 8;
899  } else if (type == NVVM::MMATypes::f64) {
900  elementType = builder.getF64Type();
901  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
902  numberElements = 1;
903  else
904  numberElements = 2;
905  } else if (type == NVVM::MMATypes::tf32) {
906  elementType = builder.getI32Type();
907  numberElements = 4;
908  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
909  elementType = builder.getI32Type();
910  int parallelSize = 0;
911  if (frag == NVVM::MMAFrag::a)
912  parallelSize = nRow;
913  if (frag == NVVM::MMAFrag::b)
914  parallelSize = nCol;
915 
916  // m == 16 && n == 16 && k == 16
917  if (parallelSize == 16)
918  numberElements = 2;
919  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
920  else if (parallelSize == 8)
921  numberElements = 1;
922  else if (parallelSize == 32)
923  numberElements = 4;
924  } else if (type == NVVM::MMATypes::s32) {
925  elementType = builder.getI32Type();
926  numberElements = 8;
927  }
928  assert(numberElements != 0 && elementType != nullptr);
929  return std::make_pair(elementType, numberElements);
930 }
931 
932 static std::pair<mlir::Type, unsigned>
933 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
934  int k, MLIRContext *context) {
935  int nRow, nCol;
936  if (frag == NVVM::MMAFrag::a) {
937  nRow = m;
938  nCol = k;
939  } else if (frag == NVVM::MMAFrag::b) {
940  nRow = k;
941  nCol = n;
942  } else {
943  nRow = m;
944  nCol = n;
945  }
946  assert(nRow && nCol);
947  return inferMMAType(type, frag, nRow, nCol, context);
948 }
949 
950 LogicalResult NVVM::WMMALoadOp::verify() {
951  unsigned addressSpace =
952  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
953  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
954  addressSpace != NVVMMemorySpace::Shared)
955  return emitOpError("expected source pointer in memory "
956  "space 0, 1, 3");
957 
958  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
959  getEltype(), getFrag()) == 0)
960  return emitOpError() << "invalid attribute combination";
961  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
962  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
963  // Special case for f64 fragments
964  Type f64Ty = Float64Type::get(getContext());
965  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
966  if (getType() != f64Ty)
967  return emitOpError("expected destination type to be f64");
968  return success();
969  }
970  // Everything else is a struct
971  Type dstType = LLVM::LLVMStructType::getLiteral(
972  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
973  if (getType() != dstType)
974  return emitOpError("expected destination type is a structure of ")
975  << typeInfo.second << " elements of type " << typeInfo.first;
976  return success();
977 }
978 
979 LogicalResult NVVM::WMMAStoreOp::verify() {
980  unsigned addressSpace =
981  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
982  if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
983  addressSpace != NVVMMemorySpace::Shared)
984  return emitOpError("expected operands to be a source pointer in memory "
985  "space 0, 1, 3");
986 
987  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
988  getEltype()) == 0)
989  return emitOpError() << "invalid attribute combination";
990  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
991  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
992  if (getArgs().size() != typeInfo.second)
993  return emitOpError() << "expected " << typeInfo.second << " data operands";
994  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
995  return operands.getType() != typeInfo.first;
996  }))
997  return emitOpError() << "expected data operands of type " << typeInfo.first;
998  return success();
999 }
1000 
1001 LogicalResult NVVM::WMMAMmaOp::verify() {
1002  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
1003  getLayoutB(), getEltypeA(),
1004  getEltypeB()) == 0)
1005  return emitOpError() << "invalid attribute combination";
1006  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
1007  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
1008  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
1009  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
1010  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
1011  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
1012  SmallVector<Type, 32> arguments;
1013  arguments.append(typeInfoA.second, typeInfoA.first);
1014  arguments.append(typeInfoB.second, typeInfoB.first);
1015  arguments.append(typeInfoC.second, typeInfoC.first);
1016  unsigned numArgs = arguments.size();
1017  if (getArgs().size() != numArgs)
1018  return emitOpError() << "expected " << numArgs << " arguments";
1019  for (unsigned i = 0; i < numArgs; i++) {
1020  if (getArgs()[i].getType() != arguments[i])
1021  return emitOpError() << "expected argument " << i << " to be of type "
1022  << arguments[i];
1023  }
1024  Type dstType = LLVM::LLVMStructType::getLiteral(
1025  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
1026  if (getType() != dstType)
1027  return emitOpError("expected destination type is a structure of ")
1028  << typeInfoC.second << " elements of type " << typeInfoC.first;
1029  return success();
1030 }
1031 
1032 LogicalResult NVVM::LdMatrixOp::verify() {
1033  uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
1034  if (m == 8 && n == 8) {
1035  if (num != 1 && num != 2 && num != 4) {
1036  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
1037  "matrix");
1038  }
1039  if (getEltType() != LdStMatrixEltType::B16) {
1040  return emitOpError("expected element type to be b16 for 8x8 matrix");
1041  }
1042  } else if (m == 8 && n == 16) {
1043  if (num != 1 && num != 2 && num != 4) {
1044  return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
1045  "matrix");
1046  }
1047  if (getLayout() != MMALayout::row) {
1048  return emitOpError("expected layout to be row for 8x16 matrix");
1049  }
1050  if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1051  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1052  return emitOpError("expected element type to be b8x16.b4x16_p64 or "
1053  "b8x16.b6x16_p32 for 8x16 matrix");
1054  }
1055  } else if (m == 16 && n == 16) {
1056  if (num != 1 && num != 2) {
1057  return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
1058  "matrix");
1059  }
1060  if (getLayout() != MMALayout::col) {
1061  return emitOpError("expected layout to be col for 16x16 matrix");
1062  }
1063  if (getEltType() != LdStMatrixEltType::B8 &&
1064  getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1065  getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1066  return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
1067  "b8x16.b6x16_p32 for 16x16 matrix");
1068  }
1069  } else {
1070  return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
1071  }
1072 
1073  Type i32 = IntegerType::get(getContext(), 32);
1074  uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1075  if (numElements == 1 && getType() != i32)
1076  return emitOpError("expected destination type is i32");
1077  if (numElements == 2 || numElements == 4) {
1078  Type dstType = LLVM::LLVMStructType::getLiteral(
1079  getContext(), SmallVector<Type>(numElements, i32));
1080  if (getType() != dstType)
1081  return emitOpError("expected destination type is a structure of ")
1082  << numElements << " elements of type i32";
1083  }
1084 
1085  return success();
1086 }
1087 
1088 LogicalResult NVVM::StMatrixOp::verify() {
1089  int numMatrix = getSources().size();
1090  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1091  return emitOpError("expected num attribute to be 1, 2 or 4");
1092 
1093  int m = getShape().getM(), n = getShape().getN();
1094  if (m == 8 && n == 8) {
1095  if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1096  return emitOpError("expected element type to be B16 for 8x8 matrix");
1097  }
1098  } else if (m == 16 && n == 8) {
1099  if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1100  return emitOpError("expected element type to be B8 for 16x8 matrix");
1101  }
1102  if (getLayout() != NVVM::MMALayout::col) {
1103  return emitOpError("expected layout to be col for 16x8 matrix");
1104  }
1105  } else {
1106  return emitOpError("expected shape to be 8x8 or 16x8");
1107  }
1108 
1109  return success();
1110 }
1111 
1112 static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
1113  if (typeA == NVVM::WGMMATypes::tf32)
1114  return 8;
1115  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1116  return 16;
1117  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1118  return 32;
1119  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1120  return 32;
1121  if (typeA == NVVM::WGMMATypes::b1)
1122  return 256;
1123  return failure();
1124 }
1125 
1126 static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
1127  NVVM::WGMMATypes typeA,
1128  NVVM::WGMMATypes typeB) {
1129  switch (typeA) {
1130  case NVVM::WGMMATypes::f16:
1131  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1132  typeB == NVVM::WGMMATypes::f16)
1133  return success();
1134  break;
1135  case NVVM::WGMMATypes::tf32:
1136  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1137  return success();
1138  break;
1139  case NVVM::WGMMATypes::u8:
1140  case NVVM::WGMMATypes::s8:
1141  if (typeD == NVVM::WGMMATypes::s32 &&
1142  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1143  return success();
1144  break;
1145  case NVVM::WGMMATypes::b1:
1146  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1147  return success();
1148  break;
1149  case NVVM::WGMMATypes::bf16:
1150  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1151  typeB == NVVM::WGMMATypes::bf16)
1152  return success();
1153  break;
1154  case NVVM::WGMMATypes::e4m3:
1155  case NVVM::WGMMATypes::e5m2:
1156  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1157  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1158  return success();
1159  break;
1160  case WGMMATypes::f32:
1161  case WGMMATypes::s32:
1162  llvm_unreachable("unsupported input types");
1163  break;
1164  }
1165  return failure();
1166 }
1167 
1168 static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
1169  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
1170  72, 80, 88, 96, 104, 112, 120, 128,
1171  136, 144, 152, 160, 168, 176, 184, 192,
1172  200, 208, 216, 224, 232, 240, 248, 256};
1173  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
1174  80, 96, 112, 128, 144, 160,
1175  176, 192, 208, 224, 240, 256};
1176  switch (typeA) {
1177  case WGMMATypes::f16:
1178  case WGMMATypes::tf32:
1179  case WGMMATypes::bf16:
1180  case WGMMATypes::e4m3:
1181  case WGMMATypes::e5m2:
1182  if (llvm::is_contained(allowedN, sizeN))
1183  return success();
1184  break;
1185  case WGMMATypes::u8:
1186  case WGMMATypes::s8:
1187  case WGMMATypes::b1:
1188  if (llvm::is_contained(allowedNshort, sizeN))
1189  return success();
1190  break;
1191  case WGMMATypes::f32:
1192  case WGMMATypes::s32:
1193  llvm_unreachable("unsupported input types");
1194  break;
1195  }
1196  return failure();
1197 }
1198 
1199 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1200  Value outValue = getResults();
1201  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
1202  if (!stype)
1203  return emitOpError() << "expected results to be struct";
1204  int outputSize = stype.getBody().size();
1205  WGMMATypes typeD = getTypeD();
1206  WGMMATypes typeA = getTypeA();
1207  WGMMATypes typeB = getTypeB();
1208 
1209  for (Type t : stype.getBody()) {
1210  if (t != stype.getBody().front())
1211  return emitOpError()
1212  << "all elements in struct must be same type but there is " << t;
1213  }
1214 
1215  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1216  typeD != WGMMATypes::s32) {
1217  return emitOpError() << "does not support the given output type "
1218  << NVVM::stringifyWGMMATypes(typeD);
1219  }
1220  if (typeD == WGMMATypes::s32 &&
1221  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1222  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
1223  }
1224 
1225  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
1226  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1227  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
1228  << NVVM::stringifyWGMMATypes(typeB)
1229  << ", it is not supported.";
1230  }
1231 
1232  // Check M
1233  if (getShape().getM() != 64)
1234  return emitOpError() << "shape 'm' must be 64";
1235 
1236  // Check K
1237  FailureOr<int> allowedK = getAllowedSizeK(typeA);
1238  if (failed(allowedK) || allowedK.value() != getShape().getK())
1239  return emitOpError() << "shape 'k' must be " << allowedK.value()
1240  << " for input type "
1241  << NVVM::stringifyWGMMATypes(typeA);
1242 
1243  // Check N
1244  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
1245  return emitOpError() << "has input type "
1246  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
1247  << getShape().getN() << ", it is not supported.";
1248  }
1249 
1250  // Check transpose (only available for f16/bf16)
1251  // Matrices A should be stored in row-major and B in column-major.
1252  // Only f16/bf16 matrices can be stored in either column-major or row-major
1253  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
1254  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1255  (getLayoutA() == mlir::NVVM::MMALayout::col ||
1256  getLayoutB() == mlir::NVVM::MMALayout::row)) {
1257  return emitOpError()
1258  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1259  << " and layout_b = " << stringifyMMALayout(getLayoutB())
1260  << " for input types " << stringifyWGMMATypes(typeA) << " and "
1261  << stringifyWGMMATypes(typeB)
1262  << " requires transpose. However, this is only supported for: "
1263  << stringifyMMATypes(MMATypes::f16) << " and "
1264  << stringifyMMATypes(MMATypes::bf16);
1265  }
1266 
1267  // Check result registers
1268  int expectedOutput = 0;
1269  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1270  expectedOutput = getShape().getN() / 2;
1271  if (typeD == WGMMATypes::f16)
1272  expectedOutput = getShape().getN() / 4;
1273  if (outputSize != expectedOutput) {
1274  return emitOpError() << "results " << expectedOutput
1275  << ", however output struct has " << outputSize
1276  << " elements";
1277  }
1278  // Check satfinite (only available for s32 accumulator)
1279  if (typeD != WGMMATypes::s32 &&
1280  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1281  NVVM::MMAIntOverflow::satfinite) {
1282  return emitOpError()
1283  << " `satfinite` can be only used with s32 accumulator, however "
1284  "the current accumulator is "
1285  << NVVM::stringifyWGMMATypes(typeD);
1286  }
1287 
1288  return success();
1289 }
1290 
1291 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1292 
1293  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1294  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1295 
1296  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1297 
1298  int expectedOutputRegisters = 0;
1299  if (getTypeD() == WGMMATypes::f16)
1300  expectedOutputRegisters = getShape().getN() / 4;
1301  else
1302  expectedOutputRegisters = getShape().getN() / 2;
1303 
1304  std::string ptx;
1305  llvm::raw_string_ostream ss(ptx);
1306 
1307  ss << "{\n"
1308  ".reg .pred p;\n"
1309  "setp.ne.b32 p, $"
1310  << ((expectedOutputRegisters * 2) + 2)
1311  << ", 0;\n"
1312  "wgmma.mma_async.sync.aligned.m"
1313  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1314  << stringifyWGMMATypes(getTypeA()) << "."
1315  << stringifyWGMMATypes(getTypeB());
1316  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1317  NVVM::MMAIntOverflow::satfinite)
1318  ss << ".satfinite";
1319  ss << " {";
1320  int regCnt = 0;
1321  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1322  ss << "$" << regCnt;
1323  if (regCnt != expectedOutputRegisters - 1)
1324  ss << ", ";
1325  }
1326 
1327  ss << "},";
1328  // Need to map read/write registers correctly.
1329  regCnt = (regCnt * 2);
1330  ss << " $" << (regCnt) << ","
1331  << " $" << (regCnt + 1) << ","
1332  << " p";
1333  if (getTypeD() != WGMMATypes::s32) {
1334  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1335  }
1336  // Don't add transpose parameters unless needed.
1337  if (isF16) {
1338  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1339  }
1340  ss << ";\n"
1341  << "}\n";
1342  return ptx;
1343 }
1344 
1345 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1346  RewriterBase &rewriter,
1347  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1348  &asmValues) {
1349  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1350  if (getResults())
1351  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1352  if (getInouts())
1353  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1354  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1355  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1356  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1358  if (getTypeD() != WGMMATypes::s32) {
1359  asmValues.push_back(
1360  {makeConstantI32(rewriter,
1361  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1363  asmValues.push_back(
1364  {makeConstantI32(rewriter,
1365  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1367  }
1368  if (isF16) {
1369  asmValues.push_back(
1370  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1372  asmValues.push_back(
1373  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1375  }
1376  return true; // Has manual mapping
1377 }
1378 
1379 LogicalResult NVVM::FenceProxyOp::verify() {
1380  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1381  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1382  if (getKind() == NVVM::ProxyKind::GENERIC)
1383  return emitOpError() << "generic proxy not a supported proxy kind";
1384  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1385  return emitOpError() << "async_shared fence requires space attribute";
1386  }
1387  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1388  return emitOpError() << "only async_shared fence can have space attribute";
1389  }
1390  return success();
1391 }
1392 
1393 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1394  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1395  return emitOpError("uni-directional proxies only support generic for "
1396  "from_proxy attribute");
1397 
1398  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1399  return emitOpError("uni-directional proxies only support tensormap "
1400  "for to_proxy attribute");
1401 
1402  return success();
1403 }
1404 
1405 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1406  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1407  return emitOpError("uni-directional proxies only support generic for "
1408  "from_proxy attribute");
1409 
1410  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1411  return emitOpError("uni-directional proxies only support tensormap "
1412  "for to_proxy attribute");
1413 
1414  return success();
1415 }
1416 
1417 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1418  if (getRegCount() % 8)
1419  return emitOpError("new register size must be multiple of 8");
1420  if (getRegCount() < 24 || getRegCount() > 256)
1421  return emitOpError("new register size must be in between 24 to 256");
1422  return success();
1423 }
1424 
1425 LogicalResult NVVM::BarrierOp::verify() {
1426  if (getNumberOfThreads() && !getBarrierId())
1427  return emitOpError(
1428  "barrier id is missing, it should be set between 0 to 15");
1429  return success();
1430 }
1431 
1432 LogicalResult NVVM::Tcgen05CpOp::verify() {
1433  auto mc = getMulticast();
1434 
1435  using SH = Tcgen05CpShape;
1436  using MC = Tcgen05CpMulticast;
1437  switch (getShape()) {
1438  case SH::SHAPE_128x256b:
1439  case SH::SHAPE_128x128b:
1440  case SH::SHAPE_4x256b:
1441  if (mc != MC::NONE)
1442  return emitError("Invalid multicast type for tcgen05.cp Op");
1443  break;
1444  case SH::SHAPE_64x128b:
1445  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1446  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1447  "warpx2_02_13 for tcgen05.cp Op");
1448  break;
1449  case SH::SHAPE_32x128b:
1450  if (mc != MC::WARPX4)
1451  return emitError(
1452  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1453  break;
1454  }
1455  return success();
1456 }
1457 
1458 LogicalResult NVVM::MatchSyncOp::verify() {
1459  if (getKind() == NVVM::MatchSyncKind::all) {
1460  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1461  if (!type || type.getBody().size() != 2 ||
1462  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1463  return emitOpError("match.sync 'all' returns a two element struct with "
1464  "first element as i32 and second element as i1");
1465  }
1466  } else {
1467  if (!getType().isInteger(32)) {
1468  return emitOpError("match.sync 'any' returns an i32");
1469  }
1470  }
1471  return success();
1472 }
1473 
1474 LogicalResult NVVM::VoteSyncOp::verify() {
1475  if (getKind() == NVVM::VoteSyncKind::ballot) {
1476  if (!getType().isInteger(32)) {
1477  return emitOpError("vote.sync 'ballot' returns an i32");
1478  }
1479  } else {
1480  if (!getType().isInteger(1)) {
1481  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1482  }
1483  }
1484  return success();
1485 }
1486 
1487 LogicalResult NVVM::PrefetchOp::verify() {
1488  using MemSpace = NVVM::NVVMMemorySpace;
1489  using CacheLevel = NVVM::PrefetchCacheLevel;
1490 
1491  unsigned addressSpace =
1492  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1493  std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1494  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1495 
1496  if (getTensormap() && cacheLevel)
1497  return emitOpError("cannot specify both tensormap and cache level");
1498 
1499  if (getTensormap()) {
1500  if (addressSpace != MemSpace::Generic &&
1501  addressSpace != MemSpace::Constant) {
1502  return emitOpError(
1503  "prefetch tensormap requires a generic or constant pointer");
1504  }
1505 
1506  if (evictPriority) {
1507  return emitOpError(
1508  "prefetch tensormap does not support eviction priority");
1509  }
1510 
1511  if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1512  return emitOpError(
1513  "in_param_space can only be specified for a generic pointer");
1514  }
1515 
1516  } else if (cacheLevel) {
1517  if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1518  addressSpace != MemSpace::Local) {
1519  return emitOpError("prefetch to cache level requires a generic, global, "
1520  "or local pointer");
1521  }
1522 
1523  if (getUniform()) {
1524  if (*cacheLevel != CacheLevel::L1) {
1525  return emitOpError(
1526  "unsupported cache level, the only supported uniform "
1527  "cache level is L1");
1528  }
1529 
1530  if (addressSpace != MemSpace::Generic) {
1531  return emitOpError(
1532  "prefetch to uniform cache requires a generic pointer");
1533  }
1534  }
1535 
1536  if (evictPriority) {
1537  if (*cacheLevel != CacheLevel::L2)
1538  return emitOpError(
1539  "cache eviction priority supported only for cache level L2");
1540 
1541  if (addressSpace != MemSpace::Global)
1542  return emitOpError("cache eviction priority requires a global pointer");
1543 
1544  if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1545  *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1546  return emitOpError(
1547  "unsupported cache eviction priority, only evict_last and "
1548  "evict_normal are supported");
1549  }
1550 
1551  if (getPredicate())
1552  return emitOpError("predicate supported only on prefetch tensormap");
1553 
1554  } else {
1555  return emitOpError(
1556  "requires specification of either cache level or tensormap");
1557  }
1558 
1559  return success();
1560 }
1561 
1563  switch (getQueryType()) {
1564  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1565  if (!getType().isInteger(1))
1566  return emitOpError("is_canceled query type returns an i1");
1567  break;
1568  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1569  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1570  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1571  if (!getType().isInteger(32)) {
1572  return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
1573  "get_first_cta_id_z query types return an i32");
1574  }
1575  break;
1576  }
1577  return success();
1578 }
1579 
1580 /// Packs the given `field` into the `result`.
1581 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1582 static llvm::Value *
1583 packValInto64Bits(llvm::IRBuilderBase &builder,
1584  llvm::Value *result, // the `result` (unset bits are zero)
1585  llvm::Value *field, // `field` to pack into `result`
1586  unsigned sizeInBits, // Size of `field` in bits
1587  unsigned start) { // Starting bit within `result`
1588  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1589 
1590  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1591  if (mask != 0xffffffffu)
1592  field = builder.CreateAnd(field, builder.getInt32(mask));
1593 
1594  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1595  field = builder.CreateShl(field, start);
1596 
1597  return builder.CreateOr(result, field);
1598 }
1599 
1600 void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1601  LLVM::ModuleTranslation &mt,
1602  llvm::IRBuilderBase &builder) {
1603  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1604  llvm::Value *smemDesc = builder.getInt64(0);
1605 
1606  smemDesc = packValInto64Bits(builder, smemDesc,
1607  mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1608  smemDesc = packValInto64Bits(
1609  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1610  smemDesc = packValInto64Bits(
1611  builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1612 
1613  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1614  smemDesc = packValInto64Bits(builder, smemDesc,
1615  mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1616  smemDesc = packValInto64Bits(
1617  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1618  smemDesc = packValInto64Bits(builder, smemDesc,
1619  mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1620 
1621  mt.mapValue(thisOp.getRes()) = smemDesc;
1622 }
1623 
1624 //===----------------------------------------------------------------------===//
1625 // getPtx methods
1626 //===----------------------------------------------------------------------===//
1627 
1628 std::string NVVM::MBarrierInitOp::getPtx() {
1629  unsigned addressSpace =
1630  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1631  return (addressSpace == NVVMMemorySpace::Shared)
1632  ? std::string("mbarrier.init.shared.b64 [%0], %1;")
1633  : std::string("mbarrier.init.b64 [%0], %1;");
1634 }
1635 
1636 //===----------------------------------------------------------------------===//
1637 // getIntrinsicID/getIntrinsicIDAndArgs methods
1638 //===----------------------------------------------------------------------===//
1639 
1640 mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
1641  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1642  auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1643  unsigned addressSpace =
1644  llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1645  .getAddressSpace();
1646  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
1647  ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1648  : llvm::Intrinsic::nvvm_mbarrier_init;
1649 
1650  // Fill the Intrinsic Args
1652  args.push_back(mt.lookupValue(thisOp.getAddr()));
1653  args.push_back(mt.lookupValue(thisOp.getCount()));
1654 
1655  return {id, std::move(args)};
1656 }
1657 
1658 mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
1659  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1660  auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1661  unsigned addressSpace =
1662  llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1663  .getAddressSpace();
1664  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
1665  ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1666  : llvm::Intrinsic::nvvm_mbarrier_inval;
1667 
1668  return {id, {mt.lookupValue(thisOp.getAddr())}};
1669 }
1670 
1671 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1672  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1673 
1674 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1675  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1676 
1678 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1681 
1682  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1683  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1684  switch (cpAsyncOp.getSize()) {
1685  case 4:
1686  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1687  break;
1688  case 8:
1689  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1690  break;
1691  case 16:
1692  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1693  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1694  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1695  break;
1696  default:
1697  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1698  }
1699 
1700  // Fill the Intrinsic Args
1701  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1702  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1703  if (hasCpSize)
1704  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1705 
1706  return id;
1707 }
1708 
1709 mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1710  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1711  auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1713  llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1714 
1715  // Fill the Intrinsic Args
1716  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1717  args.push_back(mt.lookupValue(thisOp.getSize()));
1718 
1719  mlir::Value cacheHint = thisOp.getL2CacheHint();
1720  const bool hasCacheHint = static_cast<bool>(cacheHint);
1721  llvm::Value *i64Unused =
1722  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1723  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1724  args.push_back(builder.getInt1(hasCacheHint));
1725 
1726  return {id, std::move(args)};
1727 }
1728 
1729 mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1730  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1731  auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1733 
1734  // Fill the Intrinsic Args: dst, mbar, src, size.
1735  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1736  args.push_back(mt.lookupValue(thisOp.getMbar()));
1737  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1738  args.push_back(mt.lookupValue(thisOp.getSize()));
1739 
1740  // Multicast mask, if available.
1741  mlir::Value multicastMask = thisOp.getMulticastMask();
1742  const bool hasMulticastMask = static_cast<bool>(multicastMask);
1743  llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1744  args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1745 
1746  // Cache hint, if available.
1747  mlir::Value cacheHint = thisOp.getL2CacheHint();
1748  const bool hasCacheHint = static_cast<bool>(cacheHint);
1749  llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
1750  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1751 
1752  // Flag arguments for multicast and cachehint.
1753  args.push_back(builder.getInt1(hasMulticastMask));
1754  args.push_back(builder.getInt1(hasCacheHint));
1755 
1756  llvm::Intrinsic::ID id =
1757  llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
1758 
1759  return {id, std::move(args)};
1760 }
1761 
1762 mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1763  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1764  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1766  llvm::Intrinsic::ID id =
1767  llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1768 
1769  // Fill the Intrinsic Args
1770  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1771  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1772  args.push_back(mt.lookupValue(thisOp.getSize()));
1773 
1774  mlir::Value cacheHint = thisOp.getL2CacheHint();
1775  const bool hasCacheHint = static_cast<bool>(cacheHint);
1776  llvm::Value *i64Unused =
1777  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1778  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1779  args.push_back(builder.getInt1(hasCacheHint));
1780 
1781  // Choose the bytemask variant
1782  if (mlir::Value byteMask = thisOp.getByteMask()) {
1783  args.push_back(mt.lookupValue(byteMask));
1784  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1785  }
1786 
1787  return {id, std::move(args)};
1788 }
1789 
1790 bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1791  RewriterBase &rewriter,
1792  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1793  &asmValues) {
1794  // Add all the operands but not the attrs to the asmValues list.
1795  // The attrs here are used to generate the right variants for
1796  // intrinsics-lowering. So, we ignore them while generating inline-PTX.
1797  for (auto val : getOperands())
1798  asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
1799 
1800  return false;
1801 }
1802 
1804 CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1805  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1806  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1807  const bool isCTAOnly = thisOp.getIsCTAOnly();
1809 
1810  // Fill the Intrinsic Args
1811  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1812  args.push_back(mt.lookupValue(thisOp.getMbar()));
1813  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1814 
1815  // Coordinates and im2col-offsets
1816  for (mlir::Value v : thisOp.getCoordinates())
1817  args.push_back(mt.lookupValue(v));
1818  for (mlir::Value v : thisOp.getIm2colOffsets())
1819  args.push_back(mt.lookupValue(v));
1820 
1821  // MulticastMask, if available
1822  mlir::Value mcMask = thisOp.getMulticastMask();
1823  const bool hasMC = static_cast<bool>(mcMask);
1824  llvm::Value *i16Zero =
1825  llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
1826 
1827  // CacheHint, if available
1828  mlir::Value cacheHint = thisOp.getL2CacheHint();
1829  const bool hasCacheHint = static_cast<bool>(cacheHint);
1830  llvm::Value *i64Zero =
1831  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1832 
1833  // Flag argument CTAGroup
1834  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1835  // Hence, the +1 to getGroup().
1836  const int32_t val =
1837  thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
1838  llvm::Value *cg =
1839  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
1840 
1841  if (!isCTAOnly) {
1842  // For shared::cluster, all the arguments that we build are applicable.
1843  args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1844  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1845  args.push_back(builder.getInt1(hasMC));
1846  args.push_back(builder.getInt1(hasCacheHint));
1847  args.push_back(cg);
1848  } else {
1849  // For shared::cta, only cache-hint is applicable.
1850  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1851  args.push_back(builder.getInt1(hasCacheHint));
1852  }
1853 
1854  constexpr size_t numDims = 5; // 1D to 5D
1855  constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
1856  using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1857  using TableTy = std::array<rowTy, numModes>;
1858  static constexpr TableTy IDTable{
1859  {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1860  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1861  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1862  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1863  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1865  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1866  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1867  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1869  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1870  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1871  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1873  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1874  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1875  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1877  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1878 
1879  static constexpr TableTy IDTableCTA{
1880  {{notIntrinsic,
1881  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1882  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1883  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1884  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1885  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1887  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1888  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1889  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1891  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1892  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1893  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1895  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1896  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1897  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1899  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1900 
1901  static_assert(
1902  (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1903  (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1904  "TMALoadModes must match number of rows in IDTable and IDTableCTA");
1905  size_t mode = static_cast<size_t>(thisOp.getMode());
1906  size_t dim = thisOp.getCoordinates().size();
1907  auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1908  assert(id != notIntrinsic &&
1909  "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1910 
1911  return {id, std::move(args)};
1912 }
1913 
1914 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
1915  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1916  auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1918 
1919  // Fill the Intrinsic Args
1920  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1921 
1922  for (auto v : thisOp.getCoordinates())
1923  args.push_back(mt.lookupValue(v));
1924  for (auto v : thisOp.getIm2colOffsets())
1925  args.push_back(mt.lookupValue(v));
1926 
1927  mlir::Value cacheHint = thisOp.getL2CacheHint();
1928  const bool hasCacheHint = static_cast<bool>(cacheHint);
1929  llvm::Value *i64Unused =
1930  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1931  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1932  args.push_back(builder.getInt1(hasCacheHint));
1933 
1934  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1935  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1936  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1937  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1938  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1939  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1940  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1941  {NI, NI, NI,
1942  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1943  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1944  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1945  {NI, NI, NI,
1946  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1947  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1948  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1949  {NI, NI, NI,
1950  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1951  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1952  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1953  {NI, NI, NI, NI, NI,
1954  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1955 
1956  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1957  "TMALoadModes must match number of rows in IDTable");
1958  size_t mode = static_cast<size_t>(thisOp.getMode());
1959  size_t dim = thisOp.getCoordinates().size();
1960  llvm::Intrinsic::ID id = IDTable[mode][dim];
1961  if (id == llvm::Intrinsic::not_intrinsic)
1962  llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1963 
1964  return {id, std::move(args)};
1965 }
1966 
1968 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1969  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1970  auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1972 
1973  // Fill the Intrinsic Args
1974  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1975  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1976 
1977  for (auto v : thisOp.getCoordinates())
1978  args.push_back(mt.lookupValue(v));
1979 
1980  mlir::Value cacheHint = thisOp.getL2CacheHint();
1981  const bool hasCacheHint = static_cast<bool>(cacheHint);
1982  llvm::Value *i64Unused =
1983  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1984  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1985  args.push_back(builder.getInt1(hasCacheHint));
1986 
1987  const unsigned NI = llvm::Intrinsic::not_intrinsic;
1988  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1989  {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1990  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1991  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1992  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1993  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1994  {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1995  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1996  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1997  {NI, NI, NI, NI, NI,
1998  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1999 
2000  static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2001  "TMAStoreModes must match number of rows in IDTable");
2002  size_t mode = static_cast<size_t>(thisOp.getMode());
2003  size_t dim = thisOp.getCoordinates().size();
2004  llvm::Intrinsic::ID id = IDTable[mode][dim];
2005  if (id == llvm::Intrinsic::not_intrinsic)
2006  llvm_unreachable(
2007  "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2008 
2009  return {id, std::move(args)};
2010 }
2011 
2012 NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
2013  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2014  auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2015  llvm::LLVMContext &ctx = mt.getLLVMContext();
2016 
2018 
2019  // Arguments to the intrinsic:
2020  // shared_mem_ptr, tmaDesc, tensorDims
2021  // cache_hint(if applicable) and flag(boolean)
2022  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2023  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2024 
2025  for (Value v : thisOp.getCoordinates())
2026  args.push_back(mt.lookupValue(v));
2027 
2028  mlir::Value cacheHint = thisOp.getL2CacheHint();
2029  const bool hasCacheHint = static_cast<bool>(cacheHint);
2030  llvm::Value *i64ZeroValue =
2031  llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2032  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
2033  args.push_back(builder.getInt1(hasCacheHint));
2034 
2035  const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
2036 
2037  constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
2038  constexpr unsigned numLayouts = 2; // TILE, IM2COL
2039  constexpr unsigned maxDim = 5; // 1D to 5D
2040  using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2041  using layoutTable = std::array<row, numLayouts>;
2042  using fullTable = std::array<layoutTable, numRedKinds>;
2043  static constexpr fullTable IDTable{
2044  {// RedTy::ADD
2045  {{{{notIntrinsic,
2046  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2047  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2048  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2049  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2050  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2052  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2053  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2054  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2055  // RedTy::MIN
2056  {{{{notIntrinsic,
2057  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2058  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2059  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2060  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2061  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2063  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2064  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2065  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2066  // RedTy::MAX
2067  {{{{notIntrinsic,
2068  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2069  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2070  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2071  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2072  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2074  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2075  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2076  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2077  // RedTy::INC
2078  {{{{notIntrinsic,
2079  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2080  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2081  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2082  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2083  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2085  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2086  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2087  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2088  // RedTy::DEC
2089  {{{{notIntrinsic,
2090  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2091  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2092  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2093  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2094  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2096  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2097  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2098  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2099  // RedTy::AND
2100  {{{{notIntrinsic,
2101  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2102  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2103  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2104  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2105  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2107  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2108  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2109  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2110  // RedTy::OR
2111  {{{{notIntrinsic,
2112  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2113  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2114  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2115  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2116  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2118  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2119  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2120  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2121  // RedTy::XOR
2122  {{{{notIntrinsic,
2123  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2124  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2125  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2126  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2127  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2129  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2130  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2131  llvm::Intrinsic::
2132  nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2133 
2134  static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2135  "TMAReduxKinds must match number of rows in IDTable");
2136 
2137  size_t redKind = static_cast<size_t>(thisOp.getRedKind());
2138  size_t mode = static_cast<size_t>(thisOp.getMode());
2139  size_t dim = thisOp.getCoordinates().size();
2140 
2141  assert(redKind < IDTable.size() &&
2142  "Invalid redKind for CpAsyncBulkTensorReduceOp");
2143  assert(mode < IDTable[redKind].size() &&
2144  "Invalid mode for CpAsyncBulkTensorReduceOp");
2145  assert(dim < IDTable[redKind][mode].size() &&
2146  "Invalid dim for CpAsyncBulkTensorReduceOp");
2147 
2148  llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2149 
2150  assert(intrinsicID != notIntrinsic &&
2151  "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2152 
2153  return {intrinsicID, std::move(args)};
2154 }
2155 
2156 #define _none
2157 
2158 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2159  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2160  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2161 
2162 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2163  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2164  : CVT_F2TF32_ID_IMPL(rnd, relu, )
2165 
2167 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2168  NVVM::SaturationMode sat, bool hasRelu) {
2169  using RndMode = NVVM::FPRoundingMode;
2170  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2171  switch (rnd) {
2172  case RndMode::RN:
2173  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
2174  case RndMode::RZ:
2175  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
2176  case RndMode::RNA:
2177  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
2178  default:
2179  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
2180  }
2181 }
2182 
2184 ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2185  LLVM::ModuleTranslation &mt,
2186  llvm::IRBuilderBase &builder) {
2188  args.push_back(mt.lookupValue(op.getA()));
2189  args.push_back(mt.lookupValue(op.getB()));
2190 
2191  bool hasRelu = op.getRelu();
2192 
2193  llvm::Intrinsic::ID intId =
2194  hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2195  : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2196 
2197  return {intId, std::move(args)};
2198 }
2199 
2200 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2201  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2202  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2203 
2204 llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
2205  bool hasRelu) {
2207  .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2208  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
2209  })
2210  .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2211  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
2212  })
2213  .Default([](mlir::Type) {
2214  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
2215  return llvm::Intrinsic::not_intrinsic;
2216  });
2217 }
2218 
2219 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2220  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2221  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2222 
2223 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2224  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2225  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2226 
2228 ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2229  NVVM::SaturationMode sat, bool hasRelu) {
2230  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2231  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2232  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2233 
2235  .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2236  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
2237  })
2238  .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2239  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
2240  })
2241  .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2242  if (hasRoundingModeRZ)
2243  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
2244  else if (hasRoundingModeRP)
2245  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
2246 
2247  llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2248  })
2249  .Default([](mlir::Type) {
2250  llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2251  return llvm::Intrinsic::not_intrinsic;
2252  });
2253 }
2254 
2255 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2256  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2257  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2258 
2259 llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
2260  bool hasRelu) {
2262  .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2263  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
2264  })
2265  .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2266  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
2267  })
2268  .Default([](mlir::Type) {
2269  llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
2270  return llvm::Intrinsic::not_intrinsic;
2271  });
2272 }
2273 
2274 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2275  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2276  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2277 
2279 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2280  NVVM::SaturationMode sat) {
2281  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2282  switch (rnd) {
2283  case NVVM::FPRoundingMode::RZ:
2284  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
2285  case NVVM::FPRoundingMode::RP:
2286  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
2287  default:
2288  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
2289  }
2290 }
2291 
2292 NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
2293  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2294  auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2295 
2296  bool hasRelu = curOp.getRelu();
2297 
2298  llvm::Intrinsic::ID intId =
2300  .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2301  return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2302  : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2303  })
2304  .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2305  return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2306  : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2307  })
2308  .Default([](mlir::Type type) {
2309  llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
2310  return llvm::Intrinsic::not_intrinsic;
2311  });
2312 
2313  llvm::Value *packedI16 =
2314  builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2315  llvm::Type::getInt16Ty(builder.getContext()));
2316 
2317  return {intId, {packedI16}};
2318 }
2319 
2320 NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
2321  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2322  auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2323 
2324  llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2325  llvm::Value *packedI16 =
2326  builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2327  llvm::Type::getInt16Ty(builder.getContext()));
2328 
2329  return {intId, {packedI16}};
2330 }
2331 
2332 NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
2333  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2334  auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2335 
2336  bool hasRelu = curOp.getRelu();
2337 
2338  llvm::Intrinsic::ID intId =
2340  .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2341  return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2342  : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2343  })
2344  .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2345  return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2346  : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2347  })
2348  .Default([](mlir::Type type) {
2349  llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
2350  return llvm::Intrinsic::not_intrinsic;
2351  });
2352 
2353  llvm::Value *packedI16 =
2354  builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2355  llvm::Type::getInt16Ty(builder.getContext()));
2356 
2357  return {intId, {packedI16}};
2358 }
2359 
2360 NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
2361  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2362  auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2363 
2364  bool hasRelu = curOp.getRelu();
2365 
2366  llvm::Intrinsic::ID intId =
2368  .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2369  return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2370  : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2371  })
2372  .Default([](mlir::Type type) {
2373  llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
2374  return llvm::Intrinsic::not_intrinsic;
2375  });
2376 
2377  llvm::Value *extendedI16 =
2378  builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
2379  llvm::Type::getInt16Ty(builder.getContext()));
2380 
2381  return {intId, {extendedI16}};
2382 }
2383 
2385 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
2386  LLVM::ModuleTranslation &mt,
2388  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2389  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2390  .getAddressSpace();
2391  bool isShared = as == NVVMMemorySpace::Shared;
2392  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2393 
2395  if (isShared) {
2396  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2397  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2398  } else {
2399  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2400  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2401  }
2402 
2403  // Fill the Intrinsic Args
2404  args.push_back(mt.lookupValue(curOp.getAddr()));
2405  args.push_back(mt.lookupValue(curOp.getNCols()));
2406 
2407  return id;
2408 }
2409 
2410 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2411  Operation &op, LLVM::ModuleTranslation &mt,
2413  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2414  auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2415  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2416  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2417 
2418  // Fill the Intrinsic Args
2419  args.push_back(mt.lookupValue(curOp.getTaddr()));
2420  args.push_back(mt.lookupValue(curOp.getNCols()));
2421 
2422  return id;
2423 }
2424 
2425 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2426  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2427  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2428 
2429 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2430  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2431  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2432 
2434 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
2435  LLVM::ModuleTranslation &mt,
2437  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2438  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2439  .getAddressSpace();
2440  bool isShared = as == NVVMMemorySpace::Shared;
2441  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
2442  bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2443 
2444  llvm::Intrinsic::ID id =
2445  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
2446  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
2447 
2448  // Fill the Intrinsic Args
2449  args.push_back(mt.lookupValue(curOp.getAddr()));
2450  if (hasMulticast)
2451  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2452 
2453  return id;
2454 }
2455 
2456 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2457  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2458 
2459 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2460  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2461  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2462 
2463 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2464  [&]() -> auto { \
2465  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2466  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2467  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2468  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2469  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2470  }()
2471 
2472 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
2473  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2474  bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2475  auto srcFmt = curOp.getSrcFormat();
2476  auto mc = curOp.getMulticast();
2477 
2478  switch (curOp.getShape()) {
2479  case Tcgen05CpShape::SHAPE_128x256b:
2480  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
2481  case Tcgen05CpShape::SHAPE_128x128b:
2482  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
2483  case Tcgen05CpShape::SHAPE_4x256b:
2484  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
2485  case Tcgen05CpShape::SHAPE_32x128b:
2486  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
2487  case Tcgen05CpShape::SHAPE_64x128b:
2488  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2489  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
2490  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
2491  }
2492  llvm_unreachable("Invalid shape in tcgen05 cp Op");
2493 }
2494 
2495 // Returns the valid vector length for a given shape and vector length, the
2496 // function models the table mentioned in the tcgen05.{ld, st} Op description
2497 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
2498  unsigned vecLen) {
2499  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2500  return vecLen >= 2;
2501  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2502  return vecLen >= 4;
2503  return true;
2504 }
2505 
2506 LogicalResult Tcgen05LdOp::verify() {
2507  LogicalResult result = success();
2508  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2509  result = emitError("shape 16x32bx2 requires offset argument");
2510 
2511  auto resTy = getRes().getType();
2512  unsigned resLen = isa<VectorType>(resTy)
2513  ? llvm::cast<VectorType>(resTy).getNumElements()
2514  : 1;
2515  if (!isValidVectorLength(getShape(), resLen))
2516  result = emitError(llvm::formatv("invalid result type length {0} for shape "
2517  "{1} in tcgen05.ld Op",
2518  resLen, stringifyEnum(getShape())));
2519 
2520  return result;
2521 }
2522 
2523 LogicalResult Tcgen05StOp::verify() {
2524  LogicalResult result = success();
2525  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2526  result = emitError("shape 16x32bx2 requires offset argument");
2527 
2528  auto valTy = getVal().getType();
2529  unsigned valLen = isa<VectorType>(valTy)
2530  ? llvm::cast<VectorType>(valTy).getNumElements()
2531  : 1;
2532  if (!isValidVectorLength(getShape(), valLen))
2533  result = emitError(llvm::formatv("invalid input length {0} for shape "
2534  "{1} in tcgen05.st Op",
2535  valLen, stringifyEnum(getShape())));
2536 
2537  return result;
2538 }
2539 
2540 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
2541 /// have ConstantRangeAttr.
2542 static void nvvmInferResultRanges(Operation *op, Value result,
2544  SetIntRangeFn setResultRanges) {
2545  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
2546  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2547  rangeAttr.getLower(), rangeAttr.getUpper()});
2548  }
2549 }
2550 
2551 /// Verify the range attribute satisfies LLVM ConstantRange constructor
2552 /// requirements for NVVM SpecialRangeableRegisterOp.
2553 static LogicalResult
2555  std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2556  if (!rangeAttr)
2557  return success();
2558 
2559  const llvm::APInt &lower = rangeAttr->getLower();
2560  const llvm::APInt &upper = rangeAttr->getUpper();
2561 
2562  // Check LLVM ConstantRange constructor condition
2563  if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2564  unsigned bitWidth = lower.getBitWidth();
2565  llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2566  llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2567  return op->emitOpError(
2568  "invalid range attribute: Lower == Upper, but they aren't min (")
2569  << llvm::toString(minVal, 10, false) << ") or max ("
2570  << llvm::toString(maxVal, 10, false)
2571  << ") value! This is an invalid constant range.";
2572  }
2573 
2574  return success();
2575 }
2576 
2577 static llvm::Value *getAsPackedI32(llvm::Value *arg,
2578  llvm::IRBuilderBase &builder) {
2579  return builder.CreateBitCast(arg,
2580  llvm::Type::getInt32Ty(builder.getContext()));
2581 }
2582 
2583 NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
2584  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2585  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2586 
2588  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2589  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2590  args.push_back(mt.lookupValue(curOp.getC()));
2591 
2592  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2593  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2594  unsigned type = (isASigned << 1) | isBSigned;
2595  const llvm::Intrinsic::ID ids[] = {
2596  llvm::Intrinsic::nvvm_idp4a_u_u,
2597  llvm::Intrinsic::nvvm_idp4a_u_s,
2598  llvm::Intrinsic::nvvm_idp4a_s_u,
2599  llvm::Intrinsic::nvvm_idp4a_s_s,
2600  };
2601  return {ids[type], args};
2602 }
2603 
2604 NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
2605  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2606  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2607 
2609  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2610  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2611  args.push_back(builder.getInt1(curOp.getBHi()));
2612  args.push_back(mt.lookupValue(curOp.getC()));
2613 
2614  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2615  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2616  unsigned type = (isASigned << 1) | isBSigned;
2617  const llvm::Intrinsic::ID ids[] = {
2618  llvm::Intrinsic::nvvm_idp2a_u_u,
2619  llvm::Intrinsic::nvvm_idp2a_u_s,
2620  llvm::Intrinsic::nvvm_idp2a_s_u,
2621  llvm::Intrinsic::nvvm_idp2a_s_s,
2622  };
2623  return {ids[type], args};
2624 }
2625 
2626 static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2627  llvm::IRBuilderBase &builder) {
2628  return builder.CreateAddrSpaceCast(
2629  addr,
2630  llvm::PointerType::get(builder.getContext(),
2631  llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2632 }
2633 
2635 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2636  LLVM::ModuleTranslation &mt,
2637  llvm::IRBuilderBase &builder) {
2638  using MemSpace = NVVM::NVVMMemorySpace;
2639  using CacheLevel = NVVM::PrefetchCacheLevel;
2640 
2641  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2642  std::optional<NVVM::CacheEvictionPriority> evictPriority =
2643  op.getEvictPriority();
2644  unsigned addressSpace =
2645  llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2646  .getAddressSpace();
2647 
2649  llvm::Value *addr = mt.lookupValue(op.getAddr());
2650  args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
2651  : addr);
2652 
2653  if (op.getTensormap())
2654  return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2655 
2656  assert(cacheLevel && "expected cache level for non-tensormap prefetch");
2657 
2658  if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2659  return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2660 
2661  if (evictPriority && *cacheLevel == CacheLevel::L2) {
2662  switch (*evictPriority) {
2663  case NVVM::CacheEvictionPriority::EvictLast:
2664  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2665  case NVVM::CacheEvictionPriority::EvictNormal:
2666  return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2667  default:
2668  llvm_unreachable("Invalid cache eviction priority");
2669  }
2670  }
2671 
2672  switch (static_cast<MemSpace>(addressSpace)) {
2673  case MemSpace::Generic:
2674  return *cacheLevel == CacheLevel::L1
2675  ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
2676  : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
2677  case MemSpace::Global:
2678  return *cacheLevel == CacheLevel::L1
2679  ? NVVM::IDArgPair(
2680  {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2681  : NVVM::IDArgPair(
2682  {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2683  case MemSpace::Local:
2684  return *cacheLevel == CacheLevel::L1
2685  ? NVVM::IDArgPair(
2686  {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2687  : NVVM::IDArgPair(
2688  {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2689  default:
2690  llvm_unreachable("Invalid pointer address space");
2691  }
2692 }
2693 
2694 bool NVVM::InlinePtxOp::getAsmValues(
2695  RewriterBase &rewriter,
2696  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2697  &asmValues) {
2698  for (auto arg : getReadWriteArgs())
2699  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
2700  for (auto arg : getResults())
2701  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
2702  for (auto arg : getReadOnlyArgs())
2703  asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
2704  if (getPredicate())
2705  asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
2706  return false; // No manual mapping needed
2707 }
2708 
2709 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2710  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2711  auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2713  args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2714  args.push_back(mt.lookupValue(curOp.getMbarrier()));
2715 
2716  llvm::Intrinsic::ID intrinsicID =
2717  curOp.getMulticast()
2718  ? llvm::Intrinsic::
2719  nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2720  : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2721 
2722  return {intrinsicID, args};
2723 }
2724 
2725 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2726  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2727  auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2729  args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2730 
2731  llvm::Intrinsic::ID intrinsicID;
2732 
2733  switch (curOp.getQueryType()) {
2734  case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2735  intrinsicID =
2736  llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2737  break;
2738  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2739  intrinsicID = llvm::Intrinsic::
2740  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2741  break;
2742  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2743  intrinsicID = llvm::Intrinsic::
2744  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2745  break;
2746  case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2747  intrinsicID = llvm::Intrinsic::
2748  nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2749  break;
2750  }
2751  return {intrinsicID, args};
2752 }
2753 
2754 //===----------------------------------------------------------------------===//
2755 // NVVMDialect initialization, type parsing, and registration.
2756 //===----------------------------------------------------------------------===//
2757 
2758 // TODO: This should be the llvm.nvvm dialect once this is supported.
2759 void NVVMDialect::initialize() {
2760  addOperations<
2761 #define GET_OP_LIST
2762 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2763  >();
2764  addAttributes<
2765 #define GET_ATTRDEF_LIST
2766 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2767  >();
2768 
2769  // Support unknown operations because not all NVVM operations are
2770  // registered.
2771  allowUnknownOperations();
2772  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2773  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2774 }
2775 
2776 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
2777  NamedAttribute attr) {
2778  StringAttr attrName = attr.getName();
2779  // Kernel function attribute should be attached to functions.
2780  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2781  if (!isa<LLVM::LLVMFuncOp>(op)) {
2782  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
2783  << "' attribute attached to unexpected op";
2784  }
2785  }
2786  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
2787  // dim
2788  if (attrName == NVVMDialect::getMaxntidAttrName() ||
2789  attrName == NVVMDialect::getReqntidAttrName() ||
2790  attrName == NVVMDialect::getClusterDimAttrName()) {
2791  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
2792  if (!values || values.empty() || values.size() > 3) {
2793  return op->emitError()
2794  << "'" << attrName
2795  << "' attribute must be integer array with maximum 3 index";
2796  }
2797  }
2798  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
2799  // attribute
2800  if (attrName == NVVMDialect::getMinctasmAttrName() ||
2801  attrName == NVVMDialect::getMaxnregAttrName() ||
2802  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2803  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
2804  return op->emitError()
2805  << "'" << attrName << "' attribute must be integer constant";
2806  }
2807  }
2808  // blocksareclusters must be used along with reqntid and cluster_dim
2809  if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2810  if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
2811  !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
2812  return op->emitError()
2813  << "'" << attrName << "' attribute must be used along with "
2814  << "'" << NVVMDialect::getReqntidAttrName() << "' and "
2815  << "'" << NVVMDialect::getClusterDimAttrName() << "'";
2816  }
2817  }
2818 
2819  return success();
2820 }
2821 
2822 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
2823  unsigned regionIndex,
2824  unsigned argIndex,
2825  NamedAttribute argAttr) {
2826  auto funcOp = dyn_cast<FunctionOpInterface>(op);
2827  if (!funcOp)
2828  return success();
2829 
2830  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
2831  StringAttr attrName = argAttr.getName();
2832  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2833  if (!isKernel) {
2834  return op->emitError()
2835  << "'" << attrName
2836  << "' attribute must be present only on kernel arguments";
2837  }
2838  if (!isa<UnitAttr>(argAttr.getValue()))
2839  return op->emitError() << "'" << attrName << "' must be a unit attribute";
2840  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2841  return op->emitError()
2842  << "'" << attrName
2843  << "' attribute requires the argument to also have attribute '"
2844  << LLVM::LLVMDialect::getByValAttrName() << "'";
2845  }
2846  }
2847 
2848  return success();
2849 }
2850 
2851 //===----------------------------------------------------------------------===//
2852 // NVVM Address Space Attr
2853 //===----------------------------------------------------------------------===//
2854 
2855 unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
2856  return static_cast<unsigned>(getValue());
2857 }
2858 
2859 bool NVVMMemorySpaceAttr::isValidLoad(
2860  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2861  const ::mlir::DataLayout *dataLayout,
2863  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2864  dataLayout, emitError);
2865 }
2866 
2867 bool NVVMMemorySpaceAttr::isValidStore(
2868  Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2869  const ::mlir::DataLayout *dataLayout,
2871  return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
2872  dataLayout, emitError);
2873 }
2874 
2875 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2876  ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
2877  std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2879  // TODO: update this method once `ptr.atomic_rmw` is implemented.
2880  assert(false && "unimplemented, see TODO in the source.");
2881  return false;
2882 }
2883 
2884 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2885  Type type, ptr::AtomicOrdering successOrdering,
2886  ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2887  const ::mlir::DataLayout *dataLayout,
2889  // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
2890  assert(false && "unimplemented, see TODO in the source.");
2891  return false;
2892 }
2893 
2894 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2895  Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
2896  // TODO: update this method once the `ptr.addrspace_cast` op is added to the
2897  // dialect.
2898  assert(false && "unimplemented, see TODO in the source.");
2899  return false;
2900 }
2901 
2902 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2903  Type intLikeTy, Type ptrLikeTy,
2905  // TODO: update this method once the int-cast ops are added to the `ptr`
2906  // dialect.
2907  assert(false && "unimplemented, see TODO in the source.");
2908  return false;
2909 }
2910 
2911 //===----------------------------------------------------------------------===//
2912 // NVVM target attribute.
2913 //===----------------------------------------------------------------------===//
2914 LogicalResult
2916  int optLevel, StringRef triple, StringRef chip,
2917  StringRef features, DictionaryAttr flags,
2918  ArrayAttr files, bool verifyTarget) {
2919  if (optLevel < 0 || optLevel > 3) {
2920  emitError() << "The optimization level must be a number between 0 and 3.";
2921  return failure();
2922  }
2923  if (triple.empty()) {
2924  emitError() << "The target triple cannot be empty.";
2925  return failure();
2926  }
2927  if (chip.empty()) {
2928  emitError() << "The target chip cannot be empty.";
2929  return failure();
2930  }
2931  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
2932  return mlir::isa_and_nonnull<StringAttr>(attr);
2933  })) {
2934  emitError() << "All the elements in the `link` array must be strings.";
2935  return failure();
2936  }
2937  return success();
2938 }
2939 
2940 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
2941  if (!getVerifyTarget())
2942  return success();
2943 
2944  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2945  if (!gpuModuleOp) {
2946  return emitError(gpuModule->getLoc(),
2947  "NVVM target attribute must be attached to a GPU module");
2948  }
2949 
2950  const NVVMCheckSMVersion targetSMVersion =
2951  NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2952  if (!targetSMVersion.isMinimumSMVersion()) {
2953  return emitError(gpuModule->getLoc(),
2954  "Minimum NVVM target SM version is sm_20");
2955  }
2956 
2957  gpuModuleOp->walk([&](Operation *op) {
2958  if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2959  const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
2960  if (!requirement.isCompatibleWith(targetSMVersion)) {
2961  op->emitOpError() << "is not supported on " << getChip();
2962  return WalkResult::interrupt();
2963  }
2964  }
2965  return WalkResult::advance();
2966  });
2967 
2968  return success();
2969 }
2970 
2971 #define GET_OP_CLASSES
2972 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2973 
2974 #define GET_ATTRDEF_CLASSES
2975 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static MLIRContext * getContext(OpFoldResult val)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
#define _none
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:57
static constexpr unsigned notIntrinsic
Definition: NVVMDialect.cpp:48
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
FloatType getF64Type()
Definition: Builders.cpp:45
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition: LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition: NVVMDialect.h:54
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:85
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)