MLIR  20.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
55  p << " " << op->getOperands();
56  if (op->getNumResults() > 0)
57  p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62  MLIRContext *context = parser.getContext();
63  auto int32Ty = IntegerType::get(context, 32);
64  auto int1Ty = IntegerType::get(context, 1);
65 
67  Type type;
68  return failure(parser.parseOperandList(ops) ||
69  parser.parseOptionalAttrDict(result.attributes) ||
70  parser.parseColonType(type) ||
71  parser.addTypeToList(type, result.types) ||
72  parser.resolveOperands(ops, {int32Ty, int1Ty},
73  parser.getNameLoc(), result.operands));
74 }
75 
77 
78 // This verifier is shared across:
79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
81 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
82  size_t numIm2ColOffsets,
83  Location loc) {
84  if (tensorDims < 1 || tensorDims > 5)
85  return emitError(loc, "expects coordinates between 1 to 5 dimension");
86 
87  if (numIm2ColOffsets) {
88  if (tensorDims < 3)
89  return emitError(
90  loc,
91  "to use im2col mode, the tensor has to be at least 3-dimensional");
92  if (tensorDims != (numIm2ColOffsets + 2))
93  return emitError(
94  loc, "im2col offsets must be 2 less than number of coordinates");
95  }
96  return success();
97 }
98 
101  getIm2colOffsets().size(), getLoc());
102 }
103 
105  if (getCoordinates().size() > 5)
106  return emitError("Maximum 5 coordinates and dimension is supported.");
107  return success();
108 }
109 
110 LogicalResult CpAsyncOp::verify() {
111  if (getModifier() != LoadCacheModifierKind::CG &&
112  getModifier() != LoadCacheModifierKind::CA)
113  return emitError("Only CG and CA cache modifiers are supported.");
114  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
115  return emitError("expected byte size to be either 4, 8 or 16.");
116  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
117  return emitError("CG cache modifier is only support for 16 bytes copy.");
118  return success();
119 }
120 
121 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
123  getIm2colOffsets().size(), getLoc());
124 }
125 
126 // Given the element type of an operand and whether or not it is an accumulator,
127 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
128 // operand's element type.
129 std::optional<mlir::NVVM::MMATypes>
130 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
131  auto half2Type =
133  if (operandElType.isF64())
134  return NVVM::MMATypes::f64;
135  if (operandElType.isF16() || operandElType == half2Type)
136  return NVVM::MMATypes::f16;
137  if (operandElType.isF32() && isAccumulator)
138  return NVVM::MMATypes::f32;
139  if (operandElType.isF32() && !isAccumulator)
140  return NVVM::MMATypes::tf32;
141  if (llvm::isa<IntegerType>(operandElType)) {
142  if (isAccumulator)
143  return NVVM::MMATypes::s32;
144  return std::nullopt;
145  }
146 
147  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
148  if (structType.getBody().empty())
149  return std::nullopt;
150  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
151  }
152 
153  return std::nullopt;
154 }
155 
156 static bool isInt4PtxType(MMATypes type) {
157  return (type == MMATypes::u4 || type == MMATypes::s4);
158 }
159 
160 static bool isInt8PtxType(MMATypes type) {
161  return (type == MMATypes::u8 || type == MMATypes::s8);
162 }
163 
164 static bool isIntegerPtxType(MMATypes type) {
165  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
166  type == MMATypes::s32;
167 }
168 
169 MMATypes MmaOp::accumPtxType() {
170  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
171  getODSOperands(2).getTypes().front(), /*isAccum=*/true);
172  assert(val.has_value() && "accumulator PTX type should always be inferrable");
173  return val.value();
174 }
175 
176 MMATypes MmaOp::resultPtxType() {
177  std::optional<mlir::NVVM::MMATypes> val =
178  inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
179  assert(val.has_value() && "result PTX type should always be inferrable");
180  return val.value();
181 }
182 
183 void MmaOp::print(OpAsmPrinter &p) {
184  SmallVector<Type, 4> regTypes;
185  struct OperandFragment {
186  StringRef operandName;
187  StringRef ptxTypeAttr;
189  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
190  : operandName(name), ptxTypeAttr(ptxTypeName) {}
191  };
192 
193  std::array<OperandFragment, 3> frags{
194  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
195  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
196  OperandFragment("C", "")};
197  SmallVector<StringRef, 4> ignoreAttrNames{
198  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
199 
200  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
201  auto &frag = frags[fragIdx];
202  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
203  for (auto operandIdx = varOperandSpec.first;
204  operandIdx < varOperandSpec.first + varOperandSpec.second;
205  operandIdx++) {
206  frag.regs.push_back(this->getOperand(operandIdx));
207  if (operandIdx == 0) {
208  regTypes.push_back(this->getOperand(operandIdx).getType());
209  }
210  }
211  std::optional<MMATypes> inferredType =
212  inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
213  if (inferredType)
214  ignoreAttrNames.push_back(frag.ptxTypeAttr);
215  }
216 
217  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
218  p << " " << frag.operandName;
219  p << "[";
220  p.printOperands(frag.regs);
221  p << "] ";
222  };
223 
224  for (const auto &frag : frags) {
225  printMmaOperand(frag);
226  }
227 
228  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
229 
230  // Print the types of the operands and result.
231  p << " : " << "(";
232  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
233  frags[1].regs[0].getType(),
234  frags[2].regs[0].getType()},
235  p);
236  p << ")";
237  p.printArrowTypeList(TypeRange{this->getRes().getType()});
238 }
239 
240 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
241  ValueRange operandA, ValueRange operandB, ValueRange operandC,
242  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
243  std::optional<MMAIntOverflow> intOverflow,
244  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
245  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
246 
247  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
248  MLIRContext *ctx = builder.getContext();
249  result.addAttribute(
250  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
251 
252  result.addOperands(operandA);
253  result.addOperands(operandB);
254  result.addOperands(operandC);
255 
256  if (multiplicandPtxTypes) {
257  result.addAttribute("multiplicandAPtxType",
258  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
259  result.addAttribute("multiplicandBPtxType",
260  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
261  } else {
262  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
263  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
264  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
265  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
266  }
267 
268  if (multiplicandLayouts) {
269  result.addAttribute("layoutA",
270  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
271  result.addAttribute("layoutB",
272  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
273  } else {
274  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
275  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
276  }
277 
278  if (intOverflow.has_value())
279  result.addAttribute("intOverflowBehavior",
280  MMAIntOverflowAttr::get(ctx, *intOverflow));
281  if (b1Op.has_value())
282  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
283 
284  result.addTypes(resultType);
285  result.addAttribute(
286  MmaOp::getOperandSegmentSizeAttr(),
287  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
288  static_cast<int32_t>(operandB.size()),
289  static_cast<int32_t>(operandC.size())}));
290 }
291 
292 // <operation> :=
293 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
294 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
295 // `->` type($res)
296 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
297  struct OperandFragment {
298  std::optional<MMATypes> elemtype;
300  SmallVector<Type> regTypes;
301  };
302 
303  Builder &builder = parser.getBuilder();
304  std::array<OperandFragment, 4> frags;
305 
306  NamedAttrList namedAttributes;
307 
308  // A helper to parse the operand segments.
309  auto parseMmaOperand = [&](StringRef operandName,
310  OperandFragment &frag) -> LogicalResult {
311  if (parser.parseKeyword(operandName).failed())
312  return failure();
313  if (parser
314  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
315  .failed())
316  return failure();
317  return success();
318  };
319 
320  // Parse the operand segments.
321  if (parseMmaOperand("A", frags[0]).failed())
322  return failure();
323  if (parseMmaOperand("B", frags[1]).failed())
324  return failure();
325  if (parseMmaOperand("C", frags[2]).failed())
326  return failure();
327 
328  if (parser.parseOptionalAttrDict(namedAttributes).failed())
329  return failure();
330 
331  // Parse the type specification and resolve operands.
332  SmallVector<Type, 3> operandTypes;
333  if (failed(parser.parseColon()))
334  return failure();
335  if (failed(parser.parseLParen()))
336  return failure();
337  if (failed(parser.parseTypeList(operandTypes)))
338  return failure();
339  if (failed(parser.parseRParen()))
340  if (operandTypes.size() != 3)
341  return parser.emitError(
342  parser.getNameLoc(),
343  "expected one type for each operand segment but got " +
344  Twine(operandTypes.size()) + " types");
345  for (const auto &iter : llvm::enumerate(operandTypes)) {
346  auto &frag = frags[iter.index()];
347  frag.regTypes.resize(frag.regs.size(), iter.value());
348  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
349  parser.getNameLoc(), result.operands)))
350  return failure();
351  frag.elemtype =
352  inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
353  }
354 
355  Type resultType;
356  if (parser.parseArrow() || parser.parseType(resultType))
357  return failure();
358  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
359 
360  std::array<StringRef, 2> names{"multiplicandAPtxType",
361  "multiplicandBPtxType"};
362  for (unsigned idx = 0; idx < names.size(); idx++) {
363  const auto &frag = frags[idx];
364  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
365  if (!frag.elemtype.has_value() && !attr.has_value()) {
366  return parser.emitError(
367  parser.getNameLoc(),
368  "attribute " + names[idx] +
369  " is not provided explicitly and cannot be inferred");
370  }
371  if (!attr.has_value())
372  result.addAttribute(
373  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
374  }
375 
376  result.addTypes(resultType);
377  if (!namedAttributes.empty())
378  result.addAttributes(namedAttributes);
379  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
380  builder.getDenseI32ArrayAttr({
381  static_cast<int32_t>(frags[0].regs.size()),
382  static_cast<int32_t>(frags[1].regs.size()),
383  static_cast<int32_t>(frags[2].regs.size()),
384  }));
385  return success();
386 }
387 
388 LogicalResult MmaOp::verify() {
389  MLIRContext *context = getContext();
390  auto f16Ty = Float16Type::get(context);
391  auto i32Ty = IntegerType::get(context, 32);
392  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
393  auto f32Ty = Float32Type::get(context);
394  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
395  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
396 
397  auto s32x4StructTy =
398  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
399  auto f32x8StructTy =
401  auto f16x2x2StructTy =
402  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
403  auto f32x4StructTy =
404  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
405  auto s32x2StructTy =
406  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
407 
408  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
409  getShapeAttr().getK()};
410 
411  // These variables define the set of allowed data types for matrices A, B, C,
412  // and result.
413  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
414  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
415  AllowedShapes allowedShapes;
416  AllowedTypes expectedA;
417  AllowedTypes expectedB;
418  AllowedTypes expectedC;
419  SmallVector<Type> expectedResult;
420 
421  // When M = 16, we just need to calculate the number of 8xk tiles, where
422  // k is a factor that depends on the data type.
423  if (mmaShape[0] == 16) {
424  int64_t kFactor;
425  Type multiplicandFragType;
426  switch (*getMultiplicandAPtxType()) {
427  case MMATypes::tf32:
428  kFactor = 4;
429  multiplicandFragType = i32Ty;
430  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
431  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
432  break;
433  case MMATypes::f16:
434  case MMATypes::bf16:
435  kFactor = 8;
436  multiplicandFragType = f16x2Ty;
437  expectedResult.push_back(f16x2x2StructTy);
438  expectedResult.push_back(f32x4StructTy);
439  break;
440  case MMATypes::s4:
441  case MMATypes::u4:
442  kFactor = 32;
443  break;
444  case MMATypes::b1:
445  kFactor = 128;
446  break;
447  case MMATypes::s8:
448  case MMATypes::u8:
449  kFactor = 16;
450  break;
451  default:
452  return emitError("invalid shape or multiplicand type: " +
453  stringifyEnum(getMultiplicandAPtxType().value()));
454  }
455 
456  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
457  expectedResult.push_back(s32x4StructTy);
458  expectedC.emplace_back(4, i32Ty);
459  multiplicandFragType = i32Ty;
460  } else {
461  expectedC.emplace_back(2, f16x2Ty);
462  expectedC.emplace_back(4, f32Ty);
463  }
464 
465  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
466  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
467  expectedA.emplace_back(unitA, multiplicandFragType);
468  expectedB.emplace_back(unitB, multiplicandFragType);
469  allowedShapes.push_back({16, 8, kFactor});
470  allowedShapes.push_back({16, 8, kFactor * 2});
471  }
472 
473  // In the M=8 case, there is only 1 possible case per data type.
474  if (mmaShape[0] == 8) {
475  if (*getMultiplicandAPtxType() == MMATypes::f16) {
476  expectedA.emplace_back(2, f16x2Ty);
477  expectedB.emplace_back(2, f16x2Ty);
478  expectedResult.push_back(f16x2x4StructTy);
479  expectedResult.push_back(f32x8StructTy);
480  expectedC.emplace_back(4, f16x2Ty);
481  expectedC.emplace_back(8, f32Ty);
482  allowedShapes.push_back({8, 8, 4});
483  }
484  if (*getMultiplicandAPtxType() == MMATypes::f64) {
485  Type f64Ty = Float64Type::get(context);
486  expectedA.emplace_back(1, f64Ty);
487  expectedB.emplace_back(1, f64Ty);
488  expectedC.emplace_back(2, f64Ty);
489  // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
490  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
491  context, SmallVector<Type>(2, f64Ty)));
492  allowedShapes.push_back({8, 8, 4});
493  }
494  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
495  expectedA.push_back({i32Ty});
496  expectedB.push_back({i32Ty});
497  expectedC.push_back({i32Ty, i32Ty});
498  expectedResult.push_back(s32x2StructTy);
499  if (isInt4PtxType(getMultiplicandAPtxType().value()))
500  allowedShapes.push_back({8, 8, 32});
501  if (isInt8PtxType(getMultiplicandAPtxType().value()))
502  allowedShapes.push_back({8, 8, 16});
503  if (getMultiplicandAPtxType().value() == MMATypes::b1)
504  allowedShapes.push_back({8, 8, 128});
505  }
506  }
507 
508  std::string errorMessage;
509  llvm::raw_string_ostream errorStream(errorMessage);
510 
511  // Check that we matched an existing shape/dtype combination.
512  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
513  !llvm::is_contained(allowedShapes, mmaShape)) {
514  errorStream << "unimplemented variant for MMA shape <";
515  llvm::interleaveComma(mmaShape, errorStream);
516  errorStream << ">";
517  return emitOpError(errorMessage);
518  }
519 
520  // Verify the operand types for segments of A, B, and C operands.
521  std::array<StringRef, 3> operandNames{"A", "B", "C"};
522  for (const auto &iter : llvm::enumerate(
523  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
524  auto spec = this->getODSOperandIndexAndLength(iter.index());
525  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
526  operand_type_begin() + spec.first +
527  spec.second);
528  bool match = llvm::is_contained(iter.value(), operandTySeg);
529 
530  if (!match) {
531  errorStream << "Could not match types for the "
532  << operandNames[iter.index()]
533  << " operands; expected one of ";
534  for (const auto &x : iter.value()) {
535  errorStream << x.size() << "x" << x[0] << " ";
536  }
537  errorStream << "but got ";
538  llvm::interleaveComma(operandTySeg, errorStream);
539  return emitOpError(errorMessage);
540  }
541  }
542 
543  // Check the result type
544  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
545  return expectedResultType == getResult().getType();
546  })) {
547  errorStream
548  << "Could not match allowed types for the result; expected one of ";
549  llvm::interleaveComma(expectedResult, errorStream);
550  errorStream << " but got " << getResult().getType();
551  return emitOpError(errorMessage);
552  }
553 
554  // Ensure that binary MMA variants have a b1 MMA operation defined.
555  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
556  return emitOpError("op requires " + getB1OpAttrName().strref() +
557  " attribute");
558  }
559 
560  // Ensure int4/int8 MMA variants specify the accum overflow behavior
561  // attribute.
562  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
563  isInt8PtxType(*getMultiplicandAPtxType())) {
564  if (!getIntOverflowBehavior())
565  return emitOpError("op requires " +
566  getIntOverflowBehaviorAttrName().strref() +
567  " attribute");
568  }
569 
570  return success();
571 }
572 
573 LogicalResult ShflOp::verify() {
574  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
575  return success();
576  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
577  auto elementType = (type && type.getBody().size() == 2)
578  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
579  : nullptr;
580  if (!elementType || elementType.getWidth() != 1)
581  return emitError("expected return type to be a two-element struct with "
582  "i1 as the second element");
583  return success();
584 }
585 
586 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
587  NVVM::MMAFrag frag, int nRow,
588  int nCol,
589  MLIRContext *context) {
590  unsigned numberElements = 0;
591  Type elementType;
592  OpBuilder builder(context);
593  Type f16x2 = VectorType::get(2, builder.getF16Type());
594  if (type == NVVM::MMATypes::f16) {
595  elementType = f16x2;
596  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
597  numberElements = 8;
598  else
599  numberElements = 4;
600  } else if (type == NVVM::MMATypes::f32) {
601  elementType = builder.getF32Type();
602  numberElements = 8;
603  } else if (type == NVVM::MMATypes::tf32) {
604  elementType = builder.getI32Type();
605  numberElements = 4;
606  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
607  elementType = builder.getI32Type();
608  int parallelSize = 0;
609  if (frag == NVVM::MMAFrag::a)
610  parallelSize = nRow;
611  if (frag == NVVM::MMAFrag::b)
612  parallelSize = nCol;
613 
614  // m == 16 && n == 16 && k == 16
615  if (parallelSize == 16)
616  numberElements = 2;
617  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
618  else if (parallelSize == 8)
619  numberElements = 1;
620  else if (parallelSize == 32)
621  numberElements = 4;
622  } else if (type == NVVM::MMATypes::s32) {
623  elementType = builder.getI32Type();
624  numberElements = 8;
625  }
626  assert(numberElements != 0 && elementType != nullptr);
627  return std::make_pair(elementType, numberElements);
628 }
629 
630 static std::pair<mlir::Type, unsigned>
631 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
632  int k, MLIRContext *context) {
633  int nRow, nCol;
634  if (frag == NVVM::MMAFrag::a) {
635  nRow = m;
636  nCol = k;
637  } else if (frag == NVVM::MMAFrag::b) {
638  nRow = k;
639  nCol = n;
640  } else {
641  nRow = m;
642  nCol = n;
643  }
644  assert(nRow && nCol);
645  return inferMMAType(type, frag, nRow, nCol, context);
646 }
647 
648 LogicalResult NVVM::WMMALoadOp::verify() {
649  unsigned addressSpace =
650  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
651  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
652  addressSpace != NVVM::kSharedMemorySpace)
653  return emitOpError("expected source pointer in memory "
654  "space 0, 1, 3");
655 
656  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
657  getEltype(), getFrag()) == 0)
658  return emitOpError() << "invalid attribute combination";
659  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
660  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
662  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
663  if (getType() != dstType)
664  return emitOpError("expected destination type is a structure of ")
665  << typeInfo.second << " elements of type " << typeInfo.first;
666  return success();
667 }
668 
669 LogicalResult NVVM::WMMAStoreOp::verify() {
670  unsigned addressSpace =
671  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
672  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
673  addressSpace != NVVM::kSharedMemorySpace)
674  return emitOpError("expected operands to be a source pointer in memory "
675  "space 0, 1, 3");
676 
677  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
678  getEltype()) == 0)
679  return emitOpError() << "invalid attribute combination";
680  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
681  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
682  if (getArgs().size() != typeInfo.second)
683  return emitOpError() << "expected " << typeInfo.second << " data operands";
684  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
685  return operands.getType() != typeInfo.first;
686  }))
687  return emitOpError() << "expected data operands of type " << typeInfo.first;
688  return success();
689 }
690 
691 LogicalResult NVVM::WMMAMmaOp::verify() {
692  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
693  getLayoutB(), getEltypeA(),
694  getEltypeB()) == 0)
695  return emitOpError() << "invalid attribute combination";
696  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
697  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
698  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
699  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
700  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
701  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
702  SmallVector<Type, 32> arguments;
703  arguments.append(typeInfoA.second, typeInfoA.first);
704  arguments.append(typeInfoB.second, typeInfoB.first);
705  arguments.append(typeInfoC.second, typeInfoC.first);
706  unsigned numArgs = arguments.size();
707  if (getArgs().size() != numArgs)
708  return emitOpError() << "expected " << numArgs << " arguments";
709  for (unsigned i = 0; i < numArgs; i++) {
710  if (getArgs()[i].getType() != arguments[i])
711  return emitOpError() << "expected argument " << i << " to be of type "
712  << arguments[i];
713  }
715  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
716  if (getType() != dstType)
717  return emitOpError("expected destination type is a structure of ")
718  << typeInfoC.second << " elements of type " << typeInfoC.first;
719  return success();
720 }
721 
722 LogicalResult NVVM::LdMatrixOp::verify() {
723  unsigned addressSpace =
724  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
725  if (addressSpace != NVVM::kSharedMemorySpace)
726  return emitOpError("expected source pointer in memory space 3");
727 
728  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
729  return emitOpError("expected num attribute to be 1, 2 or 4");
730 
731  Type i32 = IntegerType::get(getContext(), 32);
732  if (getNum() == 1 && getType() != i32)
733  return emitOpError("expected destination type is i32");
734  if (getNum() == 2 || getNum() == 4) {
736  getContext(), SmallVector<Type>(getNum(), i32));
737  if (getType() != dstType)
738  return emitOpError("expected destination type is a structure of ")
739  << getNum() << " elements of type i32";
740  }
741  return success();
742 }
743 
744 LogicalResult NVVM::StMatrixOp::verify() {
745  unsigned addressSpace =
746  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
747  if (addressSpace != NVVM::kSharedMemorySpace)
748  return emitOpError("expected source pointer in memory space 3");
749 
750  int numMatrix = getSources().size();
751  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
752  return emitOpError("expected num attribute to be 1, 2 or 4");
753 
754  return success();
755 }
756 
757 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
758  if (typeA == NVVM::WGMMATypes::tf32)
759  return 8;
760  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
761  return 16;
762  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
763  return 32;
764  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
765  return 32;
766  if (typeA == NVVM::WGMMATypes::b1)
767  return 256;
768  return failure();
769 }
770 
771 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
772  NVVM::WGMMATypes typeA,
773  NVVM::WGMMATypes typeB) {
774  switch (typeA) {
775  case NVVM::WGMMATypes::f16:
776  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
777  typeB == NVVM::WGMMATypes::f16)
778  return success();
779  break;
780  case NVVM::WGMMATypes::tf32:
781  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
782  return success();
783  break;
784  case NVVM::WGMMATypes::u8:
785  case NVVM::WGMMATypes::s8:
786  if (typeD == NVVM::WGMMATypes::s32 &&
787  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
788  return success();
789  break;
790  case NVVM::WGMMATypes::b1:
791  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
792  return success();
793  break;
794  case NVVM::WGMMATypes::bf16:
795  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
796  typeB == NVVM::WGMMATypes::bf16)
797  return success();
798  break;
799  case NVVM::WGMMATypes::e4m3:
800  case NVVM::WGMMATypes::e5m2:
801  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
802  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
803  return success();
804  break;
805  case WGMMATypes::f32:
806  case WGMMATypes::s32:
807  llvm_unreachable("unsupported input types");
808  break;
809  }
810  return failure();
811 }
812 
813 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
814  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
815  72, 80, 88, 96, 104, 112, 120, 128,
816  136, 144, 152, 160, 168, 176, 184, 192,
817  200, 208, 216, 224, 232, 240, 248, 256};
818  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
819  80, 96, 112, 128, 144, 160,
820  176, 192, 208, 224, 240, 256};
821  switch (typeA) {
822  case WGMMATypes::f16:
823  case WGMMATypes::tf32:
824  case WGMMATypes::bf16:
825  case WGMMATypes::e4m3:
826  case WGMMATypes::e5m2:
827  if (llvm::is_contained(allowedN, sizeN))
828  return success();
829  break;
830  case WGMMATypes::u8:
831  case WGMMATypes::s8:
832  case WGMMATypes::b1:
833  if (llvm::is_contained(allowedNshort, sizeN))
834  return success();
835  break;
836  case WGMMATypes::f32:
837  case WGMMATypes::s32:
838  llvm_unreachable("unsupported input types");
839  break;
840  }
841  return failure();
842 }
843 
844 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
845  Value outValue = getResults();
846  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
847  if (!stype)
848  return emitOpError() << "expected results to be struct";
849  int outputSize = stype.getBody().size();
850  WGMMATypes typeD = getTypeD();
851  WGMMATypes typeA = getTypeA();
852  WGMMATypes typeB = getTypeB();
853 
854  for (Type t : stype.getBody()) {
855  if (t != stype.getBody().front())
856  return emitOpError()
857  << "all elements in struct must be same type but there is " << t;
858  }
859 
860  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
861  typeD != WGMMATypes::s32) {
862  return emitOpError() << "does not support the given output type "
863  << NVVM::stringifyWGMMATypes(typeD);
864  }
865  if (typeD == WGMMATypes::s32 &&
866  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
867  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
868  }
869 
870  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
871  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
872  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
873  << NVVM::stringifyWGMMATypes(typeB)
874  << ", it is not supported.";
875  }
876 
877  // Check M
878  if (getShape().getM() != 64)
879  return emitOpError() << "shape 'm' must be 64";
880 
881  // Check K
882  FailureOr<int> allowedK = getAllowedSizeK(typeA);
883  if (failed(allowedK) || allowedK.value() != getShape().getK())
884  return emitOpError() << "shape 'k' must be " << allowedK.value()
885  << " for input type "
886  << NVVM::stringifyWGMMATypes(typeA);
887 
888  // Check N
889  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
890  return emitOpError() << "has input type "
891  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
892  << getShape().getN() << ", it is not supported.";
893  }
894 
895  // Check transpose (only available for f16/bf16)
896  // Matrices A should be stored in row-major and B in column-major.
897  // Only f16/bf16 matrices can be stored in either column-major or row-major
898  // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
899  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
900  (getLayoutA() == mlir::NVVM::MMALayout::col ||
901  getLayoutB() == mlir::NVVM::MMALayout::row)) {
902  return emitOpError()
903  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
904  << " and layout_b = " << stringifyMMALayout(getLayoutB())
905  << " for input types " << stringifyWGMMATypes(typeA) << " and "
906  << stringifyWGMMATypes(typeB)
907  << " requires transpose. However, this is only supported for: "
908  << stringifyMMATypes(MMATypes::f16) << " and "
909  << stringifyMMATypes(MMATypes::bf16);
910  }
911 
912  // Check result registers
913  int expectedOutput = 0;
914  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
915  expectedOutput = getShape().getN() / 2;
916  if (typeD == WGMMATypes::f16)
917  expectedOutput = getShape().getN() / 4;
918  if (outputSize != expectedOutput) {
919  return emitOpError() << "results " << expectedOutput
920  << ", however output struct has " << outputSize
921  << " elements";
922  }
923  // Check satfinite (only available for s32 accumulator)
924  if (typeD != WGMMATypes::s32 &&
925  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
926  NVVM::MMAIntOverflow::satfinite) {
927  return emitOpError()
928  << " `satfinite` can be only used with s32 accumulator, however "
929  "the current accumulator is "
930  << NVVM::stringifyWGMMATypes(typeD);
931  }
932 
933  return success();
934 }
935 
936 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
937 
938  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
939  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
940 
941  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
942 
943  int expectedOutputRegisters = 0;
944  if (getTypeD() == WGMMATypes::f16)
945  expectedOutputRegisters = getShape().getN() / 4;
946  else
947  expectedOutputRegisters = getShape().getN() / 2;
948 
949  std::string ptx;
950  llvm::raw_string_ostream ss(ptx);
951 
952  ss << "{\n"
953  ".reg .pred p;\n"
954  "setp.ne.b32 p, $"
955  << ((expectedOutputRegisters * 2) + 2)
956  << ", 0;\n"
957  "wgmma.mma_async.sync.aligned.m"
958  << m << "n" << n << "k" << k << "." << outputTypeName << "."
959  << stringifyWGMMATypes(getTypeA()) << "."
960  << stringifyWGMMATypes(getTypeB());
961  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
962  NVVM::MMAIntOverflow::satfinite)
963  ss << ".satfinite";
964  ss << " {";
965  int regCnt = 0;
966  for (; regCnt < expectedOutputRegisters; ++regCnt) {
967  ss << "$" << regCnt;
968  if (regCnt != expectedOutputRegisters - 1)
969  ss << ", ";
970  }
971 
972  ss << "},";
973  // Need to map read/write registers correctly.
974  regCnt = (regCnt * 2);
975  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
976  if (getTypeD() != WGMMATypes::s32) {
977  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
978  }
979  // Don't add transpose parameters unless needed.
980  if (isF16) {
981  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
982  }
983  ss << ";\n"
984  << "}\n";
985  return ptx;
986 }
987 
988 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
989  RewriterBase &rewriter,
990  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
991  &asmValues) {
992  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
993  if (getResults())
994  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
995  if (getInouts())
996  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
997  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
998  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
999  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1001  if (getTypeD() != WGMMATypes::s32) {
1002  asmValues.push_back(
1003  {makeConstantI32(rewriter,
1004  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1006  asmValues.push_back(
1007  {makeConstantI32(rewriter,
1008  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1010  }
1011  if (isF16) {
1012  asmValues.push_back(
1013  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1015  asmValues.push_back(
1016  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1018  }
1019 }
1020 LogicalResult NVVM::FenceProxyOp::verify() {
1021  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1022  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1023  if (getKind() == NVVM::ProxyKind::GENERIC)
1024  return emitOpError() << "generic proxy not a supported proxy kind";
1025  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1026  return emitOpError() << "async_shared fence requires space attribute";
1027  }
1028  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1029  return emitOpError() << "only async_shared fence can have space attribute";
1030  }
1031  return success();
1032 }
1033 
1034 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1035  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1036  return emitOpError("uni-directional proxies only support generic for "
1037  "from_proxy attribute");
1038 
1039  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1040  return emitOpError("uni-directional proxies only support tensormap "
1041  "for to_proxy attribute");
1042 
1043  return success();
1044 }
1045 
1046 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1047  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1048  return emitOpError("uni-directional proxies only support generic for "
1049  "from_proxy attribute");
1050 
1051  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1052  return emitOpError("uni-directional proxies only support tensormap "
1053  "for to_proxy attribute");
1054 
1055  return success();
1056 }
1057 
1058 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1059  if (getRegCount() % 8)
1060  return emitOpError("new register size must be multiple of 8");
1061  if (getRegCount() < 24 || getRegCount() > 256)
1062  return emitOpError("new register size must be in between 24 to 256");
1063  return success();
1064 }
1065 
1066 LogicalResult NVVM::BarrierOp::verify() {
1067  if (getNumberOfThreads() && !getBarrierId())
1068  return emitOpError(
1069  "barrier id is missing, it should be set between 0 to 15");
1070  return success();
1071 }
1072 
1073 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1074  bool isIm2Col) {
1075  switch (tensorDims) {
1076  case 1:
1077  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1078  case 2:
1079  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1080  case 3:
1081  return isIm2Col
1082  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1083  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1084  case 4:
1085  return isIm2Col
1086  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1087  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1088  case 5:
1089  return isIm2Col
1090  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1091  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1092  default:
1093  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1094  }
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // NVVMDialect initialization, type parsing, and registration.
1099 //===----------------------------------------------------------------------===//
1100 
1101 // TODO: This should be the llvm.nvvm dialect once this is supported.
1102 void NVVMDialect::initialize() {
1103  addOperations<
1104 #define GET_OP_LIST
1105 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1106  >();
1107  addAttributes<
1108 #define GET_ATTRDEF_LIST
1109 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1110  >();
1111 
1112  // Support unknown operations because not all NVVM operations are
1113  // registered.
1114  allowUnknownOperations();
1115  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1116  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1117 }
1118 
1119 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1120  NamedAttribute attr) {
1121  StringAttr attrName = attr.getName();
1122  // Kernel function attribute should be attached to functions.
1123  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1124  if (!isa<LLVM::LLVMFuncOp>(op)) {
1125  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1126  << "' attribute attached to unexpected op";
1127  }
1128  }
1129  // If maxntid and reqntid exist, it must be an array with max 3 dim
1130  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1131  attrName == NVVMDialect::getReqntidAttrName()) {
1132  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1133  if (!values || values.empty() || values.size() > 3)
1134  return op->emitError()
1135  << "'" << attrName
1136  << "' attribute must be integer array with maximum 3 index";
1137  }
1138  // If minctasm and maxnreg exist, it must be an integer attribute
1139  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1140  attrName == NVVMDialect::getMaxnregAttrName()) {
1141  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1142  return op->emitError()
1143  << "'" << attrName << "' attribute must be integer constant";
1144  }
1145 
1146  return success();
1147 }
1148 
1149 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1150  unsigned regionIndex,
1151  unsigned argIndex,
1152  NamedAttribute argAttr) {
1153  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1154  if (!funcOp)
1155  return success();
1156 
1157  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1158  StringAttr attrName = argAttr.getName();
1159  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1160  if (!isKernel) {
1161  return op->emitError()
1162  << "'" << attrName
1163  << "' attribute must be present only on kernel arguments";
1164  }
1165  if (!isa<UnitAttr>(argAttr.getValue()))
1166  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1167  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1168  return op->emitError()
1169  << "'" << attrName
1170  << "' attribute requires the argument to also have attribute '"
1171  << LLVM::LLVMDialect::getByValAttrName() << "'";
1172  }
1173  }
1174 
1175  return success();
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // NVVM target attribute.
1180 //===----------------------------------------------------------------------===//
1181 LogicalResult
1183  int optLevel, StringRef triple, StringRef chip,
1184  StringRef features, DictionaryAttr flags,
1185  ArrayAttr files) {
1186  if (optLevel < 0 || optLevel > 3) {
1187  emitError() << "The optimization level must be a number between 0 and 3.";
1188  return failure();
1189  }
1190  if (triple.empty()) {
1191  emitError() << "The target triple cannot be empty.";
1192  return failure();
1193  }
1194  if (chip.empty()) {
1195  emitError() << "The target chip cannot be empty.";
1196  return failure();
1197  }
1198  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1199  return attr && mlir::isa<StringAttr>(attr);
1200  })) {
1201  emitError() << "All the elements in the `link` array must be strings.";
1202  return failure();
1203  }
1204  return success();
1205 }
1206 
1207 #define GET_OP_CLASSES
1208 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1209 
1210 #define GET_ATTRDEF_CLASSES
1211 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
Definition: NVVMDialect.cpp:54
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:81
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerType getI32Type()
Definition: Builders.cpp:107
FloatType getF16Type()
Definition: Builders.cpp:83
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:106
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:215
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:60
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:59
bool isF16() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:955
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:38
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.