MLIR  21.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
23 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/Types.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/AsmParser/Parser.h"
35 #include "llvm/IR/Attributes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/IntrinsicsNVPTX.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/SourceMgr.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <cassert>
45 #include <optional>
46 #include <string>
47 
48 using namespace mlir;
49 using namespace NVVM;
50 
51 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
52 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
53 
54 //===----------------------------------------------------------------------===//
55 // Verifier methods
56 //===----------------------------------------------------------------------===//
57 
58 // This verifier is shared among the following Ops:
59 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
60 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
61 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
62 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
63  bool isIm2Col,
64  size_t numIm2ColOffsets,
65  Location loc) {
66  if (tensorDims < 1 || tensorDims > 5)
67  return emitError(loc, "expects coordinates between 1 to 5 dimension");
68 
69  // For Im2Col mode, there are two constraints:
70  if (isIm2Col) {
71  // 1. Tensor must always be at least 3-d.
72  if (tensorDims < 3)
73  return emitError(
74  loc,
75  "to use im2col mode, the tensor has to be at least 3-dimensional");
76  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
77  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
78  return emitError(
79  loc, "im2col offsets must be 2 less than number of coordinates");
80  }
81  return success();
82 }
83 
85  size_t numIm2ColOffsets = getIm2colOffsets().size();
86  bool isIm2Col = numIm2ColOffsets > 0;
87  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
88  numIm2ColOffsets, getLoc());
89 }
90 
92  if (getCoordinates().size() > 5)
93  return emitError("Maximum 5 coordinates and dimension is supported.");
94  return success();
95 }
96 
97 LogicalResult CpAsyncOp::verify() {
98  if (getModifier() != LoadCacheModifierKind::CG &&
99  getModifier() != LoadCacheModifierKind::CA)
100  return emitError("Only CG and CA cache modifiers are supported.");
101  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
102  return emitError("expected byte size to be either 4, 8 or 16.");
103  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
104  return emitError("CG cache modifier is only support for 16 bytes copy.");
105  return success();
106 }
107 
108 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
109  size_t numIm2ColOffsets = getIm2colOffsets().size();
110  bool isIm2Col = numIm2ColOffsets > 0;
111  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
112  numIm2ColOffsets, getLoc());
113 }
114 
115 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
116  bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
117  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
118  getLoc());
119 }
120 
121 LogicalResult ConvertFloatToTF32Op::verify() {
122  using RndMode = NVVM::FPRoundingMode;
123  switch (getRnd()) {
124  case RndMode::RNA:
125  if (getRelu())
126  return emitError("Relu not supported with rna rounding mode.");
127  break;
128  case RndMode::RN:
129  case RndMode::RZ:
130  break;
131  default:
132  return emitError(
133  "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
134  }
135  return success();
136 }
137 
138 LogicalResult ConvertF32x2ToF8x2Op::verify() {
139  using RndMode = NVVM::FPRoundingMode;
140  using SatMode = NVVM::SaturationMode;
141 
142  bool isRoundingModeRN = getRnd() == RndMode::RN;
143  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
144  bool isRoundingModeRP = getRnd() == RndMode::RP;
145  bool isSatFinite = getSat() == SatMode::SATFINITE;
146 
147  bool hasRelu = getRelu();
148 
149  switch (getType()) {
150  case ConvertFP8Type::E4M3:
151  case ConvertFP8Type::E5M2:
152  if (!isRoundingModeRN)
153  return emitOpError("Only RN rounding mode is supported for conversions "
154  "from f32x2 to .e4m3x2 or .e5m2x2 types");
155  if (!isSatFinite)
156  return emitOpError("Only SATFINITE saturation mode is supported for "
157  "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
158  break;
159  case ConvertFP8Type::UE8M0:
160  if (!(isRoundingModeRZ || isRoundingModeRP))
161  return emitOpError("Only RZ or RP rounding modes are supported for "
162  "conversions from f32x2 to .ue8m0x2 type");
163  if (hasRelu)
164  return emitOpError("relu not supported for conversions to .ue8m0x2 type");
165  break;
166  }
167  return success();
168 }
169 
170 LogicalResult ConvertF16x2ToF8x2Op::verify() {
171  if (getType() == ConvertFP8Type::UE8M0)
172  return emitOpError("Only .e4m3 or .e5m2 types are supported for "
173  "conversions from f16x2 to f8x2.");
174 
175  return success();
176 }
177 
178 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
179  using RndMode = NVVM::FPRoundingMode;
180 
181  if (getType() != ConvertFP8Type::UE8M0)
182  return emitOpError(
183  "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
184 
185  auto rnd = getRnd();
186  if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
187  return emitOpError("Only RZ and RP rounding modes are supported for "
188  "conversions from bf16x2 to f8x2.");
189 
190  return success();
191 }
192 
193 LogicalResult BulkStoreOp::verify() {
194  if (getInitVal() != 0)
195  return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
196  return success();
197 }
198 
199 // Given the element type of an operand and whether or not it is an accumulator,
200 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
201 // operand's element type.
202 std::optional<mlir::NVVM::MMATypes>
203 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
204  auto half2Type =
205  VectorType::get(2, Float16Type::get(operandElType.getContext()));
206  if (operandElType.isF64())
207  return NVVM::MMATypes::f64;
208  if (operandElType.isF16() || operandElType == half2Type)
209  return NVVM::MMATypes::f16;
210  if (operandElType.isF32() && isAccumulator)
211  return NVVM::MMATypes::f32;
212  if (operandElType.isF32() && !isAccumulator)
213  return NVVM::MMATypes::tf32;
214  if (llvm::isa<IntegerType>(operandElType)) {
215  if (isAccumulator)
216  return NVVM::MMATypes::s32;
217  return std::nullopt;
218  }
219 
220  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
221  if (structType.getBody().empty())
222  return std::nullopt;
223  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
224  }
225 
226  return std::nullopt;
227 }
228 
229 static bool isInt4PtxType(MMATypes type) {
230  return (type == MMATypes::u4 || type == MMATypes::s4);
231 }
232 
233 static bool isInt8PtxType(MMATypes type) {
234  return (type == MMATypes::u8 || type == MMATypes::s8);
235 }
236 
237 static bool isIntegerPtxType(MMATypes type) {
238  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
239  type == MMATypes::s32;
240 }
241 
242 MMATypes MmaOp::accumPtxType() {
243  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
244  getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
245  assert(val.has_value() && "accumulator PTX type should always be inferrable");
246  return val.value();
247 }
248 
249 MMATypes MmaOp::resultPtxType() {
250  std::optional<mlir::NVVM::MMATypes> val =
251  inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
252  assert(val.has_value() && "result PTX type should always be inferrable");
253  return val.value();
254 }
255 
256 void MmaOp::print(OpAsmPrinter &p) {
257  SmallVector<Type, 4> regTypes;
258  struct OperandFragment {
259  StringRef operandName;
260  StringRef ptxTypeAttr;
262  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
263  : operandName(name), ptxTypeAttr(ptxTypeName) {}
264  };
265 
266  std::array<OperandFragment, 3> frags{
267  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
268  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
269  OperandFragment("C", "")};
270  SmallVector<StringRef, 4> ignoreAttrNames{
271  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
272 
273  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
274  auto &frag = frags[fragIdx];
275  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
276  for (auto operandIdx = varOperandSpec.first;
277  operandIdx < varOperandSpec.first + varOperandSpec.second;
278  operandIdx++) {
279  frag.regs.push_back(this->getOperand(operandIdx));
280  if (operandIdx == 0) {
281  regTypes.push_back(this->getOperand(operandIdx).getType());
282  }
283  }
284  std::optional<MMATypes> inferredType =
285  inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
286  if (inferredType)
287  ignoreAttrNames.push_back(frag.ptxTypeAttr);
288  }
289 
290  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
291  p << " " << frag.operandName;
292  p << "[";
293  p.printOperands(frag.regs);
294  p << "] ";
295  };
296 
297  for (const auto &frag : frags) {
298  printMmaOperand(frag);
299  }
300 
301  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
302 
303  // Print the types of the operands and result.
304  p << " : "
305  << "(";
306  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
307  frags[1].regs[0].getType(),
308  frags[2].regs[0].getType()},
309  p);
310  p << ")";
311  p.printArrowTypeList(TypeRange{this->getRes().getType()});
312 }
313 
314 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
315  ValueRange operandA, ValueRange operandB, ValueRange operandC,
316  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
317  std::optional<MMAIntOverflow> intOverflow,
318  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
319  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
320 
321  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
322  MLIRContext *ctx = builder.getContext();
323  result.addAttribute(
324  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
325 
326  result.addOperands(operandA);
327  result.addOperands(operandB);
328  result.addOperands(operandC);
329 
330  if (multiplicandPtxTypes) {
331  result.addAttribute("multiplicandAPtxType",
332  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
333  result.addAttribute("multiplicandBPtxType",
334  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
335  } else {
336  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
337  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
338  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
339  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
340  }
341 
342  if (multiplicandLayouts) {
343  result.addAttribute("layoutA",
344  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
345  result.addAttribute("layoutB",
346  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
347  } else {
348  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
349  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
350  }
351 
352  if (intOverflow.has_value())
353  result.addAttribute("intOverflowBehavior",
354  MMAIntOverflowAttr::get(ctx, *intOverflow));
355  if (b1Op.has_value())
356  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
357 
358  result.addTypes(resultType);
359  result.addAttribute(
360  MmaOp::getOperandSegmentSizeAttr(),
361  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
362  static_cast<int32_t>(operandB.size()),
363  static_cast<int32_t>(operandC.size())}));
364 }
365 
366 // <operation> :=
367 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
368 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
369 // `->` type($res)
370 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
371  struct OperandFragment {
372  std::optional<MMATypes> elemtype;
374  SmallVector<Type> regTypes;
375  };
376 
377  Builder &builder = parser.getBuilder();
378  std::array<OperandFragment, 4> frags;
379 
380  NamedAttrList namedAttributes;
381 
382  // A helper to parse the operand segments.
383  auto parseMmaOperand = [&](StringRef operandName,
384  OperandFragment &frag) -> LogicalResult {
385  if (parser.parseKeyword(operandName).failed())
386  return failure();
387  if (parser
388  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
389  .failed())
390  return failure();
391  return success();
392  };
393 
394  // Parse the operand segments.
395  if (parseMmaOperand("A", frags[0]).failed())
396  return failure();
397  if (parseMmaOperand("B", frags[1]).failed())
398  return failure();
399  if (parseMmaOperand("C", frags[2]).failed())
400  return failure();
401 
402  if (parser.parseOptionalAttrDict(namedAttributes).failed())
403  return failure();
404 
405  // Parse the type specification and resolve operands.
406  SmallVector<Type, 3> operandTypes;
407  if (failed(parser.parseColon()))
408  return failure();
409  if (failed(parser.parseLParen()))
410  return failure();
411  if (failed(parser.parseTypeList(operandTypes)))
412  return failure();
413  if (failed(parser.parseRParen()))
414  if (operandTypes.size() != 3)
415  return parser.emitError(
416  parser.getNameLoc(),
417  "expected one type for each operand segment but got " +
418  Twine(operandTypes.size()) + " types");
419  for (const auto &iter : llvm::enumerate(operandTypes)) {
420  auto &frag = frags[iter.index()];
421  frag.regTypes.resize(frag.regs.size(), iter.value());
422  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
423  parser.getNameLoc(), result.operands)))
424  return failure();
425  frag.elemtype = inferOperandMMAType(frag.regTypes[0],
426  /*isAccumulator*/ iter.index() < 2);
427  }
428 
429  Type resultType;
430  if (parser.parseArrow() || parser.parseType(resultType))
431  return failure();
432  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
433 
434  std::array<StringRef, 2> names{"multiplicandAPtxType",
435  "multiplicandBPtxType"};
436  for (unsigned idx = 0; idx < names.size(); idx++) {
437  const auto &frag = frags[idx];
438  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
439  if (!frag.elemtype.has_value() && !attr.has_value()) {
440  return parser.emitError(
441  parser.getNameLoc(),
442  "attribute " + names[idx] +
443  " is not provided explicitly and cannot be inferred");
444  }
445  if (!attr.has_value())
446  result.addAttribute(
447  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
448  }
449 
450  result.addTypes(resultType);
451  if (!namedAttributes.empty())
452  result.addAttributes(namedAttributes);
453  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
454  builder.getDenseI32ArrayAttr({
455  static_cast<int32_t>(frags[0].regs.size()),
456  static_cast<int32_t>(frags[1].regs.size()),
457  static_cast<int32_t>(frags[2].regs.size()),
458  }));
459  return success();
460 }
461 
462 LogicalResult MmaOp::verify() {
463  MLIRContext *context = getContext();
464  auto f16Ty = Float16Type::get(context);
465  auto i32Ty = IntegerType::get(context, 32);
466  auto f16x2Ty = VectorType::get(2, f16Ty);
467  auto f32Ty = Float32Type::get(context);
468  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
469  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
470 
471  auto s32x4StructTy =
472  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
473  auto f32x8StructTy =
474  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
475  auto f16x2x2StructTy =
476  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
477  auto f32x4StructTy =
478  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
479  auto s32x2StructTy =
480  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
481 
482  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
483  getShapeAttr().getK()};
484 
485  // These variables define the set of allowed data types for matrices A, B, C,
486  // and result.
487  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
488  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
489  AllowedShapes allowedShapes;
490  AllowedTypes expectedA;
491  AllowedTypes expectedB;
492  AllowedTypes expectedC;
493  SmallVector<Type> expectedResult;
494 
495  // When M = 16, we just need to calculate the number of 8xk tiles, where
496  // k is a factor that depends on the data type.
497  if (mmaShape[0] == 16) {
498  int64_t kFactor;
499  Type multiplicandFragType;
500  switch (*getMultiplicandAPtxType()) {
501  case MMATypes::tf32:
502  kFactor = 4;
503  multiplicandFragType = i32Ty;
504  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
505  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
506  break;
507  case MMATypes::bf16:
508  kFactor = 8;
509  multiplicandFragType = i32Ty;
510  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
511  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
512  break;
513  case MMATypes::f16:
514  kFactor = 8;
515  multiplicandFragType = f16x2Ty;
516  expectedResult.push_back(f16x2x2StructTy);
517  expectedResult.push_back(f32x4StructTy);
518  break;
519  case MMATypes::s4:
520  case MMATypes::u4:
521  kFactor = 32;
522  break;
523  case MMATypes::b1:
524  kFactor = 128;
525  break;
526  case MMATypes::s8:
527  case MMATypes::u8:
528  kFactor = 16;
529  break;
530  default:
531  return emitError("invalid shape or multiplicand type: " +
532  stringifyEnum(getMultiplicandAPtxType().value()));
533  }
534 
535  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
536  expectedResult.push_back(s32x4StructTy);
537  expectedC.emplace_back(4, i32Ty);
538  multiplicandFragType = i32Ty;
539  } else {
540  expectedC.emplace_back(2, f16x2Ty);
541  expectedC.emplace_back(4, f32Ty);
542  }
543 
544  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
545  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
546  expectedA.emplace_back(unitA, multiplicandFragType);
547  expectedB.emplace_back(unitB, multiplicandFragType);
548  allowedShapes.push_back({16, 8, kFactor});
549  allowedShapes.push_back({16, 8, kFactor * 2});
550  }
551 
552  // In the M=8 case, there is only 1 possible case per data type.
553  if (mmaShape[0] == 8) {
554  if (*getMultiplicandAPtxType() == MMATypes::f16) {
555  expectedA.emplace_back(2, f16x2Ty);
556  expectedB.emplace_back(2, f16x2Ty);
557  expectedResult.push_back(f16x2x4StructTy);
558  expectedResult.push_back(f32x8StructTy);
559  expectedC.emplace_back(4, f16x2Ty);
560  expectedC.emplace_back(8, f32Ty);
561  allowedShapes.push_back({8, 8, 4});
562  }
563  if (*getMultiplicandAPtxType() == MMATypes::f64) {
564  Type f64Ty = Float64Type::get(context);
565  expectedA.emplace_back(1, f64Ty);
566  expectedB.emplace_back(1, f64Ty);
567  expectedC.emplace_back(2, f64Ty);
568  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
569  context, SmallVector<Type>(2, f64Ty)));
570  allowedShapes.push_back({8, 8, 4});
571  }
572  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
573  expectedA.push_back({i32Ty});
574  expectedB.push_back({i32Ty});
575  expectedC.push_back({i32Ty, i32Ty});
576  expectedResult.push_back(s32x2StructTy);
577  if (isInt4PtxType(getMultiplicandAPtxType().value()))
578  allowedShapes.push_back({8, 8, 32});
579  if (isInt8PtxType(getMultiplicandAPtxType().value()))
580  allowedShapes.push_back({8, 8, 16});
581  if (getMultiplicandAPtxType().value() == MMATypes::b1)
582  allowedShapes.push_back({8, 8, 128});
583  }
584  }
585 
586  std::string errorMessage;
587  llvm::raw_string_ostream errorStream(errorMessage);
588 
589  // Check that we matched an existing shape/dtype combination.
590  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
591  !llvm::is_contained(allowedShapes, mmaShape)) {
592  errorStream << "unimplemented variant for MMA shape <";
593  llvm::interleaveComma(mmaShape, errorStream);
594  errorStream << ">";
595  return emitOpError(errorMessage);
596  }
597 
598  // Verify the operand types for segments of A, B, and C operands.
599  std::array<StringRef, 3> operandNames{"A", "B", "C"};
600  for (const auto &iter : llvm::enumerate(
601  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
602  auto spec = this->getODSOperandIndexAndLength(iter.index());
603  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
604  operand_type_begin() + spec.first +
605  spec.second);
606  bool match = llvm::is_contained(iter.value(), operandTySeg);
607 
608  if (!match) {
609  errorStream << "Could not match types for the "
610  << operandNames[iter.index()]
611  << " operands; expected one of ";
612  for (const auto &x : iter.value()) {
613  errorStream << x.size() << "x" << x[0] << " ";
614  }
615  errorStream << "but got ";
616  llvm::interleaveComma(operandTySeg, errorStream);
617  return emitOpError(errorMessage);
618  }
619  }
620 
621  // Check the result type
622  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
623  return expectedResultType == getResult().getType();
624  })) {
625  errorStream
626  << "Could not match allowed types for the result; expected one of ";
627  llvm::interleaveComma(expectedResult, errorStream);
628  errorStream << " but got " << getResult().getType();
629  return emitOpError(errorMessage);
630  }
631 
632  // Ensure that binary MMA variants have a b1 MMA operation defined.
633  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
634  return emitOpError("op requires " + getB1OpAttrName().strref() +
635  " attribute");
636  }
637 
638  // Ensure int4/int8 MMA variants specify the accum overflow behavior
639  // attribute.
640  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
641  isInt8PtxType(*getMultiplicandAPtxType())) {
642  if (!getIntOverflowBehavior())
643  return emitOpError("op requires " +
644  getIntOverflowBehaviorAttrName().strref() +
645  " attribute");
646  }
647 
648  return success();
649 }
650 
651 LogicalResult ShflOp::verify() {
652  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
653  return success();
654  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
655  auto elementType = (type && type.getBody().size() == 2)
656  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
657  : nullptr;
658  if (!elementType || elementType.getWidth() != 1)
659  return emitError("expected return type to be a two-element struct with "
660  "i1 as the second element");
661  return success();
662 }
663 
664 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
665  NVVM::MMAFrag frag, int nRow,
666  int nCol,
667  MLIRContext *context) {
668  unsigned numberElements = 0;
669  Type elementType;
670  OpBuilder builder(context);
671  Type f16x2 = VectorType::get(2, builder.getF16Type());
672  if (type == NVVM::MMATypes::f16) {
673  elementType = f16x2;
674  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
675  numberElements = 8;
676  else
677  numberElements = 4;
678  } else if (type == NVVM::MMATypes::f32) {
679  elementType = builder.getF32Type();
680  numberElements = 8;
681  } else if (type == NVVM::MMATypes::tf32) {
682  elementType = builder.getI32Type();
683  numberElements = 4;
684  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
685  elementType = builder.getI32Type();
686  int parallelSize = 0;
687  if (frag == NVVM::MMAFrag::a)
688  parallelSize = nRow;
689  if (frag == NVVM::MMAFrag::b)
690  parallelSize = nCol;
691 
692  // m == 16 && n == 16 && k == 16
693  if (parallelSize == 16)
694  numberElements = 2;
695  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
696  else if (parallelSize == 8)
697  numberElements = 1;
698  else if (parallelSize == 32)
699  numberElements = 4;
700  } else if (type == NVVM::MMATypes::s32) {
701  elementType = builder.getI32Type();
702  numberElements = 8;
703  }
704  assert(numberElements != 0 && elementType != nullptr);
705  return std::make_pair(elementType, numberElements);
706 }
707 
708 static std::pair<mlir::Type, unsigned>
709 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
710  int k, MLIRContext *context) {
711  int nRow, nCol;
712  if (frag == NVVM::MMAFrag::a) {
713  nRow = m;
714  nCol = k;
715  } else if (frag == NVVM::MMAFrag::b) {
716  nRow = k;
717  nCol = n;
718  } else {
719  nRow = m;
720  nCol = n;
721  }
722  assert(nRow && nCol);
723  return inferMMAType(type, frag, nRow, nCol, context);
724 }
725 
726 LogicalResult NVVM::WMMALoadOp::verify() {
727  unsigned addressSpace =
728  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
729  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
730  addressSpace != NVVM::kSharedMemorySpace)
731  return emitOpError("expected source pointer in memory "
732  "space 0, 1, 3");
733 
734  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
735  getEltype(), getFrag()) == 0)
736  return emitOpError() << "invalid attribute combination";
737  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
738  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
739  Type dstType = LLVM::LLVMStructType::getLiteral(
740  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
741  if (getType() != dstType)
742  return emitOpError("expected destination type is a structure of ")
743  << typeInfo.second << " elements of type " << typeInfo.first;
744  return success();
745 }
746 
747 LogicalResult NVVM::WMMAStoreOp::verify() {
748  unsigned addressSpace =
749  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
750  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
751  addressSpace != NVVM::kSharedMemorySpace)
752  return emitOpError("expected operands to be a source pointer in memory "
753  "space 0, 1, 3");
754 
755  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
756  getEltype()) == 0)
757  return emitOpError() << "invalid attribute combination";
758  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
759  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
760  if (getArgs().size() != typeInfo.second)
761  return emitOpError() << "expected " << typeInfo.second << " data operands";
762  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
763  return operands.getType() != typeInfo.first;
764  }))
765  return emitOpError() << "expected data operands of type " << typeInfo.first;
766  return success();
767 }
768 
769 LogicalResult NVVM::WMMAMmaOp::verify() {
770  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
771  getLayoutB(), getEltypeA(),
772  getEltypeB()) == 0)
773  return emitOpError() << "invalid attribute combination";
774  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
775  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
776  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
777  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
778  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
779  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
780  SmallVector<Type, 32> arguments;
781  arguments.append(typeInfoA.second, typeInfoA.first);
782  arguments.append(typeInfoB.second, typeInfoB.first);
783  arguments.append(typeInfoC.second, typeInfoC.first);
784  unsigned numArgs = arguments.size();
785  if (getArgs().size() != numArgs)
786  return emitOpError() << "expected " << numArgs << " arguments";
787  for (unsigned i = 0; i < numArgs; i++) {
788  if (getArgs()[i].getType() != arguments[i])
789  return emitOpError() << "expected argument " << i << " to be of type "
790  << arguments[i];
791  }
792  Type dstType = LLVM::LLVMStructType::getLiteral(
793  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
794  if (getType() != dstType)
795  return emitOpError("expected destination type is a structure of ")
796  << typeInfoC.second << " elements of type " << typeInfoC.first;
797  return success();
798 }
799 
800 LogicalResult NVVM::LdMatrixOp::verify() {
801  unsigned addressSpace =
802  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
803  if (addressSpace != NVVM::kSharedMemorySpace)
804  return emitOpError("expected source pointer in memory space 3");
805 
806  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
807  return emitOpError("expected num attribute to be 1, 2 or 4");
808 
809  Type i32 = IntegerType::get(getContext(), 32);
810  if (getNum() == 1 && getType() != i32)
811  return emitOpError("expected destination type is i32");
812  if (getNum() == 2 || getNum() == 4) {
813  Type dstType = LLVM::LLVMStructType::getLiteral(
814  getContext(), SmallVector<Type>(getNum(), i32));
815  if (getType() != dstType)
816  return emitOpError("expected destination type is a structure of ")
817  << getNum() << " elements of type i32";
818  }
819  return success();
820 }
821 
822 LogicalResult NVVM::StMatrixOp::verify() {
823  unsigned addressSpace =
824  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
825  if (addressSpace != NVVM::kSharedMemorySpace)
826  return emitOpError("expected source pointer in memory space 3");
827 
828  int numMatrix = getSources().size();
829  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
830  return emitOpError("expected num attribute to be 1, 2 or 4");
831 
832  return success();
833 }
834 
835 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
836  if (typeA == NVVM::WGMMATypes::tf32)
837  return 8;
838  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
839  return 16;
840  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
841  return 32;
842  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
843  return 32;
844  if (typeA == NVVM::WGMMATypes::b1)
845  return 256;
846  return failure();
847 }
848 
849 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
850  NVVM::WGMMATypes typeA,
851  NVVM::WGMMATypes typeB) {
852  switch (typeA) {
853  case NVVM::WGMMATypes::f16:
854  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
855  typeB == NVVM::WGMMATypes::f16)
856  return success();
857  break;
858  case NVVM::WGMMATypes::tf32:
859  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
860  return success();
861  break;
862  case NVVM::WGMMATypes::u8:
863  case NVVM::WGMMATypes::s8:
864  if (typeD == NVVM::WGMMATypes::s32 &&
865  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
866  return success();
867  break;
868  case NVVM::WGMMATypes::b1:
869  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
870  return success();
871  break;
872  case NVVM::WGMMATypes::bf16:
873  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
874  typeB == NVVM::WGMMATypes::bf16)
875  return success();
876  break;
877  case NVVM::WGMMATypes::e4m3:
878  case NVVM::WGMMATypes::e5m2:
879  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
880  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
881  return success();
882  break;
883  case WGMMATypes::f32:
884  case WGMMATypes::s32:
885  llvm_unreachable("unsupported input types");
886  break;
887  }
888  return failure();
889 }
890 
891 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
892  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
893  72, 80, 88, 96, 104, 112, 120, 128,
894  136, 144, 152, 160, 168, 176, 184, 192,
895  200, 208, 216, 224, 232, 240, 248, 256};
896  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
897  80, 96, 112, 128, 144, 160,
898  176, 192, 208, 224, 240, 256};
899  switch (typeA) {
900  case WGMMATypes::f16:
901  case WGMMATypes::tf32:
902  case WGMMATypes::bf16:
903  case WGMMATypes::e4m3:
904  case WGMMATypes::e5m2:
905  if (llvm::is_contained(allowedN, sizeN))
906  return success();
907  break;
908  case WGMMATypes::u8:
909  case WGMMATypes::s8:
910  case WGMMATypes::b1:
911  if (llvm::is_contained(allowedNshort, sizeN))
912  return success();
913  break;
914  case WGMMATypes::f32:
915  case WGMMATypes::s32:
916  llvm_unreachable("unsupported input types");
917  break;
918  }
919  return failure();
920 }
921 
922 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
923  Value outValue = getResults();
924  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
925  if (!stype)
926  return emitOpError() << "expected results to be struct";
927  int outputSize = stype.getBody().size();
928  WGMMATypes typeD = getTypeD();
929  WGMMATypes typeA = getTypeA();
930  WGMMATypes typeB = getTypeB();
931 
932  for (Type t : stype.getBody()) {
933  if (t != stype.getBody().front())
934  return emitOpError()
935  << "all elements in struct must be same type but there is " << t;
936  }
937 
938  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
939  typeD != WGMMATypes::s32) {
940  return emitOpError() << "does not support the given output type "
941  << NVVM::stringifyWGMMATypes(typeD);
942  }
943  if (typeD == WGMMATypes::s32 &&
944  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
945  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
946  }
947 
948  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
949  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
950  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
951  << NVVM::stringifyWGMMATypes(typeB)
952  << ", it is not supported.";
953  }
954 
955  // Check M
956  if (getShape().getM() != 64)
957  return emitOpError() << "shape 'm' must be 64";
958 
959  // Check K
960  FailureOr<int> allowedK = getAllowedSizeK(typeA);
961  if (failed(allowedK) || allowedK.value() != getShape().getK())
962  return emitOpError() << "shape 'k' must be " << allowedK.value()
963  << " for input type "
964  << NVVM::stringifyWGMMATypes(typeA);
965 
966  // Check N
967  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
968  return emitOpError() << "has input type "
969  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
970  << getShape().getN() << ", it is not supported.";
971  }
972 
973  // Check transpose (only available for f16/bf16)
974  // Matrices A should be stored in row-major and B in column-major.
975  // Only f16/bf16 matrices can be stored in either column-major or row-major
976  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
977  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
978  (getLayoutA() == mlir::NVVM::MMALayout::col ||
979  getLayoutB() == mlir::NVVM::MMALayout::row)) {
980  return emitOpError()
981  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
982  << " and layout_b = " << stringifyMMALayout(getLayoutB())
983  << " for input types " << stringifyWGMMATypes(typeA) << " and "
984  << stringifyWGMMATypes(typeB)
985  << " requires transpose. However, this is only supported for: "
986  << stringifyMMATypes(MMATypes::f16) << " and "
987  << stringifyMMATypes(MMATypes::bf16);
988  }
989 
990  // Check result registers
991  int expectedOutput = 0;
992  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
993  expectedOutput = getShape().getN() / 2;
994  if (typeD == WGMMATypes::f16)
995  expectedOutput = getShape().getN() / 4;
996  if (outputSize != expectedOutput) {
997  return emitOpError() << "results " << expectedOutput
998  << ", however output struct has " << outputSize
999  << " elements";
1000  }
1001  // Check satfinite (only available for s32 accumulator)
1002  if (typeD != WGMMATypes::s32 &&
1003  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1004  NVVM::MMAIntOverflow::satfinite) {
1005  return emitOpError()
1006  << " `satfinite` can be only used with s32 accumulator, however "
1007  "the current accumulator is "
1008  << NVVM::stringifyWGMMATypes(typeD);
1009  }
1010 
1011  return success();
1012 }
1013 
1014 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1015 
1016  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1017  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1018 
1019  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1020 
1021  int expectedOutputRegisters = 0;
1022  if (getTypeD() == WGMMATypes::f16)
1023  expectedOutputRegisters = getShape().getN() / 4;
1024  else
1025  expectedOutputRegisters = getShape().getN() / 2;
1026 
1027  std::string ptx;
1028  llvm::raw_string_ostream ss(ptx);
1029 
1030  ss << "{\n"
1031  ".reg .pred p;\n"
1032  "setp.ne.b32 p, $"
1033  << ((expectedOutputRegisters * 2) + 2)
1034  << ", 0;\n"
1035  "wgmma.mma_async.sync.aligned.m"
1036  << m << "n" << n << "k" << k << "." << outputTypeName << "."
1037  << stringifyWGMMATypes(getTypeA()) << "."
1038  << stringifyWGMMATypes(getTypeB());
1039  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1040  NVVM::MMAIntOverflow::satfinite)
1041  ss << ".satfinite";
1042  ss << " {";
1043  int regCnt = 0;
1044  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1045  ss << "$" << regCnt;
1046  if (regCnt != expectedOutputRegisters - 1)
1047  ss << ", ";
1048  }
1049 
1050  ss << "},";
1051  // Need to map read/write registers correctly.
1052  regCnt = (regCnt * 2);
1053  ss << " $" << (regCnt) << ","
1054  << " $" << (regCnt + 1) << ","
1055  << " p";
1056  if (getTypeD() != WGMMATypes::s32) {
1057  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1058  }
1059  // Don't add transpose parameters unless needed.
1060  if (isF16) {
1061  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1062  }
1063  ss << ";\n"
1064  << "}\n";
1065  return ptx;
1066 }
1067 
1068 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1069  RewriterBase &rewriter,
1070  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1071  &asmValues) {
1072  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1073  if (getResults())
1074  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1075  if (getInouts())
1076  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1077  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1078  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1079  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1081  if (getTypeD() != WGMMATypes::s32) {
1082  asmValues.push_back(
1083  {makeConstantI32(rewriter,
1084  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1086  asmValues.push_back(
1087  {makeConstantI32(rewriter,
1088  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1090  }
1091  if (isF16) {
1092  asmValues.push_back(
1093  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1095  asmValues.push_back(
1096  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1098  }
1099 }
1100 LogicalResult NVVM::FenceProxyOp::verify() {
1101  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1102  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1103  if (getKind() == NVVM::ProxyKind::GENERIC)
1104  return emitOpError() << "generic proxy not a supported proxy kind";
1105  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1106  return emitOpError() << "async_shared fence requires space attribute";
1107  }
1108  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1109  return emitOpError() << "only async_shared fence can have space attribute";
1110  }
1111  return success();
1112 }
1113 
1114 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1115  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1116  return emitOpError("uni-directional proxies only support generic for "
1117  "from_proxy attribute");
1118 
1119  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1120  return emitOpError("uni-directional proxies only support tensormap "
1121  "for to_proxy attribute");
1122 
1123  return success();
1124 }
1125 
1126 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1127  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1128  return emitOpError("uni-directional proxies only support generic for "
1129  "from_proxy attribute");
1130 
1131  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1132  return emitOpError("uni-directional proxies only support tensormap "
1133  "for to_proxy attribute");
1134 
1135  return success();
1136 }
1137 
1138 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1139  if (getRegCount() % 8)
1140  return emitOpError("new register size must be multiple of 8");
1141  if (getRegCount() < 24 || getRegCount() > 256)
1142  return emitOpError("new register size must be in between 24 to 256");
1143  return success();
1144 }
1145 
1146 LogicalResult NVVM::BarrierOp::verify() {
1147  if (getNumberOfThreads() && !getBarrierId())
1148  return emitOpError(
1149  "barrier id is missing, it should be set between 0 to 15");
1150  return success();
1151 }
1152 
1153 LogicalResult NVVM::Tcgen05CpOp::verify() {
1154  auto mc = getMulticast();
1155 
1156  using SH = Tcgen05CpShape;
1157  using MC = Tcgen05CpMulticast;
1158  switch (getShape()) {
1159  case SH::SHAPE_128x256b:
1160  case SH::SHAPE_128x128b:
1161  case SH::SHAPE_4x256b:
1162  if (mc != MC::NONE)
1163  return emitError("Invalid multicast type for tcgen05.cp Op");
1164  break;
1165  case SH::SHAPE_64x128b:
1166  if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1167  return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1168  "warpx2_02_13 for tcgen05.cp Op");
1169  break;
1170  case SH::SHAPE_32x128b:
1171  if (mc != MC::WARPX4)
1172  return emitError(
1173  "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1174  break;
1175  }
1176  return success();
1177 }
1178 
1179 LogicalResult NVVM::MatchSyncOp::verify() {
1180  if (getKind() == NVVM::MatchSyncKind::all) {
1181  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1182  if (!type || type.getBody().size() != 2 ||
1183  !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1184  return emitOpError("match.sync 'all' returns a two element struct with "
1185  "first element as i32 and second element as i1");
1186  }
1187  } else {
1188  if (!getType().isInteger(32)) {
1189  return emitOpError("match.sync 'any' returns an i32");
1190  }
1191  }
1192  return success();
1193 }
1194 
1195 LogicalResult NVVM::VoteSyncOp::verify() {
1196  if (getKind() == NVVM::VoteSyncKind::ballot) {
1197  if (!getType().isInteger(32)) {
1198  return emitOpError("vote.sync 'ballot' returns an i32");
1199  }
1200  } else {
1201  if (!getType().isInteger(1)) {
1202  return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1203  }
1204  }
1205  return success();
1206 }
1207 
1208 LogicalResult NVVM::PrefetchOp::verify() {
1209  using MemSpace = NVVM::NVVMMemorySpace;
1210  using CacheLevel = NVVM::PrefetchCacheLevel;
1211 
1212  unsigned addressSpace =
1213  llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1214  std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1215 
1216  if (getUniform()) {
1217  if (getCacheLevel() != CacheLevel::L1)
1218  return emitOpError("unsupported cache level, the only supported uniform "
1219  "cache level is L1");
1220 
1221  if (addressSpace != MemSpace::kGenericMemorySpace)
1222  return emitOpError(
1223  "prefetch to uniform cache requires a generic pointer");
1224  }
1225 
1226  if (evictPriority) {
1227  if (getCacheLevel() != CacheLevel::L2)
1228  return emitOpError(
1229  "cache eviction priority supported only for cache level L2");
1230 
1231  if (addressSpace != MemSpace::kGlobalMemorySpace)
1232  return emitOpError("cache eviction priority requires a global pointer");
1233 
1234  if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1235  *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1236  return emitOpError(
1237  "unsupported cache eviction priority, only evict_last and "
1238  "evict_normal are supported");
1239  }
1240 
1241  return success();
1242 }
1243 
1244 /// Packs the given `field` into the `result`.
1245 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1246 static llvm::Value *
1247 packValInto64Bits(llvm::IRBuilderBase &builder,
1248  llvm::Value *result, // the `result` (unset bits are zero)
1249  llvm::Value *field, // `field` to pack into `result`
1250  unsigned sizeInBits, // Size of `field` in bits
1251  unsigned start) { // Starting bit within `result`
1252  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1253 
1254  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1255  if (mask != 0xffffffffu)
1256  field = builder.CreateAnd(field, builder.getInt32(mask));
1257 
1258  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1259  field = builder.CreateShl(field, start);
1260 
1261  return builder.CreateOr(result, field);
1262 }
1263 
1264 void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1266  llvm::IRBuilderBase &builder) {
1267  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1268  llvm::Value *smemDesc = builder.getInt64(0);
1269 
1270  smemDesc = packValInto64Bits(builder, smemDesc,
1271  mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1272  smemDesc = packValInto64Bits(
1273  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1274  smemDesc = packValInto64Bits(
1275  builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1276 
1277  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1278  smemDesc = packValInto64Bits(builder, smemDesc,
1279  mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1280  smemDesc = packValInto64Bits(
1281  builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1282  smemDesc = packValInto64Bits(builder, smemDesc,
1283  mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1284 
1285  mt.mapValue(thisOp.getRes()) = smemDesc;
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // getIntrinsicID/getIntrinsicIDAndArgs methods
1290 //===----------------------------------------------------------------------===//
1291 
1292 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1293  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1294 
1295 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1296  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1297 
1299 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1302 
1303  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1304  bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1305  switch (cpAsyncOp.getSize()) {
1306  case 4:
1307  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1308  break;
1309  case 8:
1310  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1311  break;
1312  case 16:
1313  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1314  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1315  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1316  break;
1317  default:
1318  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1319  }
1320 
1321  // Fill the Intrinsic Args
1322  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1323  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1324  if (hasCpSize)
1325  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1326 
1327  return id;
1328 }
1329 
1330 mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1331  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1332  auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1334  llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1335 
1336  // Fill the Intrinsic Args
1337  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1338  args.push_back(mt.lookupValue(thisOp.getSize()));
1339 
1340  mlir::Value cacheHint = thisOp.getL2CacheHint();
1341  const bool hasCacheHint = static_cast<bool>(cacheHint);
1342  llvm::Value *i64Unused =
1343  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1344  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1345  args.push_back(builder.getInt1(hasCacheHint));
1346 
1347  return {id, std::move(args)};
1348 }
1349 
1350 mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1351  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1352  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1354  llvm::Intrinsic::ID id =
1355  llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1356 
1357  // Fill the Intrinsic Args
1358  args.push_back(mt.lookupValue(thisOp.getDstMem()));
1359  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1360  args.push_back(mt.lookupValue(thisOp.getSize()));
1361 
1362  mlir::Value cacheHint = thisOp.getL2CacheHint();
1363  const bool hasCacheHint = static_cast<bool>(cacheHint);
1364  llvm::Value *i64Unused =
1365  llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1366  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1367  args.push_back(builder.getInt1(hasCacheHint));
1368 
1369  // Choose the bytemask variant
1370  if (mlir::Value byteMask = thisOp.getByteMask()) {
1371  args.push_back(mt.lookupValue(byteMask));
1372  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1373  }
1374 
1375  return {id, std::move(args)};
1376 }
1377 
1378 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1379  bool isIm2Col) {
1380  switch (tensorDims) {
1381  case 1:
1382  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1383  case 2:
1384  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1385  case 3:
1386  return isIm2Col
1387  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1388  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1389  case 4:
1390  return isIm2Col
1391  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1392  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1393  case 5:
1394  return isIm2Col
1395  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1396  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1397  default:
1398  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1399  }
1400 }
1401 
1402 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1403  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1404 
1405 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1406  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1407  : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1408 
1409 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1410  [&]() -> auto { \
1411  switch (dims) { \
1412  case 1: \
1413  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1414  case 2: \
1415  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1416  case 3: \
1417  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1418  case 4: \
1419  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1420  case 5: \
1421  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1422  default: \
1423  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1424  } \
1425  }()
1426 
1427 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1428  int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1429  using RedTy = NVVM::TMAReduxKind;
1430  switch (kind) {
1431  case RedTy::ADD:
1432  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1433  case RedTy::MIN:
1434  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1435  case RedTy::MAX:
1436  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1437  case RedTy::INC:
1438  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1439  case RedTy::DEC:
1440  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1441  case RedTy::AND:
1442  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1443  case RedTy::OR:
1444  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1445  case RedTy::XOR:
1446  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1447  }
1448  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1449 }
1450 
1451 #define _none
1452 
1453 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1454  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1455  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1456 
1457 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1458  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1459  : CVT_F2TF32_ID_IMPL(rnd, relu, )
1460 
1462 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1463  NVVM::SaturationMode sat, bool hasRelu) {
1464  using RndMode = NVVM::FPRoundingMode;
1465  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1466  switch (rnd) {
1467  case RndMode::RN:
1468  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1469  case RndMode::RZ:
1470  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1471  case RndMode::RNA:
1472  return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1473  default:
1474  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1475  }
1476 }
1477 
1478 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1479  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1480  : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1481 
1483 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1484  switch (type) {
1485  case NVVM::ConvertFP6Type::E2M3:
1486  return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1487  case NVVM::ConvertFP6Type::E3M2:
1488  return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1489  }
1490  llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1491 }
1492 
1493 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1494  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1495  : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1496 
1497 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1498  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1499  : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1500 
1502 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1503  NVVM::FPRoundingMode rnd,
1504  NVVM::SaturationMode sat, bool hasRelu) {
1505  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1506  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1507  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1508 
1509  switch (type) {
1510  case NVVM::ConvertFP8Type::E4M3:
1511  return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1512  case NVVM::ConvertFP8Type::E5M2:
1513  return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1514  case NVVM::ConvertFP8Type::UE8M0:
1515  if (hasRoundingModeRZ)
1516  return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1517  else if (hasRoundingModeRP)
1518  return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1519  }
1520  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1521 }
1522 
1523 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1524  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1525  : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1526 
1528 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1529  switch (type) {
1530  case NVVM::ConvertFP8Type::E4M3:
1531  return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1532  case NVVM::ConvertFP8Type::E5M2:
1533  return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1534  default:
1535  llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1536  }
1537 }
1538 
1539 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1540  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1541  : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1542 
1544 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1545  NVVM::SaturationMode sat) {
1546  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1547  switch (rnd) {
1548  case NVVM::FPRoundingMode::RZ:
1549  return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1550  case NVVM::FPRoundingMode::RP:
1551  return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1552  default:
1553  llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1554  }
1555 }
1556 
1558 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1561  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1562  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1563  .getAddressSpace();
1564  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1565  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1566 
1568  if (isShared) {
1569  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1570  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1571  } else {
1572  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1573  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1574  }
1575 
1576  // Fill the Intrinsic Args
1577  args.push_back(mt.lookupValue(curOp.getAddr()));
1578  args.push_back(mt.lookupValue(curOp.getNCols()));
1579 
1580  return id;
1581 }
1582 
1583 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1586  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1587  auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1588  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1589  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1590 
1591  // Fill the Intrinsic Args
1592  args.push_back(mt.lookupValue(curOp.getTaddr()));
1593  args.push_back(mt.lookupValue(curOp.getNCols()));
1594 
1595  return id;
1596 }
1597 
1598 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1599  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1600  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1601 
1602 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1603  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1604  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1605 
1607 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1610  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1611  unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1612  .getAddressSpace();
1613  bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1614  bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1615  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1616 
1617  llvm::Intrinsic::ID id =
1618  is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1619  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1620 
1621  // Fill the Intrinsic Args
1622  args.push_back(mt.lookupValue(curOp.getAddr()));
1623  if (hasMulticast)
1624  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1625 
1626  return id;
1627 }
1628 
1629 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1630  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1631 
1632 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1633  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1634  : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1635 
1636 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1637  [&]() -> auto { \
1638  if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1639  return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1640  if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1641  return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1642  return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1643  }()
1644 
1645 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1646  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1647  bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1648  auto srcFmt = curOp.getSrcFormat();
1649  auto mc = curOp.getMulticast();
1650 
1651  switch (curOp.getShape()) {
1652  case Tcgen05CpShape::SHAPE_128x256b:
1653  return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1654  case Tcgen05CpShape::SHAPE_128x128b:
1655  return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1656  case Tcgen05CpShape::SHAPE_4x256b:
1657  return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1658  case Tcgen05CpShape::SHAPE_32x128b:
1659  return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1660  case Tcgen05CpShape::SHAPE_64x128b:
1661  return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1662  ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1663  : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1664  }
1665  llvm_unreachable("Invalid shape in tcgen05 cp Op");
1666 }
1667 
1668 // Returns the valid vector length for a given shape and vector length, the
1669 // function models the table mentioned in the tcgen05.{ld, st} Op description
1670 static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1671  unsigned vecLen) {
1672  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1673  return vecLen >= 2;
1674  if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1675  return vecLen >= 4;
1676  return true;
1677 }
1678 
1679 LogicalResult Tcgen05LdOp::verify() {
1680  LogicalResult result = success();
1681  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1682  result = emitError("shape 16x32bx2 requires offset argument");
1683 
1684  auto resTy = getRes().getType();
1685  unsigned resLen = isa<VectorType>(resTy)
1686  ? llvm::cast<VectorType>(resTy).getNumElements()
1687  : 1;
1688  if (!isValidVectorLength(getShape(), resLen))
1689  result = emitError(llvm::formatv("invalid result type length {0} for shape "
1690  "{1} in tcgen05.ld Op",
1691  resLen, stringifyEnum(getShape())));
1692 
1693  return result;
1694 }
1695 
1696 LogicalResult Tcgen05StOp::verify() {
1697  LogicalResult result = success();
1698  if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1699  result = emitError("shape 16x32bx2 requires offset argument");
1700 
1701  auto valTy = getVal().getType();
1702  unsigned valLen = isa<VectorType>(valTy)
1703  ? llvm::cast<VectorType>(valTy).getNumElements()
1704  : 1;
1705  if (!isValidVectorLength(getShape(), valLen))
1706  result = emitError(llvm::formatv("invalid input length {0} for shape "
1707  "{1} in tcgen05.st Op",
1708  valLen, stringifyEnum(getShape())));
1709 
1710  return result;
1711 }
1712 
1713 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1714 /// have ConstantRangeAttr.
1715 static void nvvmInferResultRanges(Operation *op, Value result,
1717  SetIntRangeFn setResultRanges) {
1718  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1719  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1720  rangeAttr.getLower(), rangeAttr.getUpper()});
1721  }
1722 }
1723 
1724 static llvm::Value *getAsPackedI32(llvm::Value *arg,
1725  llvm::IRBuilderBase &builder) {
1726  return builder.CreateBitCast(arg,
1727  llvm::Type::getInt32Ty(builder.getContext()));
1728 }
1729 
1730 NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1731  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1732  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1733 
1735  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1736  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1737  args.push_back(mt.lookupValue(curOp.getC()));
1738 
1739  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1740  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1741  unsigned type = (isASigned << 1) | isBSigned;
1742  const llvm::Intrinsic::ID ids[] = {
1743  llvm::Intrinsic::nvvm_idp4a_u_u,
1744  llvm::Intrinsic::nvvm_idp4a_u_s,
1745  llvm::Intrinsic::nvvm_idp4a_s_u,
1746  llvm::Intrinsic::nvvm_idp4a_s_s,
1747  };
1748  return {ids[type], args};
1749 }
1750 
1751 NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1752  Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1753  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1754 
1756  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1757  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1758  args.push_back(builder.getInt1(curOp.getBHi()));
1759  args.push_back(mt.lookupValue(curOp.getC()));
1760 
1761  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1762  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1763  unsigned type = (isASigned << 1) | isBSigned;
1764  const llvm::Intrinsic::ID ids[] = {
1765  llvm::Intrinsic::nvvm_idp2a_u_u,
1766  llvm::Intrinsic::nvvm_idp2a_u_s,
1767  llvm::Intrinsic::nvvm_idp2a_s_u,
1768  llvm::Intrinsic::nvvm_idp2a_s_s,
1769  };
1770  return {ids[type], args};
1771 }
1772 
1773 llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1774  using MemSpace = NVVM::NVVMMemorySpace;
1775  using CacheLevel = NVVM::PrefetchCacheLevel;
1776 
1777  NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1778  std::optional<NVVM::CacheEvictionPriority> evictPriority =
1779  op.getEvictPriority();
1780  unsigned addressSpace =
1781  llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1782  .getAddressSpace();
1783 
1784  if (op.getUniform() && cacheLevel == CacheLevel::L1)
1785  return llvm::Intrinsic::nvvm_prefetchu_L1;
1786 
1787  if (evictPriority && cacheLevel == CacheLevel::L2) {
1788  switch (*evictPriority) {
1789  case NVVM::CacheEvictionPriority::EvictLast:
1790  return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1791  case NVVM::CacheEvictionPriority::EvictNormal:
1792  return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1793  default:
1794  llvm_unreachable("Invalid cache eviction priority");
1795  }
1796  }
1797 
1798  switch (addressSpace) {
1800  return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1801  : llvm::Intrinsic::nvvm_prefetch_L2;
1803  return cacheLevel == CacheLevel::L1
1804  ? llvm::Intrinsic::nvvm_prefetch_global_L1
1805  : llvm::Intrinsic::nvvm_prefetch_global_L2;
1807  return cacheLevel == CacheLevel::L1
1808  ? llvm::Intrinsic::nvvm_prefetch_local_L1
1809  : llvm::Intrinsic::nvvm_prefetch_local_L2;
1810  default:
1811  llvm_unreachable("Invalid pointer address space");
1812  }
1813 }
1814 
1815 //===----------------------------------------------------------------------===//
1816 // NVVMDialect initialization, type parsing, and registration.
1817 //===----------------------------------------------------------------------===//
1818 
1819 // TODO: This should be the llvm.nvvm dialect once this is supported.
1820 void NVVMDialect::initialize() {
1821  addOperations<
1822 #define GET_OP_LIST
1823 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1824  >();
1825  addAttributes<
1826 #define GET_ATTRDEF_LIST
1827 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1828  >();
1829 
1830  // Support unknown operations because not all NVVM operations are
1831  // registered.
1832  allowUnknownOperations();
1833  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1834  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1835 }
1836 
1837 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1838  NamedAttribute attr) {
1839  StringAttr attrName = attr.getName();
1840  // Kernel function attribute should be attached to functions.
1841  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1842  if (!isa<LLVM::LLVMFuncOp>(op)) {
1843  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1844  << "' attribute attached to unexpected op";
1845  }
1846  }
1847  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1848  // dim
1849  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1850  attrName == NVVMDialect::getReqntidAttrName() ||
1851  attrName == NVVMDialect::getClusterDimAttrName()) {
1852  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1853  if (!values || values.empty() || values.size() > 3)
1854  return op->emitError()
1855  << "'" << attrName
1856  << "' attribute must be integer array with maximum 3 index";
1857  }
1858  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1859  // attribute
1860  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1861  attrName == NVVMDialect::getMaxnregAttrName() ||
1862  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1863  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1864  return op->emitError()
1865  << "'" << attrName << "' attribute must be integer constant";
1866  }
1867 
1868  return success();
1869 }
1870 
1871 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1872  unsigned regionIndex,
1873  unsigned argIndex,
1874  NamedAttribute argAttr) {
1875  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1876  if (!funcOp)
1877  return success();
1878 
1879  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1880  StringAttr attrName = argAttr.getName();
1881  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1882  if (!isKernel) {
1883  return op->emitError()
1884  << "'" << attrName
1885  << "' attribute must be present only on kernel arguments";
1886  }
1887  if (!isa<UnitAttr>(argAttr.getValue()))
1888  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1889  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1890  return op->emitError()
1891  << "'" << attrName
1892  << "' attribute requires the argument to also have attribute '"
1893  << LLVM::LLVMDialect::getByValAttrName() << "'";
1894  }
1895  }
1896 
1897  return success();
1898 }
1899 
1900 //===----------------------------------------------------------------------===//
1901 // NVVM target attribute.
1902 //===----------------------------------------------------------------------===//
1903 LogicalResult
1905  int optLevel, StringRef triple, StringRef chip,
1906  StringRef features, DictionaryAttr flags,
1907  ArrayAttr files, bool verifyTarget) {
1908  if (optLevel < 0 || optLevel > 3) {
1909  emitError() << "The optimization level must be a number between 0 and 3.";
1910  return failure();
1911  }
1912  if (triple.empty()) {
1913  emitError() << "The target triple cannot be empty.";
1914  return failure();
1915  }
1916  if (chip.empty()) {
1917  emitError() << "The target chip cannot be empty.";
1918  return failure();
1919  }
1920  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1921  return mlir::isa_and_nonnull<StringAttr>(attr);
1922  })) {
1923  emitError() << "All the elements in the `link` array must be strings.";
1924  return failure();
1925  }
1926  return success();
1927 }
1928 
1929 LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
1930  if (!getVerifyTarget())
1931  return success();
1932 
1933  auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1934  if (!gpuModuleOp) {
1935  return emitError(gpuModule->getLoc(),
1936  "NVVM target attribute must be attached to a GPU module");
1937  }
1938 
1939  const NVVMCheckSMVersion targetSMVersion =
1940  NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1941  if (!targetSMVersion.isMinimumSMVersion()) {
1942  return emitError(gpuModule->getLoc(),
1943  "Minimum NVVM target SM version is sm_20");
1944  }
1945 
1946  gpuModuleOp->walk([&](Operation *op) {
1947  if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1948  const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
1949  if (!requirement.isCompatibleWith(targetSMVersion)) {
1950  op->emitOpError() << "is not supported on " << getChip();
1951  return WalkResult::interrupt();
1952  }
1953  }
1954  return WalkResult::advance();
1955  });
1956 
1957  return success();
1958 }
1959 
1960 #define GET_OP_CLASSES
1961 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1962 
1963 #define GET_ATTRDEF_CLASSES
1964 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1219::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
#define _none
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:62
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI32Type()
Definition: Builders.cpp:62
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition: NVVMDialect.h:61
NVVMMemorySpace
NVVM memory space identifiers.
Definition: NVVMDialect.h:38
@ kGenericMemorySpace
Generic memory space identifier.
Definition: NVVMDialect.h:40
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
@ kLocalMemorySpace
Local memory space identifier.
Definition: NVVMDialect.h:48
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:44
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
@ NONE
Definition: OpenACC.h:84
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)