MLIR  22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <cassert>
38 #include <optional>
39 #include <string>
40 
41 using namespace mlir;
42 using namespace NVVM;
43 
44 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
45 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // Verifier methods
49 //===----------------------------------------------------------------------===//
50 
51 // This verifier is shared among the following Ops:
52 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
53 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
54 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
55 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
56  bool isIm2Col,
57  size_t numIm2ColOffsets,
58  Location loc) {
59  if (tensorDims < 1 || tensorDims > 5)
60  return emitError(loc, "expects coordinates between 1 to 5 dimension");
61 
62  // For Im2Col mode, there are two constraints:
63  if (isIm2Col) {
64  // 1. Tensor must always be at least 3-d.
65  if (tensorDims < 3)
66  return emitError(
67  loc,
68  "to use im2col mode, the tensor has to be at least 3-dimensional");
69  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
70  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
71  return emitError(
72  loc, "im2col offsets must be 2 less than number of coordinates");
73  }
74  return success();
75 }
76 
78  size_t numIm2ColOffsets = getIm2colOffsets().size();
79  bool isIm2Col = numIm2ColOffsets > 0;
80  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
81  numIm2ColOffsets, getLoc());
82 }
83 
85  if (getCoordinates().size() > 5)
86  return emitError("Maximum 5 coordinates and dimension is supported.");
87  return success();
88 }
89 
90 LogicalResult CpAsyncOp::verify() {
91  if (getModifier() != LoadCacheModifierKind::CG &&
92  getModifier() != LoadCacheModifierKind::CA)
93  return emitError("Only CG and CA cache modifiers are supported.");
94  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
95  return emitError("expected byte size to be either 4, 8 or 16.");
96  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
97  return emitError("CG cache modifier is only support for 16 bytes copy.");
98  return success();
99 }
100 
101 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
102  size_t numIm2ColOffsets = getIm2colOffsets().size();
103  bool isIm2Col = numIm2ColOffsets > 0;
104  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
105  numIm2ColOffsets, getLoc());
106 }
107 
108 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
109  bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
110  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
111  getLoc());
112 }
113 
114 LogicalResult ConvertFloatToTF32Op::verify() {
115  using RndMode = NVVM::FPRoundingMode;
116  switch (getRnd()) {
117  case RndMode::RNA:
118  if (getRelu())
119  return emitError("Relu not supported with rna rounding mode.");
120  break;
121  case RndMode::RN:
122  case RndMode::RZ:
123  break;
124  default:
125  return emitError(
126  "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
127  }
128  return success();
129 }
130 
131 LogicalResult ConvertF32x2ToF8x2Op::verify() {
132  using RndMode = NVVM::FPRoundingMode;
133  using SatMode = NVVM::SaturationMode;
134 
135  bool isRoundingModeRN = getRnd() == RndMode::RN;
136  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
137  bool isRoundingModeRP = getRnd() == RndMode::RP;
138  bool isSatFinite = getSat() == SatMode::SATFINITE;
139 
140  bool hasRelu = getRelu();
141 
142  switch (getType()) {
143  case ConvertFP8Type::E4M3:
144  case ConvertFP8Type::E5M2:
145  if (!isRoundingModeRN)
146  return emitOpError("Only RN rounding mode is supported for conversions "
147  "from f32x2 to .e4m3x2 or .e5m2x2 types");
148  if (!isSatFinite)
149  return emitOpError("Only SATFINITE saturation mode is supported for "
150  "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
151  break;
152  case ConvertFP8Type::UE8M0:
153  if (!(isRoundingModeRZ || isRoundingModeRP))
154  return emitOpError("Only RZ or RP rounding modes are supported for "
155  "conversions from f32x2 to .ue8m0x2 type");
156  if (hasRelu)
157  return emitOpError("relu not supported for conversions to .ue8m0x2 type");
158  break;
159  }
160  return success();
161 }
162 
163 LogicalResult ConvertF16x2ToF8x2Op::verify() {
164  if (getType() == ConvertFP8Type::UE8M0)
165  return emitOpError("Only .e4m3 or .e5m2 types are supported for "
166  "conversions from f16x2 to f8x2.");
167 
168  return success();
169 }
170 
171 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
172  using RndMode = NVVM::FPRoundingMode;
173 
174  if (getType() != ConvertFP8Type::UE8M0)
175  return emitOpError(
176  "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
177 
178  auto rnd = getRnd();
179  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
180  return emitOpError("Only RZ and RP rounding modes are supported for "
181  "conversions from bf16x2 to f8x2.");
182 
183  return success();
184 }
185 
186 LogicalResult BulkStoreOp::verify() {
187  if (getInitVal() != 0)
188  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
189  return success();
190 }
191 
192 // Given the element type of an operand and whether or not it is an accumulator,
193 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
194 // operand's element type.
195 std::optional<mlir::NVVM::MMATypes>
196 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
197  auto half2Type =
198  VectorType::get(2, Float16Type::get(operandElType.getContext()));
199  if (operandElType.isF64())
200  return NVVM::MMATypes::f64;
201  if (operandElType.isF16() || operandElType == half2Type)
202  return NVVM::MMATypes::f16;
203  if (operandElType.isF32() && isAccumulator)
204  return NVVM::MMATypes::f32;
205  if (operandElType.isF32() && !isAccumulator)
206  return NVVM::MMATypes::tf32;
207  if (llvm::isa<IntegerType>(operandElType)) {
208  if (isAccumulator)
209  return NVVM::MMATypes::s32;
210  return std::nullopt;
211  }
212 
213  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
214  if (structType.getBody().empty())
215  return std::nullopt;
216  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
217  }
218 
219  return std::nullopt;
220 }
221 
222 static bool isInt4PtxType(MMATypes type) {
223  return (type == MMATypes::u4 || type == MMATypes::s4);
224 }
225 
226 static bool isInt8PtxType(MMATypes type) {
227  return (type == MMATypes::u8 || type == MMATypes::s8);
228 }
229 
230 static bool isIntegerPtxType(MMATypes type) {
231  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
232  type == MMATypes::s32;
233 }
234 
235 MMATypes MmaOp::accumPtxType() {
236  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
237  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
238  assert(val.has_value() && "accumulator PTX type should always be inferrable");
239  return val.value();
240 }
241 
242 MMATypes MmaOp::resultPtxType() {
243  std::optional<mlir::NVVM::MMATypes> val =
244  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
245  assert(val.has_value() && "result PTX type should always be inferrable");
246  return val.value();
247 }
248 
249 void MmaOp::print(OpAsmPrinter &p) {
250  SmallVector<Type, 4> regTypes;
251  struct OperandFragment {
252  StringRef operandName;
253  StringRef ptxTypeAttr;
255  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
256  : operandName(name), ptxTypeAttr(ptxTypeName) {}
257  };
258 
259  std::array<OperandFragment, 3> frags{
260  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
261  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
262  OperandFragment("C", "")};
263  SmallVector<StringRef, 4> ignoreAttrNames{
264  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
265 
266  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
267  auto &frag = frags[fragIdx];
268  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
269  for (auto operandIdx = varOperandSpec.first;
270  operandIdx < varOperandSpec.first + varOperandSpec.second;
271  operandIdx++) {
272  frag.regs.push_back(this->getOperand(operandIdx));
273  if (operandIdx == 0) {
274  regTypes.push_back(this->getOperand(operandIdx).getType());
275  }
276  }
277  std::optional<MMATypes> inferredType =
278  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
279  if (inferredType)
280  ignoreAttrNames.push_back(frag.ptxTypeAttr);
281  }
282 
283  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
284  p << " " << frag.operandName;
285  p << "[";
286  p.printOperands(frag.regs);
287  p << "] ";
288  };
289 
290  for (const auto &frag : frags) {
291  printMmaOperand(frag);
292  }
293 
294  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
295 
296  // Print the types of the operands and result.
297  p << " : "
298  << "(";
299  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
300  frags[1].regs[0].getType(),
301  frags[2].regs[0].getType()},
302  p);
303  p << ")";
304  p.printArrowTypeList(TypeRange{this->getRes().getType()});
305 }
306 
307 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
308  ValueRange operandA, ValueRange operandB, ValueRange operandC,
309  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
310  std::optional<MMAIntOverflow> intOverflow,
311  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
312  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
313 
314  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
315  MLIRContext *ctx = builder.getContext();
316  result.addAttribute(
317  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
318 
319  result.addOperands(operandA);
320  result.addOperands(operandB);
321  result.addOperands(operandC);
322 
323  if (multiplicandPtxTypes) {
324  result.addAttribute("multiplicandAPtxType",
325  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
326  result.addAttribute("multiplicandBPtxType",
327  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
328  } else {
329  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
330  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
331  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
332  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
333  }
334 
335  if (multiplicandLayouts) {
336  result.addAttribute("layoutA",
337  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
338  result.addAttribute("layoutB",
339  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
340  } else {
341  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
342  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
343  }
344 
345  if (intOverflow.has_value())
346  result.addAttribute("intOverflowBehavior",
347  MMAIntOverflowAttr::get(ctx, *intOverflow));
348  if (b1Op.has_value())
349  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
350 
351  result.addTypes(resultType);
352  result.addAttribute(
353  MmaOp::getOperandSegmentSizeAttr(),
354  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
355  static_cast<int32_t>(operandB.size()),
356  static_cast<int32_t>(operandC.size())}));
357 }
358 
359 // <operation> :=
360 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
361 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
362 // `->` type($res)
363 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
364  struct OperandFragment {
365  std::optional<MMATypes> elemtype;
367  SmallVector<Type> regTypes;
368  };
369 
370  Builder &builder = parser.getBuilder();
371  std::array<OperandFragment, 4> frags;
372 
373  NamedAttrList namedAttributes;
374 
375  // A helper to parse the operand segments.
376  auto parseMmaOperand = [&](StringRef operandName,
377  OperandFragment &frag) -> LogicalResult {
378  if (parser.parseKeyword(operandName).failed())
379  return failure();
380  if (parser
381  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
382  .failed())
383  return failure();
384  return success();
385  };
386 
387  // Parse the operand segments.
388  if (parseMmaOperand("A", frags[0]).failed())
389  return failure();
390  if (parseMmaOperand("B", frags[1]).failed())
391  return failure();
392  if (parseMmaOperand("C", frags[2]).failed())
393  return failure();
394 
395  if (parser.parseOptionalAttrDict(namedAttributes).failed())
396  return failure();
397 
398  // Parse the type specification and resolve operands.
399  SmallVector<Type, 3> operandTypes;
400  if (failed(parser.parseColon()))
401  return failure();
402  if (failed(parser.parseLParen()))
403  return failure();
404  if (failed(parser.parseTypeList(operandTypes)))
405  return failure();
406  if (failed(parser.parseRParen()))
407  if (operandTypes.size() != 3)
408  return parser.emitError(
409  parser.getNameLoc(),
410  "expected one type for each operand segment but got " +
411  Twine(operandTypes.size()) + " types");
412  for (const auto &iter : llvm::enumerate(operandTypes)) {
413  auto &frag = frags[iter.index()];
414  frag.regTypes.resize(frag.regs.size(), iter.value());
415  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
416  parser.getNameLoc(), result.operands)))
417  return failure();
418  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
419  /*isAccumulator*/ iter.index() < 2);
420  }
421 
422  Type resultType;
423  if (parser.parseArrow() || parser.parseType(resultType))
424  return failure();
425  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
426 
427  std::array<StringRef, 2> names{"multiplicandAPtxType",
428  "multiplicandBPtxType"};
429  for (unsigned idx = 0; idx < names.size(); idx++) {
430  const auto &frag = frags[idx];
431  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
432  if (!frag.elemtype.has_value() && !attr.has_value()) {
433  return parser.emitError(
434  parser.getNameLoc(),
435  "attribute " + names[idx] +
436  " is not provided explicitly and cannot be inferred");
437  }
438  if (!attr.has_value())
439  result.addAttribute(
440  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
441  }
442 
443  result.addTypes(resultType);
444  if (!namedAttributes.empty())
445  result.addAttributes(namedAttributes);
446  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
447  builder.getDenseI32ArrayAttr({
448  static_cast<int32_t>(frags[0].regs.size()),
449  static_cast<int32_t>(frags[1].regs.size()),
450  static_cast<int32_t>(frags[2].regs.size()),
451  }));
452  return success();
453 }
454 
455 LogicalResult MmaOp::verify() {
456  MLIRContext *context = getContext();
457  auto f16Ty = Float16Type::get(context);
458  auto i32Ty = IntegerType::get(context, 32);
459  auto f16x2Ty = VectorType::get(2, f16Ty);
460  auto f32Ty = Float32Type::get(context);
461  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
462  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
463 
464  auto s32x4StructTy =
465  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
466  auto f32x8StructTy =
467  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
468  auto f16x2x2StructTy =
469  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
470  auto f32x4StructTy =
471  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
472  auto s32x2StructTy =
473  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
474 
475  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
476  getShapeAttr().getK()};
477 
478  // These variables define the set of allowed data types for matrices A, B, C,
479  // and result.
480  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
481  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
482  AllowedShapes allowedShapes;
483  AllowedTypes expectedA;
484  AllowedTypes expectedB;
485  AllowedTypes expectedC;
486  SmallVector<Type> expectedResult;
487 
488  // When M = 16, we just need to calculate the number of 8xk tiles, where
489  // k is a factor that depends on the data type.
490  if (mmaShape[0] == 16) {
491  int64_t kFactor;
492  Type multiplicandFragType;
493  switch (*getMultiplicandAPtxType()) {
494  case MMATypes::tf32:
495  kFactor = 4;
496  multiplicandFragType = i32Ty;
497  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
498  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
499  break;
500  case MMATypes::bf16:
501  kFactor = 8;
502  multiplicandFragType = i32Ty;
503  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
504  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
505  break;
506  case MMATypes::f16:
507  kFactor = 8;
508  multiplicandFragType = f16x2Ty;
509  expectedResult.push_back(f16x2x2StructTy);
510  expectedResult.push_back(f32x4StructTy);
511  break;
512  case MMATypes::s4:
513  case MMATypes::u4:
514  kFactor = 32;
515  break;
516  case MMATypes::b1:
517  kFactor = 128;
518  break;
519  case MMATypes::s8:
520  case MMATypes::u8:
521  kFactor = 16;
522  break;
523  default:
524  return emitError("invalid shape or multiplicand type: " +
525  stringifyEnum(getMultiplicandAPtxType().value()));
526  }
527 
528  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
529  expectedResult.push_back(s32x4StructTy);
530  expectedC.emplace_back(4, i32Ty);
531  multiplicandFragType = i32Ty;
532  } else {
533  expectedC.emplace_back(2, f16x2Ty);
534  expectedC.emplace_back(4, f32Ty);
535  }
536 
537  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
538  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
539  expectedA.emplace_back(unitA, multiplicandFragType);
540  expectedB.emplace_back(unitB, multiplicandFragType);
541  allowedShapes.push_back({16, 8, kFactor});
542  allowedShapes.push_back({16, 8, kFactor * 2});
543  }
544 
545  // In the M=8 case, there is only 1 possible case per data type.
546  if (mmaShape[0] == 8) {
547  if (*getMultiplicandAPtxType() == MMATypes::f16) {
548  expectedA.emplace_back(2, f16x2Ty);
549  expectedB.emplace_back(2, f16x2Ty);
550  expectedResult.push_back(f16x2x4StructTy);
551  expectedResult.push_back(f32x8StructTy);
552  expectedC.emplace_back(4, f16x2Ty);
553  expectedC.emplace_back(8, f32Ty);
554  allowedShapes.push_back({8, 8, 4});
555  }
556  if (*getMultiplicandAPtxType() == MMATypes::f64) {
557  Type f64Ty = Float64Type::get(context);
558  expectedA.emplace_back(1, f64Ty);
559  expectedB.emplace_back(1, f64Ty);
560  expectedC.emplace_back(2, f64Ty);
561  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
562  context, SmallVector<Type>(2, f64Ty)));
563  allowedShapes.push_back({8, 8, 4});
564  }
565  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
566  expectedA.push_back({i32Ty});
567  expectedB.push_back({i32Ty});
568  expectedC.push_back({i32Ty, i32Ty});
569  expectedResult.push_back(s32x2StructTy);
570  if (isInt4PtxType(getMultiplicandAPtxType().value()))
571  allowedShapes.push_back({8, 8, 32});
572  if (isInt8PtxType(getMultiplicandAPtxType().value()))
573  allowedShapes.push_back({8, 8, 16});
574  if (getMultiplicandAPtxType().value() == MMATypes::b1)
575  allowedShapes.push_back({8, 8, 128});
576  }
577  }
578 
579  std::string errorMessage;
580  llvm::raw_string_ostream errorStream(errorMessage);
581 
582  // Check that we matched an existing shape/dtype combination.
583  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
584  !llvm::is_contained(allowedShapes, mmaShape)) {
585  errorStream << "unimplemented variant for MMA shape <";
586  llvm::interleaveComma(mmaShape, errorStream);
587  errorStream << ">";
588  return emitOpError(errorMessage);
589  }
590 
591  // Verify the operand types for segments of A, B, and C operands.
592  std::array<StringRef, 3> operandNames{"A", "B", "C"};
593  for (const auto &iter : llvm::enumerate(
594  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
595  auto spec = this->getODSOperandIndexAndLength(iter.index());
596  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
597  operand_type_begin() + spec.first +
598  spec.second);
599  bool match = llvm::is_contained(iter.value(), operandTySeg);
600 
601  if (!match) {
602  errorStream << "Could not match types for the "
603  << operandNames[iter.index()]
604  << " operands; expected one of ";
605  for (const auto &x : iter.value()) {
606  errorStream << x.size() << "x" << x[0] << " ";
607  }
608  errorStream << "but got ";
609  llvm::interleaveComma(operandTySeg, errorStream);
610  return emitOpError(errorMessage);
611  }
612  }
613 
614  // Check the result type
615  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
616  return expectedResultType == getResult().getType();
617  })) {
618  errorStream
619  << "Could not match allowed types for the result; expected one of ";
620  llvm::interleaveComma(expectedResult, errorStream);
621  errorStream << " but got " << getResult().getType();
622  return emitOpError(errorMessage);
623  }
624 
625  // Ensure that binary MMA variants have a b1 MMA operation defined.
626  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
627  return emitOpError("op requires " + getB1OpAttrName().strref() +
628  " attribute");
629  }
630 
631  // Ensure int4/int8 MMA variants specify the accum overflow behavior
632  // attribute.
633  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
634  isInt8PtxType(*getMultiplicandAPtxType())) {
635  if (!getIntOverflowBehavior())
636  return emitOpError("op requires " +
637  getIntOverflowBehaviorAttrName().strref() +
638  " attribute");
639  }
640 
641  return success();
642 }
643 
644 LogicalResult ShflOp::verify() {
645  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
646  return success();
647  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
648  auto elementType = (type && type.getBody().size() == 2)
649  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
650  : nullptr;
651  if (!elementType || elementType.getWidth() != 1)
652  return emitError("expected return type to be a two-element struct with "
653  "i1 as the second element");
654  return success();
655 }
656 
657 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
658  NVVM::MMAFrag frag, int nRow,
659  int nCol,
660  MLIRContext *context) {
661  unsigned numberElements = 0;
662  Type elementType;
663  OpBuilder builder(context);
664  Type f16x2 = VectorType::get(2, builder.getF16Type());
665  if (type == NVVM::MMATypes::f16) {
666  elementType = f16x2;
667  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
668  numberElements = 8;
669  else
670  numberElements = 4;
671  } else if (type == NVVM::MMATypes::f32) {
672  elementType = builder.getF32Type();
673  numberElements = 8;
674  } else if (type == NVVM::MMATypes::tf32) {
675  elementType = builder.getI32Type();
676  numberElements = 4;
677  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
678  elementType = builder.getI32Type();
679  int parallelSize = 0;
680  if (frag == NVVM::MMAFrag::a)
681  parallelSize = nRow;
682  if (frag == NVVM::MMAFrag::b)
683  parallelSize = nCol;
684 
685  // m == 16 && n == 16 && k == 16
686  if (parallelSize == 16)
687  numberElements = 2;
688  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
689  else if (parallelSize == 8)
690  numberElements = 1;
691  else if (parallelSize == 32)
692  numberElements = 4;
693  } else if (type == NVVM::MMATypes::s32) {
694  elementType = builder.getI32Type();
695  numberElements = 8;
696  }
697  assert(numberElements != 0 && elementType != nullptr);
698  return std::make_pair(elementType, numberElements);
699 }
700 
701 static std::pair<mlir::Type, unsigned>
702 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
703  int k, MLIRContext *context) {
704  int nRow, nCol;
705  if (frag == NVVM::MMAFrag::a) {
706  nRow = m;
707  nCol = k;
708  } else if (frag == NVVM::MMAFrag::b) {
709  nRow = k;
710  nCol = n;
711  } else {
712  nRow = m;
713  nCol = n;
714  }
715  assert(nRow && nCol);
716  return inferMMAType(type, frag, nRow, nCol, context);
717 }
718 
719 LogicalResult NVVM::WMMALoadOp::verify() {
720  unsigned addressSpace =
721  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
722  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
723  addressSpace != NVVM::kSharedMemorySpace)
724  return emitOpError("expected source pointer in memory "
725  "space 0, 1, 3");
726 
727  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
728  getEltype(), getFrag()) == 0)
729  return emitOpError() << "invalid attribute combination";
730  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
731  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
732  Type dstType = LLVM::LLVMStructType::getLiteral(
733  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
734  if (getType() != dstType)
735  return emitOpError("expected destination type is a structure of ")
736  << typeInfo.second << " elements of type " << typeInfo.first;
737  return success();
738 }
739 
740 LogicalResult NVVM::WMMAStoreOp::verify() {
741  unsigned addressSpace =
742  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
743  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
744  addressSpace != NVVM::kSharedMemorySpace)
745  return emitOpError("expected operands to be a source pointer in memory "
746  "space 0, 1, 3");
747 
748  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
749  getEltype()) == 0)
750  return emitOpError() << "invalid attribute combination";
751  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
752  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
753  if (getArgs().size() != typeInfo.second)
754  return emitOpError() << "expected " << typeInfo.second << " data operands";
755  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
756  return operands.getType() != typeInfo.first;
757  }))
758  return emitOpError() << "expected data operands of type " << typeInfo.first;
759  return success();
760 }
761 
762 LogicalResult NVVM::WMMAMmaOp::verify() {
763  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
764  getLayoutB(), getEltypeA(),
765  getEltypeB()) == 0)
766  return emitOpError() << "invalid attribute combination";
767  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
768  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
769  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
770  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
771  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
772  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
773  SmallVector<Type, 32> arguments;
774  arguments.append(typeInfoA.second, typeInfoA.first);
775  arguments.append(typeInfoB.second, typeInfoB.first);
776  arguments.append(typeInfoC.second, typeInfoC.first);
777  unsigned numArgs = arguments.size();
778  if (getArgs().size() != numArgs)
779  return emitOpError() << "expected " << numArgs << " arguments";
780  for (unsigned i = 0; i < numArgs; i++) {
781  if (getArgs()[i].getType() != arguments[i])
782  return emitOpError() << "expected argument " << i << " to be of type "
783  << arguments[i];
784  }
785  Type dstType = LLVM::LLVMStructType::getLiteral(
786  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
787  if (getType() != dstType)
788  return emitOpError("expected destination type is a structure of ")
789  << typeInfoC.second << " elements of type " << typeInfoC.first;
790  return success();
791 }
792 
793 LogicalResult NVVM::LdMatrixOp::verify() {
794  unsigned addressSpace =
795  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
796  if (addressSpace != NVVM::kSharedMemorySpace)
797  return emitOpError("expected source pointer in memory space 3");
798 
799  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
800  return emitOpError("expected num attribute to be 1, 2 or 4");
801 
802  Type i32 = IntegerType::get(getContext(), 32);
803  if (getNum() == 1 && getType() != i32)
804  return emitOpError("expected destination type is i32");
805  if (getNum() == 2 || getNum() == 4) {
806  Type dstType = LLVM::LLVMStructType::getLiteral(
807  getContext(), SmallVector<Type>(getNum(), i32));
808  if (getType() != dstType)
809  return emitOpError("expected destination type is a structure of ")
810  << getNum() << " elements of type i32";
811  }
812  return success();
813 }
814 
815 LogicalResult NVVM::StMatrixOp::verify() {
816  int numMatrix = getSources().size();
817  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
818  return emitOpError("expected num attribute to be 1, 2 or 4");
819 
820  int m = getShape().getM(), n = getShape().getN();
821  if (m == 8 && n == 8) {
822  if (getEltType() != NVVM::LdStMatrixEltType::B16) {
823  return emitOpError("expected element type to be B16 for 8x8 matrix");
824  }
825  } else if (m == 16 && n == 8) {
826  if (getEltType() != NVVM::LdStMatrixEltType::B8) {
827  return emitOpError("expected element type to be B8 for 16x8 matrix");
828  }
829  if (getLayout() != NVVM::MMALayout::col) {
830  return emitOpError("expected layout to be col for 16x8 matrix");
831  }
832  } else {
833  return emitOpError("expected shape to be 8x8 or 16x8");
834  }
835 
836  return success();
837 }
838 
839 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
840  if (typeA == NVVM::WGMMATypes::tf32)
841  return 8;
842  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
843  return 16;
844  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
845  return 32;
846  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
847  return 32;
848  if (typeA == NVVM::WGMMATypes::b1)
849  return 256;
850  return failure();
851 }
852 
853 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
854  NVVM::WGMMATypes typeA,
855  NVVM::WGMMATypes typeB) {
856  switch (typeA) {
857  case NVVM::WGMMATypes::f16:
858  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
859  typeB == NVVM::WGMMATypes::f16)
860  return success();
861  break;
862  case NVVM::WGMMATypes::tf32:
863  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
864  return success();
865  break;
866  case NVVM::WGMMATypes::u8:
867  case NVVM::WGMMATypes::s8:
868  if (typeD == NVVM::WGMMATypes::s32 &&
869  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
870  return success();
871  break;
872  case NVVM::WGMMATypes::b1:
873  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
874  return success();
875  break;
876  case NVVM::WGMMATypes::bf16:
877  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
878  typeB == NVVM::WGMMATypes::bf16)
879  return success();
880  break;
881  case NVVM::WGMMATypes::e4m3:
882  case NVVM::WGMMATypes::e5m2:
883  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
884  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
885  return success();
886  break;
887  case WGMMATypes::f32:
888  case WGMMATypes::s32:
889  llvm_unreachable("unsupported input types");
890  break;
891  }
892  return failure();
893 }
894 
895 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
896  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
897  72, 80, 88, 96, 104, 112, 120, 128,
898  136, 144, 152, 160, 168, 176, 184, 192,
899  200, 208, 216, 224, 232, 240, 248, 256};
900  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
901  80, 96, 112, 128, 144, 160,
902  176, 192, 208, 224, 240, 256};
903  switch (typeA) {
904  case WGMMATypes::f16:
905  case WGMMATypes::tf32:
906  case WGMMATypes::bf16:
907  case WGMMATypes::e4m3:
908  case WGMMATypes::e5m2:
909  if (llvm::is_contained(allowedN, sizeN))
910  return success();
911  break;
912  case WGMMATypes::u8:
913  case WGMMATypes::s8:
914  case WGMMATypes::b1:
915  if (llvm::is_contained(allowedNshort, sizeN))
916  return success();
917  break;
918  case WGMMATypes::f32:
919  case WGMMATypes::s32:
920  llvm_unreachable("unsupported input types");
921  break;
922  }
923  return failure();
924 }
925 
926 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
927  Value outValue = getResults();
928  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
929  if (!stype)
930  return emitOpError() << "expected results to be struct";
931  int outputSize = stype.getBody().size();
932  WGMMATypes typeD = getTypeD();
933  WGMMATypes typeA = getTypeA();
934  WGMMATypes typeB = getTypeB();
935 
936  for (Type t : stype.getBody()) {
937  if (t != stype.getBody().front())
938  return emitOpError()
939  << "all elements in struct must be same type but there is " << t;
940  }
941 
942  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
943  typeD != WGMMATypes::s32) {
944  return emitOpError() << "does not support the given output type "
945  << NVVM::stringifyWGMMATypes(typeD);
946  }
947  if (typeD == WGMMATypes::s32 &&
948  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
949  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
950  }
951 
952  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
953  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
954  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
955  << NVVM::stringifyWGMMATypes(typeB)
956  << ", it is not supported.";
957  }
958 
959  // Check M
960  if (getShape().getM() != 64)
961  return emitOpError() << "shape 'm' must be 64";
962 
963  // Check K
964  FailureOr<int> allowedK = getAllowedSizeK(typeA);
965  if (failed(allowedK) || allowedK.value() != getShape().getK())
966  return emitOpError() << "shape 'k' must be " << allowedK.value()
967  << " for input type "
968  << NVVM::stringifyWGMMATypes(typeA);
969 
970  // Check N
971  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
972  return emitOpError() << "has input type "
973  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
974  << getShape().getN() << ", it is not supported.";
975  }
976 
977  // Check transpose (only available for f16/bf16)
978  // Matrices A should be stored in row-major and B in column-major.
979  // Only f16/bf16 matrices can be stored in either column-major or row-major
980  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
981  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
982  (getLayoutA() == mlir::NVVM::MMALayout::col ||
983  getLayoutB() == mlir::NVVM::MMALayout::row)) {
984  return emitOpError()
985  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
986  << " and layout_b = " << stringifyMMALayout(getLayoutB())
987  << " for input types " << stringifyWGMMATypes(typeA) << " and "
988  << stringifyWGMMATypes(typeB)
989  << " requires transpose. However, this is only supported for: "
990  << stringifyMMATypes(MMATypes::f16) << " and "
991  << stringifyMMATypes(MMATypes::bf16);
992  }
993 
994  // Check result registers
995  int expectedOutput = 0;
996  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
997  expectedOutput = getShape().getN() / 2;
998  if (typeD == WGMMATypes::f16)
999  expectedOutput = getShape().getN() / 4;
1000  if (outputSize != expectedOutput) {
1001  return emitOpError() << "results " << expectedOutput
1002  << ", however output struct has " << outputSize
1003  << " elements";
1004  }
1005  // Check satfinite (only available for s32 accumulator)
1006  if (typeD != WGMMATypes::s32 &&
1007  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1008  NVVM::MMAIntOverflow::satfinite) {
1009  return emitOpError()
1010  << " `satfinite` can be only used with s32 accumulator, however "
1011  "the current accumulator is "
1012  << NVVM::stringifyWGMMATypes(typeD);
1013  }
1014 
1015  return success();
1016 }
1017 
1018 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1019 
1020  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1021  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1022 
1023  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1024 
1025  int expectedOutputRegisters = 0;
1026  if (getTypeD() == WGMMATypes::f16)
1027  expectedOutputRegisters = getShape().getN() / 4;
1028  else
1029  expectedOutputRegisters = getShape().getN() / 2;
1030 
1031  std::string ptx;
1032  llvm::raw_string_ostream ss(ptx);
1033 
1034  ss << "{\n"
1035  ".reg .pred p;\n"
1036  "setp.ne.b32 p, $"
1037  << ((expectedOutputRegisters * 2) + 2)
1038  << ", 0;\n"
1039  "wgmma.mma_async.sync.aligned.m"
1040  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1041  << stringifyWGMMATypes(getTypeA()) << "."
1042  << stringifyWGMMATypes(getTypeB());
1043  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1044  NVVM::MMAIntOverflow::satfinite)
1045  ss << ".satfinite";
1046  ss << " {";
1047  int regCnt = 0;
1048  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1049  ss << "$" << regCnt;
1050  if (regCnt != expectedOutputRegisters - 1)
1051  ss << ", ";
1052  }
1053 
1054  ss << "},";
1055  // Need to map read/write registers correctly.
1056  regCnt = (regCnt * 2);
1057  ss << " $" << (regCnt) << ","
1058  << " $" << (regCnt + 1) << ","
1059  << " p";
1060  if (getTypeD() != WGMMATypes::s32) {
1061  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1062  }
1063  // Don't add transpose parameters unless needed.
1064  if (isF16) {
1065  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1066  }
1067  ss << ";\n"
1068  << "}\n";
1069  return ptx;
1070 }
1071 
1072 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1073  RewriterBase &rewriter,
1074  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1075  &asmValues) {
1076  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1077  if (getResults())
1078  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1079  if (getInouts())
1080  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1081  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1082  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1083  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1085  if (getTypeD() != WGMMATypes::s32) {
1086  asmValues.push_back(
1087  {makeConstantI32(rewriter,
1088  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1090  asmValues.push_back(
1091  {makeConstantI32(rewriter,
1092  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1094  }
1095  if (isF16) {
1096  asmValues.push_back(
1097  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1099  asmValues.push_back(
1100  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1102  }
1103 }
1104 LogicalResult NVVM::FenceProxyOp::verify() {
1105  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1106  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1107  if (getKind() == NVVM::ProxyKind::GENERIC)
1108  return emitOpError() << "generic proxy not a supported proxy kind";
1109  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1110  return emitOpError() << "async_shared fence requires space attribute";
1111  }
1112  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1113  return emitOpError() << "only async_shared fence can have space attribute";
1114  }
1115  return success();
1116 }
1117 
1118 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1119  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1120  return emitOpError("uni-directional proxies only support generic for "
1121  "from_proxy attribute");
1122 
1123  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1124  return emitOpError("uni-directional proxies only support tensormap "
1125  "for to_proxy attribute");
1126 
1127  return success();
1128 }
1129 
1130 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1131  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1132  return emitOpError("uni-directional proxies only support generic for "
1133  "from_proxy attribute");
1134 
1135  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1136  return emitOpError("uni-directional proxies only support tensormap "
1137  "for to_proxy attribute");
1138 
1139  return success();
1140 }
1141 
1142 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1143  if (getRegCount() % 8)
1144  return emitOpError("new register size must be multiple of 8");
1145  if (getRegCount() < 24 || getRegCount() > 256)
1146  return emitOpError("new register size must be in between 24 to 256");
1147  return success();
1148 }
1149 
1150 LogicalResult NVVM::BarrierOp::verify() {
1151  if (getNumberOfThreads() && !getBarrierId())
1152  return emitOpError(
1153  "barrier id is missing, it should be set between 0 to 15");
1154  return success();
1155 }
1156 
1157 LogicalResult NVVM::Tcgen05CpOp::verify() {
1158  auto mc = getMulticast();
1159 
1160  using SH = Tcgen05CpShape;
1161  using MC = Tcgen05CpMulticast;
1162  switch (getShape()) {
1163  case SH::SHAPE_128x256b:
1164  case SH::SHAPE_128x128b:
1165  case SH::SHAPE_4x256b:
1166  if (mc != MC::NONE)
1167  return emitError("Invalid multicast type for tcgen05.cp Op");
1168  break;
1169  case SH::SHAPE_64x128b:
1170  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1171  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1172  "warpx2_02_13 for tcgen05.cp Op");
1173  break;
1174  case SH::SHAPE_32x128b:
1175  if (mc != MC::WARPX4)
1176  return emitError(
1177  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1178  break;
1179  }
1180  return success();
1181 }
1182 
1183 LogicalResult NVVM::MatchSyncOp::verify() {
1184  if (getKind() == NVVM::MatchSyncKind::all) {
1185  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1186  if (!type || type.getBody().size() != 2 ||
1187  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1188  return emitOpError("match.sync 'all' returns a two element struct with "
1189  "first element as i32 and second element as i1");
1190  }
1191  } else {
1192  if (!getType().isInteger(32)) {
1193  return emitOpError("match.sync 'any' returns an i32");
1194  }
1195  }
1196  return success();
1197 }
1198 
1199 LogicalResult NVVM::VoteSyncOp::verify() {
1200  if (getKind() == NVVM::VoteSyncKind::ballot) {
1201  if (!getType().isInteger(32)) {
1202  return emitOpError("vote.sync 'ballot' returns an i32");
1203  }
1204  } else {
1205  if (!getType().isInteger(1)) {
1206  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1207  }
1208  }
1209  return success();
1210 }
1211 
1212 LogicalResult NVVM::PrefetchOp::verify() {
1213  using MemSpace = NVVM::NVVMMemorySpace;
1214  using CacheLevel = NVVM::PrefetchCacheLevel;
1215 
1216  unsigned addressSpace =
1217  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1218  std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1219 
1220  if (getUniform()) {
1221  if (getCacheLevel() != CacheLevel::L1)
1222  return emitOpError("unsupported cache level, the only supported uniform "
1223  "cache level is L1");
1224 
1225  if (addressSpace != MemSpace::kGenericMemorySpace)
1226  return emitOpError(
1227  "prefetch to uniform cache requires a generic pointer");
1228  }
1229 
1230  if (evictPriority) {
1231  if (getCacheLevel() != CacheLevel::L2)
1232  return emitOpError(
1233  "cache eviction priority supported only for cache level L2");
1234 
1235  if (addressSpace != MemSpace::kGlobalMemorySpace)
1236  return emitOpError("cache eviction priority requires a global pointer");
1237 
1238  if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1239  *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1240  return emitOpError(
1241  "unsupported cache eviction priority, only evict_last and "
1242  "evict_normal are supported");
1243  }
1244 
1245  return success();
1246 }
1247 
1248 /// Packs the given `field` into the `result`.
1249 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1250 static llvm::Value *
1251 packValInto64Bits(llvm::IRBuilderBase &builder,
1252  llvm::Value *result, // the `result` (unset bits are zero)
1253  llvm::Value *field, // `field` to pack into `result`
1254  unsigned sizeInBits, // Size of `field` in bits
1255  unsigned start) { // Starting bit within `result`
1256  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1257 
1258  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1259  if (mask != 0xffffffffu)
1260  field = builder.CreateAnd(field, builder.getInt32(mask));
1261 
1262  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1263  field = builder.CreateShl(field, start);
1264 
1265  return builder.CreateOr(result, field);
1266 }
1267 
1268 void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1270  llvm::IRBuilderBase &builder) {
1271  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1272  llvm::Value *smemDesc = builder.getInt64(0);
1273 
1274  smemDesc = packValInto64Bits(builder, smemDesc,
1275  mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1276  smemDesc = packValInto64Bits(
1277  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1278  smemDesc = packValInto64Bits(
1279  builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1280 
1281  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1282  smemDesc = packValInto64Bits(builder, smemDesc,
1283  mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1284  smemDesc = packValInto64Bits(
1285  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1286  smemDesc = packValInto64Bits(builder, smemDesc,
1287  mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1288 
1289  mt.mapValue(thisOp.getRes()) = smemDesc;
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // getIntrinsicID/getIntrinsicIDAndArgs methods
1294 //===----------------------------------------------------------------------===//
1295 
1296 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1297  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1298 
1299 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1300  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1301 
1303 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1306 
1307  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1308  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1309  switch (cpAsyncOp.getSize()) {
1310  case 4:
1311  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1312  break;
1313  case 8:
1314  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1315  break;
1316  case 16:
1317  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1318  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1319  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1320  break;
1321  default:
1322  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1323  }
1324 
1325  // Fill the Intrinsic Args
1326  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1327  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1328  if (hasCpSize)
1329  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1330 
1331  return id;
1332 }
1333 
1334 mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1335  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1336  auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1338  llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1339 
1340  // Fill the Intrinsic Args
1341  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1342  args.push_back(mt.lookupValue(thisOp.getSize()));
1343 
1344  mlir::Value cacheHint = thisOp.getL2CacheHint();
1345  const bool hasCacheHint = static_cast<bool>(cacheHint);
1346  llvm::Value *i64Unused =
1347  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1348  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1349  args.push_back(builder.getInt1(hasCacheHint));
1350 
1351  return {id, std::move(args)};
1352 }
1353 
1354 mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1355  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1356  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1358  llvm::Intrinsic::ID id =
1359  llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1360 
1361  // Fill the Intrinsic Args
1362  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1363  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1364  args.push_back(mt.lookupValue(thisOp.getSize()));
1365 
1366  mlir::Value cacheHint = thisOp.getL2CacheHint();
1367  const bool hasCacheHint = static_cast<bool>(cacheHint);
1368  llvm::Value *i64Unused =
1369  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1370  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1371  args.push_back(builder.getInt1(hasCacheHint));
1372 
1373  // Choose the bytemask variant
1374  if (mlir::Value byteMask = thisOp.getByteMask()) {
1375  args.push_back(mt.lookupValue(byteMask));
1376  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1377  }
1378 
1379  return {id, std::move(args)};
1380 }
1381 
1382 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1383  bool isIm2Col) {
1384  switch (tensorDims) {
1385  case 1:
1386  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1387  case 2:
1388  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1389  case 3:
1390  return isIm2Col
1391  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1392  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1393  case 4:
1394  return isIm2Col
1395  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1396  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1397  case 5:
1398  return isIm2Col
1399  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1400  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1401  default:
1402  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1403  }
1404 }
1405 
1406 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1407  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1408 
1409 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1410  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1411  : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1412 
1413 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1414  [&]() -> auto { \
1415  switch (dims) { \
1416  case 1: \
1417  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1418  case 2: \
1419  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1420  case 3: \
1421  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1422  case 4: \
1423  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1424  case 5: \
1425  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1426  default: \
1427  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1428  } \
1429  }()
1430 
1431 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1432  int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1433  using RedTy = NVVM::TMAReduxKind;
1434  switch (kind) {
1435  case RedTy::ADD:
1436  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1437  case RedTy::MIN:
1438  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1439  case RedTy::MAX:
1440  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1441  case RedTy::INC:
1442  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1443  case RedTy::DEC:
1444  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1445  case RedTy::AND:
1446  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1447  case RedTy::OR:
1448  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1449  case RedTy::XOR:
1450  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1451  }
1452  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1453 }
1454 
1455 #define _none
1456 
1457 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1458  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1459  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1460 
1461 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1462  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1463  : CVT_F2TF32_ID_IMPL(rnd, relu, )
1464 
1466 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1467  NVVM::SaturationMode sat, bool hasRelu) {
1468  using RndMode = NVVM::FPRoundingMode;
1469  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1470  switch (rnd) {
1471  case RndMode::RN:
1472  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1473  case RndMode::RZ:
1474  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1475  case RndMode::RNA:
1476  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1477  default:
1478  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1479  }
1480 }
1481 
1482 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1483  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1484  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1485 
1487 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1488  switch (type) {
1489  case NVVM::ConvertFP6Type::E2M3:
1490  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1491  case NVVM::ConvertFP6Type::E3M2:
1492  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1493  }
1494  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1495 }
1496 
1497 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1498  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1499  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1500 
1501 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1502  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1503  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1504 
1506 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1507  NVVM::FPRoundingMode rnd,
1508  NVVM::SaturationMode sat, bool hasRelu) {
1509  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1510  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1511  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1512 
1513  switch (type) {
1514  case NVVM::ConvertFP8Type::E4M3:
1515  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1516  case NVVM::ConvertFP8Type::E5M2:
1517  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1518  case NVVM::ConvertFP8Type::UE8M0:
1519  if (hasRoundingModeRZ)
1520  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1521  else if (hasRoundingModeRP)
1522  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1523  }
1524  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1525 }
1526 
1527 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1528  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1529  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1530 
1532 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1533  switch (type) {
1534  case NVVM::ConvertFP8Type::E4M3:
1535  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1536  case NVVM::ConvertFP8Type::E5M2:
1537  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1538  default:
1539  llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1540  }
1541 }
1542 
1543 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1544  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1545  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1546 
1548 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1549  NVVM::SaturationMode sat) {
1550  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1551  switch (rnd) {
1552  case NVVM::FPRoundingMode::RZ:
1553  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1554  case NVVM::FPRoundingMode::RP:
1555  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1556  default:
1557  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1558  }
1559 }
1560 
1562 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1565  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1566  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1567  .getAddressSpace();
1568  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1569  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1570 
1572  if (isShared) {
1573  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1574  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1575  } else {
1576  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1577  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1578  }
1579 
1580  // Fill the Intrinsic Args
1581  args.push_back(mt.lookupValue(curOp.getAddr()));
1582  args.push_back(mt.lookupValue(curOp.getNCols()));
1583 
1584  return id;
1585 }
1586 
1587 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1590  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1591  auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1592  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1593  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1594 
1595  // Fill the Intrinsic Args
1596  args.push_back(mt.lookupValue(curOp.getTaddr()));
1597  args.push_back(mt.lookupValue(curOp.getNCols()));
1598 
1599  return id;
1600 }
1601 
1602 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1603  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1604  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1605 
1606 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1607  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1608  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1609 
1611 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1614  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1615  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1616  .getAddressSpace();
1617  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1618  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1619  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1620 
1621  llvm::Intrinsic::ID id =
1622  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1623  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1624 
1625  // Fill the Intrinsic Args
1626  args.push_back(mt.lookupValue(curOp.getAddr()));
1627  if (hasMulticast)
1628  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1629 
1630  return id;
1631 }
1632 
1633 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1634  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1635 
1636 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1637  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1638  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1639 
1640 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1641  [&]() -> auto { \
1642  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1643  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1644  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1645  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1646  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1647  }()
1648 
1649 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1650  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1651  bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1652  auto srcFmt = curOp.getSrcFormat();
1653  auto mc = curOp.getMulticast();
1654 
1655  switch (curOp.getShape()) {
1656  case Tcgen05CpShape::SHAPE_128x256b:
1657  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1658  case Tcgen05CpShape::SHAPE_128x128b:
1659  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1660  case Tcgen05CpShape::SHAPE_4x256b:
1661  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1662  case Tcgen05CpShape::SHAPE_32x128b:
1663  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1664  case Tcgen05CpShape::SHAPE_64x128b:
1665  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1666  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1667  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1668  }
1669  llvm_unreachable("Invalid shape in tcgen05 cp Op");
1670 }
1671 
1672 // Returns the valid vector length for a given shape and vector length, the
1673 // function models the table mentioned in the tcgen05.{ld, st} Op description
1674 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1675  unsigned vecLen) {
1676  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1677  return vecLen >= 2;
1678  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1679  return vecLen >= 4;
1680  return true;
1681 }
1682 
1683 LogicalResult Tcgen05LdOp::verify() {
1684  LogicalResult result = success();
1685  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1686  result = emitError("shape 16x32bx2 requires offset argument");
1687 
1688  auto resTy = getRes().getType();
1689  unsigned resLen = isa<VectorType>(resTy)
1690  ? llvm::cast<VectorType>(resTy).getNumElements()
1691  : 1;
1692  if (!isValidVectorLength(getShape(), resLen))
1693  result = emitError(llvm::formatv("invalid result type length {0} for shape "
1694  "{1} in tcgen05.ld Op",
1695  resLen, stringifyEnum(getShape())));
1696 
1697  return result;
1698 }
1699 
1700 LogicalResult Tcgen05StOp::verify() {
1701  LogicalResult result = success();
1702  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1703  result = emitError("shape 16x32bx2 requires offset argument");
1704 
1705  auto valTy = getVal().getType();
1706  unsigned valLen = isa<VectorType>(valTy)
1707  ? llvm::cast<VectorType>(valTy).getNumElements()
1708  : 1;
1709  if (!isValidVectorLength(getShape(), valLen))
1710  result = emitError(llvm::formatv("invalid input length {0} for shape "
1711  "{1} in tcgen05.st Op",
1712  valLen, stringifyEnum(getShape())));
1713 
1714  return result;
1715 }
1716 
1717 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1718 /// have ConstantRangeAttr.
1719 static void nvvmInferResultRanges(Operation *op, Value result,
1721  SetIntRangeFn setResultRanges) {
1722  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1723  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1724  rangeAttr.getLower(), rangeAttr.getUpper()});
1725  }
1726 }
1727 
1728 static llvm::Value *getAsPackedI32(llvm::Value *arg,
1729  llvm::IRBuilderBase &builder) {
1730  return builder.CreateBitCast(arg,
1731  llvm::Type::getInt32Ty(builder.getContext()));
1732 }
1733 
1734 NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1735  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1736  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1737 
1739  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1740  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1741  args.push_back(mt.lookupValue(curOp.getC()));
1742 
1743  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1744  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1745  unsigned type = (isASigned << 1) | isBSigned;
1746  const llvm::Intrinsic::ID ids[] = {
1747  llvm::Intrinsic::nvvm_idp4a_u_u,
1748  llvm::Intrinsic::nvvm_idp4a_u_s,
1749  llvm::Intrinsic::nvvm_idp4a_s_u,
1750  llvm::Intrinsic::nvvm_idp4a_s_s,
1751  };
1752  return {ids[type], args};
1753 }
1754 
1755 NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1756  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1757  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1758 
1760  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1761  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1762  args.push_back(builder.getInt1(curOp.getBHi()));
1763  args.push_back(mt.lookupValue(curOp.getC()));
1764 
1765  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1766  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1767  unsigned type = (isASigned << 1) | isBSigned;
1768  const llvm::Intrinsic::ID ids[] = {
1769  llvm::Intrinsic::nvvm_idp2a_u_u,
1770  llvm::Intrinsic::nvvm_idp2a_u_s,
1771  llvm::Intrinsic::nvvm_idp2a_s_u,
1772  llvm::Intrinsic::nvvm_idp2a_s_s,
1773  };
1774  return {ids[type], args};
1775 }
1776 
1777 llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1778  using MemSpace = NVVM::NVVMMemorySpace;
1779  using CacheLevel = NVVM::PrefetchCacheLevel;
1780 
1781  NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1782  std::optional<NVVM::CacheEvictionPriority> evictPriority =
1783  op.getEvictPriority();
1784  unsigned addressSpace =
1785  llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1786  .getAddressSpace();
1787 
1788  if (op.getUniform() && cacheLevel == CacheLevel::L1)
1789  return llvm::Intrinsic::nvvm_prefetchu_L1;
1790 
1791  if (evictPriority && cacheLevel == CacheLevel::L2) {
1792  switch (*evictPriority) {
1793  case NVVM::CacheEvictionPriority::EvictLast:
1794  return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1795  case NVVM::CacheEvictionPriority::EvictNormal:
1796  return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1797  default:
1798  llvm_unreachable("Invalid cache eviction priority");
1799  }
1800  }
1801 
1802  switch (addressSpace) {
1804  return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1805  : llvm::Intrinsic::nvvm_prefetch_L2;
1807  return cacheLevel == CacheLevel::L1
1808  ? llvm::Intrinsic::nvvm_prefetch_global_L1
1809  : llvm::Intrinsic::nvvm_prefetch_global_L2;
1811  return cacheLevel == CacheLevel::L1
1812  ? llvm::Intrinsic::nvvm_prefetch_local_L1
1813  : llvm::Intrinsic::nvvm_prefetch_local_L2;
1814  default:
1815  llvm_unreachable("Invalid pointer address space");
1816  }
1817 }
1818 
1819 //===----------------------------------------------------------------------===//
1820 // NVVMDialect initialization, type parsing, and registration.
1821 //===----------------------------------------------------------------------===//
1822 
1823 // TODO: This should be the llvm.nvvm dialect once this is supported.
1824 void NVVMDialect::initialize() {
1825  addOperations<
1826 #define GET_OP_LIST
1827 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1828  >();
1829  addAttributes<
1830 #define GET_ATTRDEF_LIST
1831 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1832  >();
1833 
1834  // Support unknown operations because not all NVVM operations are
1835  // registered.
1836  allowUnknownOperations();
1837  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1838  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1839 }
1840 
1841 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1842  NamedAttribute attr) {
1843  StringAttr attrName = attr.getName();
1844  // Kernel function attribute should be attached to functions.
1845  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1846  if (!isa<LLVM::LLVMFuncOp>(op)) {
1847  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1848  << "' attribute attached to unexpected op";
1849  }
1850  }
1851  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1852  // dim
1853  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1854  attrName == NVVMDialect::getReqntidAttrName() ||
1855  attrName == NVVMDialect::getClusterDimAttrName()) {
1856  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1857  if (!values || values.empty() || values.size() > 3)
1858  return op->emitError()
1859  << "'" << attrName
1860  << "' attribute must be integer array with maximum 3 index";
1861  }
1862  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1863  // attribute
1864  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1865  attrName == NVVMDialect::getMaxnregAttrName() ||
1866  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1867  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1868  return op->emitError()
1869  << "'" << attrName << "' attribute must be integer constant";
1870  }
1871 
1872  return success();
1873 }
1874 
1875 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1876  unsigned regionIndex,
1877  unsigned argIndex,
1878  NamedAttribute argAttr) {
1879  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1880  if (!funcOp)
1881  return success();
1882 
1883  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1884  StringAttr attrName = argAttr.getName();
1885  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1886  if (!isKernel) {
1887  return op->emitError()
1888  << "'" << attrName
1889  << "' attribute must be present only on kernel arguments";
1890  }
1891  if (!isa<UnitAttr>(argAttr.getValue()))
1892  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1893  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1894  return op->emitError()
1895  << "'" << attrName
1896  << "' attribute requires the argument to also have attribute '"
1897  << LLVM::LLVMDialect::getByValAttrName() << "'";
1898  }
1899  }
1900 
1901  return success();
1902 }
1903 
1904 //===----------------------------------------------------------------------===//
1905 // NVVM target attribute.
1906 //===----------------------------------------------------------------------===//
1907 LogicalResult
1909  int optLevel, StringRef triple, StringRef chip,
1910  StringRef features, DictionaryAttr flags,
1911  ArrayAttr files, bool verifyTarget) {
1912  if (optLevel < 0 || optLevel > 3) {
1913  emitError() << "The optimization level must be a number between 0 and 3.";
1914  return failure();
1915  }
1916  if (triple.empty()) {
1917  emitError() << "The target triple cannot be empty.";
1918  return failure();
1919  }
1920  if (chip.empty()) {
1921  emitError() << "The target chip cannot be empty.";
1922  return failure();
1923  }
1924  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1925  return mlir::isa_and_nonnull<StringAttr>(attr);
1926  })) {
1927  emitError() << "All the elements in the `link` array must be strings.";
1928  return failure();
1929  }
1930  return success();
1931 }
1932 
1933 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
1934  if (!getVerifyTarget())
1935  return success();
1936 
1937  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1938  if (!gpuModuleOp) {
1939  return emitError(gpuModule->getLoc(),
1940  "NVVM target attribute must be attached to a GPU module");
1941  }
1942 
1943  const NVVMCheckSMVersion targetSMVersion =
1944  NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1945  if (!targetSMVersion.isMinimumSMVersion()) {
1946  return emitError(gpuModule->getLoc(),
1947  "Minimum NVVM target SM version is sm_20");
1948  }
1949 
1950  gpuModuleOp->walk([&](Operation *op) {
1951  if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1952  const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
1953  if (!requirement.isCompatibleWith(targetSMVersion)) {
1954  op->emitOpError() << "is not supported on " << getChip();
1955  return WalkResult::interrupt();
1956  }
1957  }
1958  return WalkResult::advance();
1959  });
1960 
1961  return success();
1962 }
1963 
1964 #define GET_OP_CLASSES
1965 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1966 
1967 #define GET_ATTRDEF_CLASSES
1968 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1225::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
#define _none
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:55
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI32Type()
Definition: Builders.cpp:62
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition: NVVMDialect.h:61
NVVMMemorySpace
NVVM memory space identifiers.
Definition: NVVMDialect.h:38
@ kGenericMemorySpace
Generic memory space identifier.
Definition: NVVMDialect.h:40
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
@ kLocalMemorySpace
Local memory space identifier.
Definition: NVVMDialect.h:48
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:44
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:85
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)