MLIR  21.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/IntrinsicsNVPTX.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/SourceMgr.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <cassert>
44 #include <optional>
45 #include <string>
46 
47 using namespace mlir;
48 using namespace NVVM;
49 
50 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
51 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Verifier methods
55 //===----------------------------------------------------------------------===//
56 
57 // This verifier is shared among the following Ops:
58 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
59 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
60 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
61 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
62  bool isIm2Col,
63  size_t numIm2ColOffsets,
64  Location loc) {
65  if (tensorDims < 1 || tensorDims > 5)
66  return emitError(loc, "expects coordinates between 1 to 5 dimension");
67 
68  // For Im2Col mode, there are two constraints:
69  if (isIm2Col) {
70  // 1. Tensor must always be at least 3-d.
71  if (tensorDims < 3)
72  return emitError(
73  loc,
74  "to use im2col mode, the tensor has to be at least 3-dimensional");
75  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
76  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
77  return emitError(
78  loc, "im2col offsets must be 2 less than number of coordinates");
79  }
80  return success();
81 }
82 
84  size_t numIm2ColOffsets = getIm2colOffsets().size();
85  bool isIm2Col = numIm2ColOffsets > 0;
86  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
87  numIm2ColOffsets, getLoc());
88 }
89 
91  if (getCoordinates().size() > 5)
92  return emitError("Maximum 5 coordinates and dimension is supported.");
93  return success();
94 }
95 
96 LogicalResult CpAsyncOp::verify() {
97  if (getModifier() != LoadCacheModifierKind::CG &&
98  getModifier() != LoadCacheModifierKind::CA)
99  return emitError("Only CG and CA cache modifiers are supported.");
100  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
101  return emitError("expected byte size to be either 4, 8 or 16.");
102  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
103  return emitError("CG cache modifier is only support for 16 bytes copy.");
104  return success();
105 }
106 
107 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
108  size_t numIm2ColOffsets = getIm2colOffsets().size();
109  bool isIm2Col = numIm2ColOffsets > 0;
110  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
111  numIm2ColOffsets, getLoc());
112 }
113 
114 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
115  bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
116  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
117  getLoc());
118 }
119 
120 LogicalResult CvtFloatToTF32Op::verify() {
121  using RndMode = NVVM::FPRoundingMode;
122  switch (getRnd()) {
123  case RndMode::RNA:
124  if (getRelu())
125  return emitError("Relu not supported with rna rounding mode.");
126  break;
127  case RndMode::RN:
128  case RndMode::RZ:
129  break;
130  default:
131  return emitError(
132  "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
133  }
134  return success();
135 }
136 
137 LogicalResult CvtF32x2ToF8x2Op::verify() {
138  using RndMode = NVVM::FPRoundingMode;
139  using SatMode = NVVM::SaturationMode;
140 
141  bool isRoundingModeRN = getRnd() == RndMode::RN;
142  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
143  bool isRoundingModeRP = getRnd() == RndMode::RP;
144  bool isSatFinite = getSat() == SatMode::SATFINITE;
145 
146  bool hasRelu = getRelu();
147 
148  switch (getType()) {
149  case CVTFP8Type::E4M3:
150  case CVTFP8Type::E5M2:
151  if (!isRoundingModeRN)
152  return emitOpError("Only RN rounding mode is supported for conversions "
153  "from f32x2 to .e4m3x2 or .e5m2x2 types");
154  if (!isSatFinite)
155  return emitOpError("Only SATFINITE saturation mode is supported for "
156  "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
157  break;
158  case CVTFP8Type::UE8M0:
159  if (!(isRoundingModeRZ || isRoundingModeRP))
160  return emitOpError("Only RZ or RP rounding modes are supported for "
161  "conversions from f32x2 to .ue8m0x2 type");
162  if (hasRelu)
163  return emitOpError("relu not supported for conversions to .ue8m0x2 type");
164  break;
165  }
166  return success();
167 }
168 
169 LogicalResult CvtF16x2ToF8x2Op::verify() {
170  if (getType() == CVTFP8Type::UE8M0)
171  return emitOpError("Only .e4m3 or .e5m2 types are supported for "
172  "conversions from f16x2 to f8x2.");
173 
174  return success();
175 }
176 
177 LogicalResult CvtBF16x2ToF8x2Op::verify() {
178  using RndMode = NVVM::FPRoundingMode;
179 
180  if (getType() != CVTFP8Type::UE8M0)
181  return emitOpError(
182  "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
183 
184  auto rnd = getRnd();
185  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
186  return emitOpError("Only RZ and RP rounding modes are supported for "
187  "conversions from bf16x2 to f8x2.");
188 
189  return success();
190 }
191 
192 LogicalResult BulkStoreOp::verify() {
193  if (getInitVal() != 0)
194  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
195  return success();
196 }
197 
198 // Given the element type of an operand and whether or not it is an accumulator,
199 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
200 // operand's element type.
201 std::optional<mlir::NVVM::MMATypes>
202 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
203  auto half2Type =
204  VectorType::get(2, Float16Type::get(operandElType.getContext()));
205  if (operandElType.isF64())
206  return NVVM::MMATypes::f64;
207  if (operandElType.isF16() || operandElType == half2Type)
208  return NVVM::MMATypes::f16;
209  if (operandElType.isF32() && isAccumulator)
210  return NVVM::MMATypes::f32;
211  if (operandElType.isF32() && !isAccumulator)
212  return NVVM::MMATypes::tf32;
213  if (llvm::isa<IntegerType>(operandElType)) {
214  if (isAccumulator)
215  return NVVM::MMATypes::s32;
216  return std::nullopt;
217  }
218 
219  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
220  if (structType.getBody().empty())
221  return std::nullopt;
222  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
223  }
224 
225  return std::nullopt;
226 }
227 
228 static bool isInt4PtxType(MMATypes type) {
229  return (type == MMATypes::u4 || type == MMATypes::s4);
230 }
231 
232 static bool isInt8PtxType(MMATypes type) {
233  return (type == MMATypes::u8 || type == MMATypes::s8);
234 }
235 
236 static bool isIntegerPtxType(MMATypes type) {
237  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
238  type == MMATypes::s32;
239 }
240 
241 MMATypes MmaOp::accumPtxType() {
242  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
243  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
244  assert(val.has_value() && "accumulator PTX type should always be inferrable");
245  return val.value();
246 }
247 
248 MMATypes MmaOp::resultPtxType() {
249  std::optional<mlir::NVVM::MMATypes> val =
250  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
251  assert(val.has_value() && "result PTX type should always be inferrable");
252  return val.value();
253 }
254 
255 void MmaOp::print(OpAsmPrinter &p) {
256  SmallVector<Type, 4> regTypes;
257  struct OperandFragment {
258  StringRef operandName;
259  StringRef ptxTypeAttr;
261  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
262  : operandName(name), ptxTypeAttr(ptxTypeName) {}
263  };
264 
265  std::array<OperandFragment, 3> frags{
266  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
267  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
268  OperandFragment("C", "")};
269  SmallVector<StringRef, 4> ignoreAttrNames{
270  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
271 
272  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
273  auto &frag = frags[fragIdx];
274  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
275  for (auto operandIdx = varOperandSpec.first;
276  operandIdx < varOperandSpec.first + varOperandSpec.second;
277  operandIdx++) {
278  frag.regs.push_back(this->getOperand(operandIdx));
279  if (operandIdx == 0) {
280  regTypes.push_back(this->getOperand(operandIdx).getType());
281  }
282  }
283  std::optional<MMATypes> inferredType =
284  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
285  if (inferredType)
286  ignoreAttrNames.push_back(frag.ptxTypeAttr);
287  }
288 
289  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
290  p << " " << frag.operandName;
291  p << "[";
292  p.printOperands(frag.regs);
293  p << "] ";
294  };
295 
296  for (const auto &frag : frags) {
297  printMmaOperand(frag);
298  }
299 
300  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
301 
302  // Print the types of the operands and result.
303  p << " : "
304  << "(";
305  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
306  frags[1].regs[0].getType(),
307  frags[2].regs[0].getType()},
308  p);
309  p << ")";
310  p.printArrowTypeList(TypeRange{this->getRes().getType()});
311 }
312 
313 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
314  ValueRange operandA, ValueRange operandB, ValueRange operandC,
315  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
316  std::optional<MMAIntOverflow> intOverflow,
317  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
318  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
319 
320  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
321  MLIRContext *ctx = builder.getContext();
322  result.addAttribute(
323  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
324 
325  result.addOperands(operandA);
326  result.addOperands(operandB);
327  result.addOperands(operandC);
328 
329  if (multiplicandPtxTypes) {
330  result.addAttribute("multiplicandAPtxType",
331  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
332  result.addAttribute("multiplicandBPtxType",
333  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
334  } else {
335  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
336  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
337  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
338  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
339  }
340 
341  if (multiplicandLayouts) {
342  result.addAttribute("layoutA",
343  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
344  result.addAttribute("layoutB",
345  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
346  } else {
347  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
348  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
349  }
350 
351  if (intOverflow.has_value())
352  result.addAttribute("intOverflowBehavior",
353  MMAIntOverflowAttr::get(ctx, *intOverflow));
354  if (b1Op.has_value())
355  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
356 
357  result.addTypes(resultType);
358  result.addAttribute(
359  MmaOp::getOperandSegmentSizeAttr(),
360  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
361  static_cast<int32_t>(operandB.size()),
362  static_cast<int32_t>(operandC.size())}));
363 }
364 
365 // <operation> :=
366 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
367 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
368 // `->` type($res)
369 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
370  struct OperandFragment {
371  std::optional<MMATypes> elemtype;
373  SmallVector<Type> regTypes;
374  };
375 
376  Builder &builder = parser.getBuilder();
377  std::array<OperandFragment, 4> frags;
378 
379  NamedAttrList namedAttributes;
380 
381  // A helper to parse the operand segments.
382  auto parseMmaOperand = [&](StringRef operandName,
383  OperandFragment &frag) -> LogicalResult {
384  if (parser.parseKeyword(operandName).failed())
385  return failure();
386  if (parser
387  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
388  .failed())
389  return failure();
390  return success();
391  };
392 
393  // Parse the operand segments.
394  if (parseMmaOperand("A", frags[0]).failed())
395  return failure();
396  if (parseMmaOperand("B", frags[1]).failed())
397  return failure();
398  if (parseMmaOperand("C", frags[2]).failed())
399  return failure();
400 
401  if (parser.parseOptionalAttrDict(namedAttributes).failed())
402  return failure();
403 
404  // Parse the type specification and resolve operands.
405  SmallVector<Type, 3> operandTypes;
406  if (failed(parser.parseColon()))
407  return failure();
408  if (failed(parser.parseLParen()))
409  return failure();
410  if (failed(parser.parseTypeList(operandTypes)))
411  return failure();
412  if (failed(parser.parseRParen()))
413  if (operandTypes.size() != 3)
414  return parser.emitError(
415  parser.getNameLoc(),
416  "expected one type for each operand segment but got " +
417  Twine(operandTypes.size()) + " types");
418  for (const auto &iter : llvm::enumerate(operandTypes)) {
419  auto &frag = frags[iter.index()];
420  frag.regTypes.resize(frag.regs.size(), iter.value());
421  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
422  parser.getNameLoc(), result.operands)))
423  return failure();
424  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
425  /*isAccumulator*/ iter.index() < 2);
426  }
427 
428  Type resultType;
429  if (parser.parseArrow() || parser.parseType(resultType))
430  return failure();
431  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
432 
433  std::array<StringRef, 2> names{"multiplicandAPtxType",
434  "multiplicandBPtxType"};
435  for (unsigned idx = 0; idx < names.size(); idx++) {
436  const auto &frag = frags[idx];
437  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
438  if (!frag.elemtype.has_value() && !attr.has_value()) {
439  return parser.emitError(
440  parser.getNameLoc(),
441  "attribute " + names[idx] +
442  " is not provided explicitly and cannot be inferred");
443  }
444  if (!attr.has_value())
445  result.addAttribute(
446  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
447  }
448 
449  result.addTypes(resultType);
450  if (!namedAttributes.empty())
451  result.addAttributes(namedAttributes);
452  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
453  builder.getDenseI32ArrayAttr({
454  static_cast<int32_t>(frags[0].regs.size()),
455  static_cast<int32_t>(frags[1].regs.size()),
456  static_cast<int32_t>(frags[2].regs.size()),
457  }));
458  return success();
459 }
460 
461 LogicalResult MmaOp::verify() {
462  MLIRContext *context = getContext();
463  auto f16Ty = Float16Type::get(context);
464  auto i32Ty = IntegerType::get(context, 32);
465  auto f16x2Ty = VectorType::get(2, f16Ty);
466  auto f32Ty = Float32Type::get(context);
467  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
468  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
469 
470  auto s32x4StructTy =
471  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
472  auto f32x8StructTy =
473  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
474  auto f16x2x2StructTy =
475  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
476  auto f32x4StructTy =
477  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
478  auto s32x2StructTy =
479  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
480 
481  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
482  getShapeAttr().getK()};
483 
484  // These variables define the set of allowed data types for matrices A, B, C,
485  // and result.
486  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
487  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
488  AllowedShapes allowedShapes;
489  AllowedTypes expectedA;
490  AllowedTypes expectedB;
491  AllowedTypes expectedC;
492  SmallVector<Type> expectedResult;
493 
494  // When M = 16, we just need to calculate the number of 8xk tiles, where
495  // k is a factor that depends on the data type.
496  if (mmaShape[0] == 16) {
497  int64_t kFactor;
498  Type multiplicandFragType;
499  switch (*getMultiplicandAPtxType()) {
500  case MMATypes::tf32:
501  kFactor = 4;
502  multiplicandFragType = i32Ty;
503  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
504  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
505  break;
506  case MMATypes::bf16:
507  kFactor = 8;
508  multiplicandFragType = i32Ty;
509  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
510  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
511  break;
512  case MMATypes::f16:
513  kFactor = 8;
514  multiplicandFragType = f16x2Ty;
515  expectedResult.push_back(f16x2x2StructTy);
516  expectedResult.push_back(f32x4StructTy);
517  break;
518  case MMATypes::s4:
519  case MMATypes::u4:
520  kFactor = 32;
521  break;
522  case MMATypes::b1:
523  kFactor = 128;
524  break;
525  case MMATypes::s8:
526  case MMATypes::u8:
527  kFactor = 16;
528  break;
529  default:
530  return emitError("invalid shape or multiplicand type: " +
531  stringifyEnum(getMultiplicandAPtxType().value()));
532  }
533 
534  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
535  expectedResult.push_back(s32x4StructTy);
536  expectedC.emplace_back(4, i32Ty);
537  multiplicandFragType = i32Ty;
538  } else {
539  expectedC.emplace_back(2, f16x2Ty);
540  expectedC.emplace_back(4, f32Ty);
541  }
542 
543  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
544  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
545  expectedA.emplace_back(unitA, multiplicandFragType);
546  expectedB.emplace_back(unitB, multiplicandFragType);
547  allowedShapes.push_back({16, 8, kFactor});
548  allowedShapes.push_back({16, 8, kFactor * 2});
549  }
550 
551  // In the M=8 case, there is only 1 possible case per data type.
552  if (mmaShape[0] == 8) {
553  if (*getMultiplicandAPtxType() == MMATypes::f16) {
554  expectedA.emplace_back(2, f16x2Ty);
555  expectedB.emplace_back(2, f16x2Ty);
556  expectedResult.push_back(f16x2x4StructTy);
557  expectedResult.push_back(f32x8StructTy);
558  expectedC.emplace_back(4, f16x2Ty);
559  expectedC.emplace_back(8, f32Ty);
560  allowedShapes.push_back({8, 8, 4});
561  }
562  if (*getMultiplicandAPtxType() == MMATypes::f64) {
563  Type f64Ty = Float64Type::get(context);
564  expectedA.emplace_back(1, f64Ty);
565  expectedB.emplace_back(1, f64Ty);
566  expectedC.emplace_back(2, f64Ty);
567  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
568  context, SmallVector<Type>(2, f64Ty)));
569  allowedShapes.push_back({8, 8, 4});
570  }
571  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
572  expectedA.push_back({i32Ty});
573  expectedB.push_back({i32Ty});
574  expectedC.push_back({i32Ty, i32Ty});
575  expectedResult.push_back(s32x2StructTy);
576  if (isInt4PtxType(getMultiplicandAPtxType().value()))
577  allowedShapes.push_back({8, 8, 32});
578  if (isInt8PtxType(getMultiplicandAPtxType().value()))
579  allowedShapes.push_back({8, 8, 16});
580  if (getMultiplicandAPtxType().value() == MMATypes::b1)
581  allowedShapes.push_back({8, 8, 128});
582  }
583  }
584 
585  std::string errorMessage;
586  llvm::raw_string_ostream errorStream(errorMessage);
587 
588  // Check that we matched an existing shape/dtype combination.
589  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
590  !llvm::is_contained(allowedShapes, mmaShape)) {
591  errorStream << "unimplemented variant for MMA shape <";
592  llvm::interleaveComma(mmaShape, errorStream);
593  errorStream << ">";
594  return emitOpError(errorMessage);
595  }
596 
597  // Verify the operand types for segments of A, B, and C operands.
598  std::array<StringRef, 3> operandNames{"A", "B", "C"};
599  for (const auto &iter : llvm::enumerate(
600  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
601  auto spec = this->getODSOperandIndexAndLength(iter.index());
602  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
603  operand_type_begin() + spec.first +
604  spec.second);
605  bool match = llvm::is_contained(iter.value(), operandTySeg);
606 
607  if (!match) {
608  errorStream << "Could not match types for the "
609  << operandNames[iter.index()]
610  << " operands; expected one of ";
611  for (const auto &x : iter.value()) {
612  errorStream << x.size() << "x" << x[0] << " ";
613  }
614  errorStream << "but got ";
615  llvm::interleaveComma(operandTySeg, errorStream);
616  return emitOpError(errorMessage);
617  }
618  }
619 
620  // Check the result type
621  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
622  return expectedResultType == getResult().getType();
623  })) {
624  errorStream
625  << "Could not match allowed types for the result; expected one of ";
626  llvm::interleaveComma(expectedResult, errorStream);
627  errorStream << " but got " << getResult().getType();
628  return emitOpError(errorMessage);
629  }
630 
631  // Ensure that binary MMA variants have a b1 MMA operation defined.
632  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
633  return emitOpError("op requires " + getB1OpAttrName().strref() +
634  " attribute");
635  }
636 
637  // Ensure int4/int8 MMA variants specify the accum overflow behavior
638  // attribute.
639  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
640  isInt8PtxType(*getMultiplicandAPtxType())) {
641  if (!getIntOverflowBehavior())
642  return emitOpError("op requires " +
643  getIntOverflowBehaviorAttrName().strref() +
644  " attribute");
645  }
646 
647  return success();
648 }
649 
650 LogicalResult ShflOp::verify() {
651  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
652  return success();
653  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
654  auto elementType = (type && type.getBody().size() == 2)
655  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
656  : nullptr;
657  if (!elementType || elementType.getWidth() != 1)
658  return emitError("expected return type to be a two-element struct with "
659  "i1 as the second element");
660  return success();
661 }
662 
663 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
664  NVVM::MMAFrag frag, int nRow,
665  int nCol,
666  MLIRContext *context) {
667  unsigned numberElements = 0;
668  Type elementType;
669  OpBuilder builder(context);
670  Type f16x2 = VectorType::get(2, builder.getF16Type());
671  if (type == NVVM::MMATypes::f16) {
672  elementType = f16x2;
673  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
674  numberElements = 8;
675  else
676  numberElements = 4;
677  } else if (type == NVVM::MMATypes::f32) {
678  elementType = builder.getF32Type();
679  numberElements = 8;
680  } else if (type == NVVM::MMATypes::tf32) {
681  elementType = builder.getI32Type();
682  numberElements = 4;
683  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
684  elementType = builder.getI32Type();
685  int parallelSize = 0;
686  if (frag == NVVM::MMAFrag::a)
687  parallelSize = nRow;
688  if (frag == NVVM::MMAFrag::b)
689  parallelSize = nCol;
690 
691  // m == 16 && n == 16 && k == 16
692  if (parallelSize == 16)
693  numberElements = 2;
694  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
695  else if (parallelSize == 8)
696  numberElements = 1;
697  else if (parallelSize == 32)
698  numberElements = 4;
699  } else if (type == NVVM::MMATypes::s32) {
700  elementType = builder.getI32Type();
701  numberElements = 8;
702  }
703  assert(numberElements != 0 && elementType != nullptr);
704  return std::make_pair(elementType, numberElements);
705 }
706 
707 static std::pair<mlir::Type, unsigned>
708 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
709  int k, MLIRContext *context) {
710  int nRow, nCol;
711  if (frag == NVVM::MMAFrag::a) {
712  nRow = m;
713  nCol = k;
714  } else if (frag == NVVM::MMAFrag::b) {
715  nRow = k;
716  nCol = n;
717  } else {
718  nRow = m;
719  nCol = n;
720  }
721  assert(nRow && nCol);
722  return inferMMAType(type, frag, nRow, nCol, context);
723 }
724 
725 LogicalResult NVVM::WMMALoadOp::verify() {
726  unsigned addressSpace =
727  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
728  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
729  addressSpace != NVVM::kSharedMemorySpace)
730  return emitOpError("expected source pointer in memory "
731  "space 0, 1, 3");
732 
733  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
734  getEltype(), getFrag()) == 0)
735  return emitOpError() << "invalid attribute combination";
736  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
737  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
738  Type dstType = LLVM::LLVMStructType::getLiteral(
739  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
740  if (getType() != dstType)
741  return emitOpError("expected destination type is a structure of ")
742  << typeInfo.second << " elements of type " << typeInfo.first;
743  return success();
744 }
745 
746 LogicalResult NVVM::WMMAStoreOp::verify() {
747  unsigned addressSpace =
748  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
749  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
750  addressSpace != NVVM::kSharedMemorySpace)
751  return emitOpError("expected operands to be a source pointer in memory "
752  "space 0, 1, 3");
753 
754  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
755  getEltype()) == 0)
756  return emitOpError() << "invalid attribute combination";
757  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
758  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
759  if (getArgs().size() != typeInfo.second)
760  return emitOpError() << "expected " << typeInfo.second << " data operands";
761  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
762  return operands.getType() != typeInfo.first;
763  }))
764  return emitOpError() << "expected data operands of type " << typeInfo.first;
765  return success();
766 }
767 
768 LogicalResult NVVM::WMMAMmaOp::verify() {
769  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
770  getLayoutB(), getEltypeA(),
771  getEltypeB()) == 0)
772  return emitOpError() << "invalid attribute combination";
773  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
774  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
775  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
776  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
777  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
778  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
779  SmallVector<Type, 32> arguments;
780  arguments.append(typeInfoA.second, typeInfoA.first);
781  arguments.append(typeInfoB.second, typeInfoB.first);
782  arguments.append(typeInfoC.second, typeInfoC.first);
783  unsigned numArgs = arguments.size();
784  if (getArgs().size() != numArgs)
785  return emitOpError() << "expected " << numArgs << " arguments";
786  for (unsigned i = 0; i < numArgs; i++) {
787  if (getArgs()[i].getType() != arguments[i])
788  return emitOpError() << "expected argument " << i << " to be of type "
789  << arguments[i];
790  }
791  Type dstType = LLVM::LLVMStructType::getLiteral(
792  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
793  if (getType() != dstType)
794  return emitOpError("expected destination type is a structure of ")
795  << typeInfoC.second << " elements of type " << typeInfoC.first;
796  return success();
797 }
798 
799 LogicalResult NVVM::LdMatrixOp::verify() {
800  unsigned addressSpace =
801  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
802  if (addressSpace != NVVM::kSharedMemorySpace)
803  return emitOpError("expected source pointer in memory space 3");
804 
805  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
806  return emitOpError("expected num attribute to be 1, 2 or 4");
807 
808  Type i32 = IntegerType::get(getContext(), 32);
809  if (getNum() == 1 && getType() != i32)
810  return emitOpError("expected destination type is i32");
811  if (getNum() == 2 || getNum() == 4) {
812  Type dstType = LLVM::LLVMStructType::getLiteral(
813  getContext(), SmallVector<Type>(getNum(), i32));
814  if (getType() != dstType)
815  return emitOpError("expected destination type is a structure of ")
816  << getNum() << " elements of type i32";
817  }
818  return success();
819 }
820 
821 LogicalResult NVVM::StMatrixOp::verify() {
822  unsigned addressSpace =
823  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
824  if (addressSpace != NVVM::kSharedMemorySpace)
825  return emitOpError("expected source pointer in memory space 3");
826 
827  int numMatrix = getSources().size();
828  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
829  return emitOpError("expected num attribute to be 1, 2 or 4");
830 
831  return success();
832 }
833 
834 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
835  if (typeA == NVVM::WGMMATypes::tf32)
836  return 8;
837  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
838  return 16;
839  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
840  return 32;
841  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
842  return 32;
843  if (typeA == NVVM::WGMMATypes::b1)
844  return 256;
845  return failure();
846 }
847 
848 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
849  NVVM::WGMMATypes typeA,
850  NVVM::WGMMATypes typeB) {
851  switch (typeA) {
852  case NVVM::WGMMATypes::f16:
853  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
854  typeB == NVVM::WGMMATypes::f16)
855  return success();
856  break;
857  case NVVM::WGMMATypes::tf32:
858  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
859  return success();
860  break;
861  case NVVM::WGMMATypes::u8:
862  case NVVM::WGMMATypes::s8:
863  if (typeD == NVVM::WGMMATypes::s32 &&
864  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
865  return success();
866  break;
867  case NVVM::WGMMATypes::b1:
868  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
869  return success();
870  break;
871  case NVVM::WGMMATypes::bf16:
872  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
873  typeB == NVVM::WGMMATypes::bf16)
874  return success();
875  break;
876  case NVVM::WGMMATypes::e4m3:
877  case NVVM::WGMMATypes::e5m2:
878  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
879  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
880  return success();
881  break;
882  case WGMMATypes::f32:
883  case WGMMATypes::s32:
884  llvm_unreachable("unsupported input types");
885  break;
886  }
887  return failure();
888 }
889 
890 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
891  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
892  72, 80, 88, 96, 104, 112, 120, 128,
893  136, 144, 152, 160, 168, 176, 184, 192,
894  200, 208, 216, 224, 232, 240, 248, 256};
895  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
896  80, 96, 112, 128, 144, 160,
897  176, 192, 208, 224, 240, 256};
898  switch (typeA) {
899  case WGMMATypes::f16:
900  case WGMMATypes::tf32:
901  case WGMMATypes::bf16:
902  case WGMMATypes::e4m3:
903  case WGMMATypes::e5m2:
904  if (llvm::is_contained(allowedN, sizeN))
905  return success();
906  break;
907  case WGMMATypes::u8:
908  case WGMMATypes::s8:
909  case WGMMATypes::b1:
910  if (llvm::is_contained(allowedNshort, sizeN))
911  return success();
912  break;
913  case WGMMATypes::f32:
914  case WGMMATypes::s32:
915  llvm_unreachable("unsupported input types");
916  break;
917  }
918  return failure();
919 }
920 
921 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
922  Value outValue = getResults();
923  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
924  if (!stype)
925  return emitOpError() << "expected results to be struct";
926  int outputSize = stype.getBody().size();
927  WGMMATypes typeD = getTypeD();
928  WGMMATypes typeA = getTypeA();
929  WGMMATypes typeB = getTypeB();
930 
931  for (Type t : stype.getBody()) {
932  if (t != stype.getBody().front())
933  return emitOpError()
934  << "all elements in struct must be same type but there is " << t;
935  }
936 
937  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
938  typeD != WGMMATypes::s32) {
939  return emitOpError() << "does not support the given output type "
940  << NVVM::stringifyWGMMATypes(typeD);
941  }
942  if (typeD == WGMMATypes::s32 &&
943  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
944  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
945  }
946 
947  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
948  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
949  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
950  << NVVM::stringifyWGMMATypes(typeB)
951  << ", it is not supported.";
952  }
953 
954  // Check M
955  if (getShape().getM() != 64)
956  return emitOpError() << "shape 'm' must be 64";
957 
958  // Check K
959  FailureOr<int> allowedK = getAllowedSizeK(typeA);
960  if (failed(allowedK) || allowedK.value() != getShape().getK())
961  return emitOpError() << "shape 'k' must be " << allowedK.value()
962  << " for input type "
963  << NVVM::stringifyWGMMATypes(typeA);
964 
965  // Check N
966  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
967  return emitOpError() << "has input type "
968  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
969  << getShape().getN() << ", it is not supported.";
970  }
971 
972  // Check transpose (only available for f16/bf16)
973  // Matrices A should be stored in row-major and B in column-major.
974  // Only f16/bf16 matrices can be stored in either column-major or row-major
975  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
976  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
977  (getLayoutA() == mlir::NVVM::MMALayout::col ||
978  getLayoutB() == mlir::NVVM::MMALayout::row)) {
979  return emitOpError()
980  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
981  << " and layout_b = " << stringifyMMALayout(getLayoutB())
982  << " for input types " << stringifyWGMMATypes(typeA) << " and "
983  << stringifyWGMMATypes(typeB)
984  << " requires transpose. However, this is only supported for: "
985  << stringifyMMATypes(MMATypes::f16) << " and "
986  << stringifyMMATypes(MMATypes::bf16);
987  }
988 
989  // Check result registers
990  int expectedOutput = 0;
991  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
992  expectedOutput = getShape().getN() / 2;
993  if (typeD == WGMMATypes::f16)
994  expectedOutput = getShape().getN() / 4;
995  if (outputSize != expectedOutput) {
996  return emitOpError() << "results " << expectedOutput
997  << ", however output struct has " << outputSize
998  << " elements";
999  }
1000  // Check satfinite (only available for s32 accumulator)
1001  if (typeD != WGMMATypes::s32 &&
1002  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1003  NVVM::MMAIntOverflow::satfinite) {
1004  return emitOpError()
1005  << " `satfinite` can be only used with s32 accumulator, however "
1006  "the current accumulator is "
1007  << NVVM::stringifyWGMMATypes(typeD);
1008  }
1009 
1010  return success();
1011 }
1012 
1013 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1014 
1015  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1016  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1017 
1018  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1019 
1020  int expectedOutputRegisters = 0;
1021  if (getTypeD() == WGMMATypes::f16)
1022  expectedOutputRegisters = getShape().getN() / 4;
1023  else
1024  expectedOutputRegisters = getShape().getN() / 2;
1025 
1026  std::string ptx;
1027  llvm::raw_string_ostream ss(ptx);
1028 
1029  ss << "{\n"
1030  ".reg .pred p;\n"
1031  "setp.ne.b32 p, $"
1032  << ((expectedOutputRegisters * 2) + 2)
1033  << ", 0;\n"
1034  "wgmma.mma_async.sync.aligned.m"
1035  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1036  << stringifyWGMMATypes(getTypeA()) << "."
1037  << stringifyWGMMATypes(getTypeB());
1038  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1039  NVVM::MMAIntOverflow::satfinite)
1040  ss << ".satfinite";
1041  ss << " {";
1042  int regCnt = 0;
1043  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1044  ss << "$" << regCnt;
1045  if (regCnt != expectedOutputRegisters - 1)
1046  ss << ", ";
1047  }
1048 
1049  ss << "},";
1050  // Need to map read/write registers correctly.
1051  regCnt = (regCnt * 2);
1052  ss << " $" << (regCnt) << ","
1053  << " $" << (regCnt + 1) << ","
1054  << " p";
1055  if (getTypeD() != WGMMATypes::s32) {
1056  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1057  }
1058  // Don't add transpose parameters unless needed.
1059  if (isF16) {
1060  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1061  }
1062  ss << ";\n"
1063  << "}\n";
1064  return ptx;
1065 }
1066 
1067 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1068  RewriterBase &rewriter,
1069  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1070  &asmValues) {
1071  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1072  if (getResults())
1073  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1074  if (getInouts())
1075  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1076  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1077  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1078  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1080  if (getTypeD() != WGMMATypes::s32) {
1081  asmValues.push_back(
1082  {makeConstantI32(rewriter,
1083  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1085  asmValues.push_back(
1086  {makeConstantI32(rewriter,
1087  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1089  }
1090  if (isF16) {
1091  asmValues.push_back(
1092  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1094  asmValues.push_back(
1095  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1097  }
1098 }
1099 LogicalResult NVVM::FenceProxyOp::verify() {
1100  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1101  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1102  if (getKind() == NVVM::ProxyKind::GENERIC)
1103  return emitOpError() << "generic proxy not a supported proxy kind";
1104  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1105  return emitOpError() << "async_shared fence requires space attribute";
1106  }
1107  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1108  return emitOpError() << "only async_shared fence can have space attribute";
1109  }
1110  return success();
1111 }
1112 
1113 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1114  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1115  return emitOpError("uni-directional proxies only support generic for "
1116  "from_proxy attribute");
1117 
1118  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1119  return emitOpError("uni-directional proxies only support tensormap "
1120  "for to_proxy attribute");
1121 
1122  return success();
1123 }
1124 
1125 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1126  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1127  return emitOpError("uni-directional proxies only support generic for "
1128  "from_proxy attribute");
1129 
1130  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1131  return emitOpError("uni-directional proxies only support tensormap "
1132  "for to_proxy attribute");
1133 
1134  return success();
1135 }
1136 
1137 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1138  if (getRegCount() % 8)
1139  return emitOpError("new register size must be multiple of 8");
1140  if (getRegCount() < 24 || getRegCount() > 256)
1141  return emitOpError("new register size must be in between 24 to 256");
1142  return success();
1143 }
1144 
1145 LogicalResult NVVM::BarrierOp::verify() {
1146  if (getNumberOfThreads() && !getBarrierId())
1147  return emitOpError(
1148  "barrier id is missing, it should be set between 0 to 15");
1149  return success();
1150 }
1151 
1152 LogicalResult NVVM::Tcgen05CpOp::verify() {
1153  auto mc = getMulticast();
1154 
1155  using SH = Tcgen05CpShape;
1156  using MC = Tcgen05CpMulticast;
1157  switch (getShape()) {
1158  case SH::SHAPE_128x256b:
1159  case SH::SHAPE_128x128b:
1160  case SH::SHAPE_4x256b:
1161  if (mc != MC::NONE)
1162  return emitError("Invalid multicast type for tcgen05.cp Op");
1163  break;
1164  case SH::SHAPE_64x128b:
1165  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1166  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1167  "warpx2_02_13 for tcgen05.cp Op");
1168  break;
1169  case SH::SHAPE_32x128b:
1170  if (mc != MC::WARPX4)
1171  return emitError(
1172  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1173  break;
1174  }
1175  return success();
1176 }
1177 
1178 LogicalResult NVVM::MatchSyncOp::verify() {
1179  if (getKind() == NVVM::MatchSyncKind::all) {
1180  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1181  if (!type || type.getBody().size() != 2 ||
1182  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1183  return emitOpError("match.sync 'all' returns a two element struct with "
1184  "first element as i32 and second element as i1");
1185  }
1186  } else {
1187  if (!getType().isInteger(32)) {
1188  return emitOpError("match.sync 'any' returns an i32");
1189  }
1190  }
1191  return success();
1192 }
1193 
1194 LogicalResult NVVM::VoteSyncOp::verify() {
1195  if (getKind() == NVVM::VoteSyncKind::ballot) {
1196  if (!getType().isInteger(32)) {
1197  return emitOpError("vote.sync 'ballot' returns an i32");
1198  }
1199  } else {
1200  if (!getType().isInteger(1)) {
1201  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1202  }
1203  }
1204  return success();
1205 }
1206 
1207 llvm::Value *
1208 NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
1209  llvm::IRBuilderBase &builder) {
1210  return builder.CreateBitCast(arg,
1211  llvm::Type::getInt32Ty(builder.getContext()));
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // getIntrinsicID/getIntrinsicIDAndArgs methods
1216 //===----------------------------------------------------------------------===//
1217 
1218 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1219  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1220 
1221 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1222  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1223 
1225 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1228 
1229  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1230  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1231  switch (cpAsyncOp.getSize()) {
1232  case 4:
1233  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1234  break;
1235  case 8:
1236  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1237  break;
1238  case 16:
1239  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1240  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1241  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1242  break;
1243  default:
1244  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1245  }
1246 
1247  // Fill the Intrinsic Args
1248  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1249  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1250  if (hasCpSize)
1251  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1252 
1253  return id;
1254 }
1255 
1256 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1257  bool isIm2Col) {
1258  switch (tensorDims) {
1259  case 1:
1260  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1261  case 2:
1262  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1263  case 3:
1264  return isIm2Col
1265  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1266  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1267  case 4:
1268  return isIm2Col
1269  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1270  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1271  case 5:
1272  return isIm2Col
1273  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1274  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1275  default:
1276  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1277  }
1278 }
1279 
1280 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1281  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1282 
1283 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1284  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1285  : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1286 
1287 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1288  [&]() -> auto { \
1289  switch (dims) { \
1290  case 1: \
1291  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1292  case 2: \
1293  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1294  case 3: \
1295  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1296  case 4: \
1297  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1298  case 5: \
1299  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1300  default: \
1301  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1302  } \
1303  }()
1304 
1305 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1306  int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1307  using RedTy = NVVM::TMAReduxKind;
1308  switch (kind) {
1309  case RedTy::ADD:
1310  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1311  case RedTy::MIN:
1312  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1313  case RedTy::MAX:
1314  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1315  case RedTy::INC:
1316  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1317  case RedTy::DEC:
1318  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1319  case RedTy::AND:
1320  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1321  case RedTy::OR:
1322  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1323  case RedTy::XOR:
1324  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1325  }
1326  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1327 }
1328 
1329 #define _none
1330 
1331 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1332  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1333  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1334 
1335 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1336  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1337  : CVT_F2TF32_ID_IMPL(rnd, relu, )
1338 
1339 llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1340  NVVM::SaturationMode sat,
1341  bool hasRelu) {
1342  using RndMode = NVVM::FPRoundingMode;
1343  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1344  switch (rnd) {
1345  case RndMode::RN:
1346  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1347  case RndMode::RZ:
1348  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1349  case RndMode::RNA:
1350  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1351  default:
1352  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1353  }
1354 }
1355 
1356 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1357  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1358  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1359 
1360 llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1361  bool hasRelu) {
1362  switch (type) {
1363  case NVVM::CVTFP6Type::E2M3:
1364  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1365  case NVVM::CVTFP6Type::E3M2:
1366  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1367  }
1368 }
1369 
1370 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1371  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1372  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1373 
1374 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1375  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1376  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1377 
1378 llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1379  NVVM::FPRoundingMode rnd,
1380  NVVM::SaturationMode sat,
1381  bool hasRelu) {
1382  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1383  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1384  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1385 
1386  switch (type) {
1387  case NVVM::CVTFP8Type::E4M3:
1388  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1389  case NVVM::CVTFP8Type::E5M2:
1390  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1391  case NVVM::CVTFP8Type::UE8M0:
1392  if (hasRoundingModeRZ)
1393  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1394  else if (hasRoundingModeRP)
1395  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1396  }
1397  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1398 }
1399 
1400 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1401  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1402  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1403 
1404 llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1405  bool hasRelu) {
1406  switch (type) {
1407  case NVVM::CVTFP8Type::E4M3:
1408  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1409  case NVVM::CVTFP8Type::E5M2:
1410  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1411  default:
1412  llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
1413  }
1414 }
1415 
1416 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1417  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1418  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1419 
1421 CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1422  NVVM::SaturationMode sat) {
1423  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1424  switch (rnd) {
1425  case NVVM::FPRoundingMode::RZ:
1426  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1427  case NVVM::FPRoundingMode::RP:
1428  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1429  default:
1430  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1431  }
1432 }
1433 
1435 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1438  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1439  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1440  .getAddressSpace();
1441  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1442  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1443 
1445  if (isShared) {
1446  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1447  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1448  } else {
1449  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1450  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1451  }
1452 
1453  // Fill the Intrinsic Args
1454  args.push_back(mt.lookupValue(curOp.getAddr()));
1455  args.push_back(mt.lookupValue(curOp.getNCols()));
1456 
1457  return id;
1458 }
1459 
1460 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1463  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1464  auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1465  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1466  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1467 
1468  // Fill the Intrinsic Args
1469  args.push_back(mt.lookupValue(curOp.getTaddr()));
1470  args.push_back(mt.lookupValue(curOp.getNCols()));
1471 
1472  return id;
1473 }
1474 
1475 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1476  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1477  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1478 
1479 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1480  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1481  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1482 
1484 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1487  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1488  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1489  .getAddressSpace();
1490  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1491  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1492  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1493 
1494  llvm::Intrinsic::ID id =
1495  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1496  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1497 
1498  // Fill the Intrinsic Args
1499  args.push_back(mt.lookupValue(curOp.getAddr()));
1500  if (hasMulticast)
1501  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1502 
1503  return id;
1504 }
1505 
1506 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1507  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1508 
1509 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1510  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1511  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1512 
1513 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1514  [&]() -> auto { \
1515  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1516  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1517  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1518  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1519  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1520  }()
1521 
1522 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1523  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1524  bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1525  auto srcFmt = curOp.getSrcFormat();
1526  auto mc = curOp.getMulticast();
1527 
1528  switch (curOp.getShape()) {
1529  case Tcgen05CpShape::SHAPE_128x256b:
1530  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1531  case Tcgen05CpShape::SHAPE_128x128b:
1532  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1533  case Tcgen05CpShape::SHAPE_4x256b:
1534  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1535  case Tcgen05CpShape::SHAPE_32x128b:
1536  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1537  case Tcgen05CpShape::SHAPE_64x128b:
1538  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1539  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1540  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1541  }
1542  llvm_unreachable("Invalid shape in tcgen05 cp Op");
1543 }
1544 
1545 // Returns the valid vector length for a given shape and vector length, the
1546 // function models the table mentioned in the tcgen05.{ld, st} Op description
1547 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1548  unsigned vecLen) {
1549  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1550  return vecLen >= 2;
1551  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1552  return vecLen >= 4;
1553  return true;
1554 }
1555 
1556 LogicalResult Tcgen05LdOp::verify() {
1557  LogicalResult result = success();
1558  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1559  result = emitError("shape 16x32bx2 requires offset argument");
1560 
1561  auto resTy = getRes().getType();
1562  unsigned resLen = isa<VectorType>(resTy)
1563  ? llvm::cast<VectorType>(resTy).getNumElements()
1564  : 1;
1565  if (!isValidVectorLength(getShape(), resLen))
1566  result = emitError(llvm::formatv("invalid result type length {0} for shape "
1567  "{1} in tcgen05.ld Op",
1568  resLen, stringifyEnum(getShape())));
1569 
1570  return result;
1571 }
1572 
1573 LogicalResult Tcgen05StOp::verify() {
1574  LogicalResult result = success();
1575  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1576  result = emitError("shape 16x32bx2 requires offset argument");
1577 
1578  auto valTy = getVal().getType();
1579  unsigned valLen = isa<VectorType>(valTy)
1580  ? llvm::cast<VectorType>(valTy).getNumElements()
1581  : 1;
1582  if (!isValidVectorLength(getShape(), valLen))
1583  result = emitError(llvm::formatv("invalid input length {0} for shape "
1584  "{1} in tcgen05.st Op",
1585  valLen, stringifyEnum(getShape())));
1586 
1587  return result;
1588 }
1589 
1590 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1591 /// have ConstantRangeAttr.
1592 static void nvvmInferResultRanges(Operation *op, Value result,
1594  SetIntRangeFn setResultRanges) {
1595  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1596  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1597  rangeAttr.getLower(), rangeAttr.getUpper()});
1598  }
1599 }
1600 
1602 DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1603  NVVM::DotAccumulate4WayType b_type) {
1604  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1605  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1606  unsigned type = (is_a_siext << 1) | is_b_siext;
1607  switch (type) {
1608  case 0:
1609  return llvm::Intrinsic::nvvm_idp4a_u_u;
1610  case 1:
1611  return llvm::Intrinsic::nvvm_idp4a_u_s;
1612  case 2:
1613  return llvm::Intrinsic::nvvm_idp4a_s_u;
1614  case 3:
1615  return llvm::Intrinsic::nvvm_idp4a_s_s;
1616  default:
1617  llvm_unreachable("Invalid DP4a type");
1618  }
1619 }
1620 
1621 //===----------------------------------------------------------------------===//
1622 // NVVMDialect initialization, type parsing, and registration.
1623 //===----------------------------------------------------------------------===//
1624 
1625 // TODO: This should be the llvm.nvvm dialect once this is supported.
1626 void NVVMDialect::initialize() {
1627  addOperations<
1628 #define GET_OP_LIST
1629 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1630  >();
1631  addAttributes<
1632 #define GET_ATTRDEF_LIST
1633 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1634  >();
1635 
1636  // Support unknown operations because not all NVVM operations are
1637  // registered.
1638  allowUnknownOperations();
1639  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1640  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1641 }
1642 
1643 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1644  NamedAttribute attr) {
1645  StringAttr attrName = attr.getName();
1646  // Kernel function attribute should be attached to functions.
1647  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1648  if (!isa<LLVM::LLVMFuncOp>(op)) {
1649  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1650  << "' attribute attached to unexpected op";
1651  }
1652  }
1653  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1654  // dim
1655  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1656  attrName == NVVMDialect::getReqntidAttrName() ||
1657  attrName == NVVMDialect::getClusterDimAttrName()) {
1658  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1659  if (!values || values.empty() || values.size() > 3)
1660  return op->emitError()
1661  << "'" << attrName
1662  << "' attribute must be integer array with maximum 3 index";
1663  }
1664  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1665  // attribute
1666  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1667  attrName == NVVMDialect::getMaxnregAttrName() ||
1668  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1669  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1670  return op->emitError()
1671  << "'" << attrName << "' attribute must be integer constant";
1672  }
1673 
1674  return success();
1675 }
1676 
1677 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1678  unsigned regionIndex,
1679  unsigned argIndex,
1680  NamedAttribute argAttr) {
1681  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1682  if (!funcOp)
1683  return success();
1684 
1685  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1686  StringAttr attrName = argAttr.getName();
1687  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1688  if (!isKernel) {
1689  return op->emitError()
1690  << "'" << attrName
1691  << "' attribute must be present only on kernel arguments";
1692  }
1693  if (!isa<UnitAttr>(argAttr.getValue()))
1694  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1695  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1696  return op->emitError()
1697  << "'" << attrName
1698  << "' attribute requires the argument to also have attribute '"
1699  << LLVM::LLVMDialect::getByValAttrName() << "'";
1700  }
1701  }
1702 
1703  return success();
1704 }
1705 
1706 //===----------------------------------------------------------------------===//
1707 // NVVM target attribute.
1708 //===----------------------------------------------------------------------===//
1709 LogicalResult
1711  int optLevel, StringRef triple, StringRef chip,
1712  StringRef features, DictionaryAttr flags,
1713  ArrayAttr files) {
1714  if (optLevel < 0 || optLevel > 3) {
1715  emitError() << "The optimization level must be a number between 0 and 3.";
1716  return failure();
1717  }
1718  if (triple.empty()) {
1719  emitError() << "The target triple cannot be empty.";
1720  return failure();
1721  }
1722  if (chip.empty()) {
1723  emitError() << "The target chip cannot be empty.";
1724  return failure();
1725  }
1726  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1727  return mlir::isa_and_nonnull<StringAttr>(attr);
1728  })) {
1729  emitError() << "All the elements in the `link` array must be strings.";
1730  return failure();
1731  }
1732  return success();
1733 }
1734 
1735 #define GET_OP_CLASSES
1736 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1737 
1738 #define GET_ATTRDEF_CLASSES
1739 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
#define _none
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:61
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:95
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:204
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:38
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:40
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:84
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)