MLIR  16.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace NVVM;
34 
35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Printing/parsing for NVVM ops
40 //===----------------------------------------------------------------------===//
41 
43  p << " " << op->getOperands();
44  if (op->getNumResults() > 0)
45  p << " : " << op->getResultTypes();
46 }
47 
48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
49 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
50  MLIRContext *context = parser.getContext();
51  auto int32Ty = IntegerType::get(context, 32);
52  auto int1Ty = IntegerType::get(context, 1);
53 
55  Type type;
56  return failure(parser.parseOperandList(ops) ||
57  parser.parseOptionalAttrDict(result.attributes) ||
58  parser.parseColonType(type) ||
59  parser.addTypeToList(type, result.types) ||
60  parser.resolveOperands(ops, {int32Ty, int1Ty},
61  parser.getNameLoc(), result.operands));
62 }
63 
65 
67  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
68  return emitError("expected byte size to be either 4, 8 or 16.");
69  if (getBypassL1() && getSize() != 16)
70  return emitError("bypass l1 is only support for 16 bytes copy.");
71  return success();
72 }
73 
74 // Given the element type of an operand and whether or not it is an accumulator,
75 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
76 // operand's element type.
77 Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
78  bool isAccumulator) {
79  auto half2Type =
80  LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
81  if (operandElType.isF64())
82  return NVVM::MMATypes::f64;
83  if (operandElType.isF16() || operandElType == half2Type)
84  return NVVM::MMATypes::f16;
85  if (operandElType.isF32() && isAccumulator)
86  return NVVM::MMATypes::f32;
87  if (operandElType.isF32() && !isAccumulator)
88  return NVVM::MMATypes::tf32;
89  if (operandElType.isa<IntegerType>()) {
90  if (isAccumulator)
91  return NVVM::MMATypes::s32;
92  return llvm::None;
93  }
94 
95  if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
96  if (structType.getBody().empty())
97  return llvm::None;
98  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
99  }
100 
101  return llvm::None;
102 }
103 
104 static bool isInt4PtxType(MMATypes type) {
105  return (type == MMATypes::u4 || type == MMATypes::s4);
106 }
107 
108 static bool isInt8PtxType(MMATypes type) {
109  return (type == MMATypes::u8 || type == MMATypes::s8);
110 }
111 
112 static bool isIntegerPtxType(MMATypes type) {
113  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
114  type == MMATypes::s32;
115 }
116 
117 MMATypes MmaOp::accumPtxType() {
118  Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
119  getODSOperands(2).getTypes().front(), /*isAccum=*/true);
120  assert(val.has_value() && "accumulator PTX type should always be inferrable");
121  return val.value();
122 }
123 
124 MMATypes MmaOp::resultPtxType() {
126  inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
127  assert(val.has_value() && "result PTX type should always be inferrable");
128  return val.value();
129 }
130 
131 void MmaOp::print(OpAsmPrinter &p) {
132  SmallVector<Type, 4> regTypes;
133  struct OperandFragment {
134  StringRef operandName;
135  StringRef ptxTypeAttr;
137  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
138  : operandName(name), ptxTypeAttr(ptxTypeName) {}
139  };
140 
141  std::array<OperandFragment, 3> frags{
142  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
143  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
144  OperandFragment("C", "")};
145  SmallVector<StringRef, 4> ignoreAttrNames{
146  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
147 
148  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
149  auto &frag = frags[fragIdx];
150  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
151  for (auto operandIdx = varOperandSpec.first;
152  operandIdx < varOperandSpec.first + varOperandSpec.second;
153  operandIdx++) {
154  frag.regs.push_back(this->getOperand(operandIdx));
155  if (operandIdx == 0) {
156  regTypes.push_back(this->getOperand(operandIdx).getType());
157  }
158  }
159  Optional<MMATypes> inferredType =
160  inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
161  if (inferredType)
162  ignoreAttrNames.push_back(frag.ptxTypeAttr);
163  }
164 
165  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
166  p << " " << frag.operandName;
167  p << "[";
168  p.printOperands(frag.regs);
169  p << "] ";
170  };
171 
172  for (const auto &frag : frags) {
173  printMmaOperand(frag);
174  }
175 
176  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
177 
178  // Print the types of the operands and result.
179  p << " : "
180  << "(";
181  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
182  frags[1].regs[0].getType(),
183  frags[2].regs[0].getType()},
184  p);
185  p << ")";
186  p.printArrowTypeList(TypeRange{this->getRes().getType()});
187 }
188 
189 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
190  ValueRange operandA, ValueRange operandB, ValueRange operandC,
192  Optional<MMAIntOverflow> intOverflow,
193  Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
194  Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
195 
196  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
197  MLIRContext *ctx = builder.getContext();
198  result.addAttribute(
199  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
200 
201  result.addOperands(operandA);
202  result.addOperands(operandB);
203  result.addOperands(operandC);
204 
205  if (multiplicandPtxTypes) {
206  result.addAttribute("multiplicandAPtxType",
207  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
208  result.addAttribute("multiplicandBPtxType",
209  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
210  } else {
211  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
212  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
213  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
214  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
215  }
216 
217  if (multiplicandLayouts) {
218  result.addAttribute("layoutA",
219  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
220  result.addAttribute("layoutB",
221  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
222  } else {
223  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
224  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
225  }
226 
227  if (intOverflow.has_value())
228  result.addAttribute("intOverflowBehavior",
229  MMAIntOverflowAttr::get(ctx, *intOverflow));
230  if (b1Op.has_value())
231  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
232 
233  result.addTypes(resultType);
234  result.addAttribute(
235  MmaOp::getOperandSegmentSizeAttr(),
236  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
237  static_cast<int32_t>(operandB.size()),
238  static_cast<int32_t>(operandC.size())}));
239 }
240 
241 // <operation> :=
242 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
243 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
244 // `->` type($res)
245 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
246  struct OperandFragment {
247  Optional<MMATypes> elemtype;
249  SmallVector<Type> regTypes;
250  };
251 
252  Builder &builder = parser.getBuilder();
253  std::array<OperandFragment, 4> frags;
254 
255  NamedAttrList namedAttributes;
256 
257  // A helper to parse the operand segments.
258  auto parseMmaOperand = [&](StringRef operandName,
259  OperandFragment &frag) -> LogicalResult {
260  if (parser.parseKeyword(operandName).failed())
261  return failure();
262  if (parser
264  .failed())
265  return failure();
266  return success();
267  };
268 
269  // Parse the operand segments.
270  if (parseMmaOperand("A", frags[0]).failed())
271  return failure();
272  if (parseMmaOperand("B", frags[1]).failed())
273  return failure();
274  if (parseMmaOperand("C", frags[2]).failed())
275  return failure();
276 
277  if (parser.parseOptionalAttrDict(namedAttributes).failed())
278  return failure();
279 
280  // Parse the type specification and resolve operands.
281  SmallVector<Type, 3> operandTypes;
282  if (failed(parser.parseColon()))
283  return failure();
284  if (failed(parser.parseLParen()))
285  return failure();
286  if (failed(parser.parseTypeList(operandTypes)))
287  return failure();
288  if (failed(parser.parseRParen()))
289  if (operandTypes.size() != 3)
290  return parser.emitError(
291  parser.getNameLoc(),
292  "expected one type for each operand segment but got " +
293  Twine(operandTypes.size()) + " types");
294  for (const auto &iter : llvm::enumerate(operandTypes)) {
295  auto &frag = frags[iter.index()];
296  frag.regTypes.resize(frag.regs.size(), iter.value());
297  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
298  parser.getNameLoc(), result.operands)))
299  return failure();
300  frag.elemtype =
301  inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
302  }
303 
304  Type resultType;
305  if (parser.parseArrow() || parser.parseType(resultType))
306  return failure();
307  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
308 
309  std::array<StringRef, 2> names{"multiplicandAPtxType",
310  "multiplicandBPtxType"};
311  for (unsigned idx = 0; idx < names.size(); idx++) {
312  const auto &frag = frags[idx];
313  Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
314  if (!frag.elemtype.has_value() && !attr.has_value()) {
315  return parser.emitError(
316  parser.getNameLoc(),
317  "attribute " + names[idx] +
318  " is not provided explicitly and cannot be inferred");
319  }
320  if (!attr.has_value())
321  result.addAttribute(
322  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
323  }
324 
325  result.addTypes(resultType);
326  if (!namedAttributes.empty())
327  result.addAttributes(namedAttributes);
328  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
329  builder.getDenseI32ArrayAttr({
330  static_cast<int32_t>(frags[0].regs.size()),
331  static_cast<int32_t>(frags[1].regs.size()),
332  static_cast<int32_t>(frags[2].regs.size()),
333  }));
334  return success();
335 }
336 
338  MLIRContext *context = getContext();
339  auto f16Ty = Float16Type::get(context);
340  auto i32Ty = IntegerType::get(context, 32);
341  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
342  auto f32Ty = Float32Type::get(context);
343  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
344  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
345 
346  auto s32x4StructTy =
347  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
348  auto f32x8StructTy =
350  auto f16x2x2StructTy =
351  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
352  auto f32x4StructTy =
353  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
354  auto s32x2StructTy =
355  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
356 
357  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
358  getShapeAttr().getK()};
359 
360  // These variables define the set of allowed data types for matrices A, B, C,
361  // and result.
362  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
363  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
364  AllowedShapes allowedShapes;
365  AllowedTypes expectedA;
366  AllowedTypes expectedB;
367  AllowedTypes expectedC;
368  SmallVector<Type> expectedResult;
369 
370  // When M = 16, we just need to calculate the number of 8xk tiles, where
371  // k is a factor that depends on the data type.
372  if (mmaShape[0] == 16) {
373  int64_t kFactor;
374  Type multiplicandFragType;
375  switch (*getMultiplicandAPtxType()) {
376  case MMATypes::tf32:
377  kFactor = 4;
378  multiplicandFragType = i32Ty;
379  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
380  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
381  break;
382  case MMATypes::f16:
383  case MMATypes::bf16:
384  kFactor = 8;
385  multiplicandFragType = f16x2Ty;
386  expectedResult.push_back(f16x2x2StructTy);
387  expectedResult.push_back(f32x4StructTy);
388  break;
389  case MMATypes::s4:
390  case MMATypes::u4:
391  kFactor = 32;
392  break;
393  case MMATypes::b1:
394  kFactor = 128;
395  break;
396  case MMATypes::s8:
397  case MMATypes::u8:
398  kFactor = 16;
399  break;
400  default:
401  return emitError("invalid shape or multiplicand type: " +
402  stringifyEnum(getMultiplicandAPtxType().value()));
403  }
404 
405  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
406  expectedResult.push_back(s32x4StructTy);
407  expectedC.emplace_back(4, i32Ty);
408  multiplicandFragType = i32Ty;
409  } else {
410  expectedC.emplace_back(2, f16x2Ty);
411  expectedC.emplace_back(4, f32Ty);
412  }
413 
414  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
415  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
416  expectedA.emplace_back(unitA, multiplicandFragType);
417  expectedB.emplace_back(unitB, multiplicandFragType);
418  allowedShapes.push_back({16, 8, kFactor});
419  allowedShapes.push_back({16, 8, kFactor * 2});
420  }
421 
422  // In the M=8 case, there is only 1 possible case per data type.
423  if (mmaShape[0] == 8) {
424  if (*getMultiplicandAPtxType() == MMATypes::f16) {
425  expectedA.emplace_back(2, f16x2Ty);
426  expectedB.emplace_back(2, f16x2Ty);
427  expectedResult.push_back(f16x2x4StructTy);
428  expectedResult.push_back(f32x8StructTy);
429  expectedC.emplace_back(4, f16x2Ty);
430  expectedC.emplace_back(8, f32Ty);
431  allowedShapes.push_back({8, 8, 4});
432  }
433  if (*getMultiplicandAPtxType() == MMATypes::f64) {
434  Type f64Ty = Float64Type::get(context);
435  expectedA.emplace_back(1, f64Ty);
436  expectedB.emplace_back(1, f64Ty);
437  expectedC.emplace_back(2, f64Ty);
438  // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
439  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
440  context, SmallVector<Type>(2, f64Ty)));
441  allowedShapes.push_back({8, 8, 4});
442  }
443  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
444  expectedA.push_back({i32Ty});
445  expectedB.push_back({i32Ty});
446  expectedC.push_back({i32Ty, i32Ty});
447  expectedResult.push_back(s32x2StructTy);
448  if (isInt4PtxType(getMultiplicandAPtxType().value()))
449  allowedShapes.push_back({8, 8, 32});
450  if (isInt8PtxType(getMultiplicandAPtxType().value()))
451  allowedShapes.push_back({8, 8, 16});
452  if (getMultiplicandAPtxType().value() == MMATypes::b1)
453  allowedShapes.push_back({8, 8, 128});
454  }
455  }
456 
457  std::string errorMessage;
458  llvm::raw_string_ostream errorStream(errorMessage);
459 
460  // Check that we matched an existing shape/dtype combination.
461  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
462  !llvm::any_of(allowedShapes,
463  [&](const auto &allowed) { return allowed == mmaShape; })) {
464  errorStream << "unimplemented variant for MMA shape <";
465  llvm::interleaveComma(mmaShape, errorStream);
466  errorStream << ">";
467  return emitOpError(errorMessage);
468  }
469 
470  // Verify the operand types for segments of A, B, and C operands.
471  std::array<StringRef, 3> operandNames{"A", "B", "C"};
472  for (const auto &iter : llvm::enumerate(
473  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
474  auto spec = this->getODSOperandIndexAndLength(iter.index());
475  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
476  operand_type_begin() + spec.first +
477  spec.second);
478  bool match =
479  llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
480  return typeSet == operandTySeg;
481  });
482 
483  if (!match) {
484  errorStream << "Could not match types for the "
485  << operandNames[iter.index()]
486  << " operands; expected one of ";
487  for (const auto &x : iter.value()) {
488  errorStream << x.size() << "x" << x[0] << " ";
489  }
490  errorStream << "but got ";
491  llvm::interleaveComma(operandTySeg, errorStream);
492  return emitOpError(errorStream.str());
493  }
494  }
495 
496  // Check the result type
497  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
498  return expectedResultType == getResult().getType();
499  })) {
500  errorStream
501  << "Could not match allowed types for the result; expected one of ";
502  llvm::interleaveComma(expectedResult, errorStream);
503  errorStream << " but got " << getResult().getType();
504  return emitOpError(errorStream.str());
505  }
506 
507  // Ensure that binary MMA variants have a b1 MMA operation defined.
508  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
509  return emitOpError("op requires " + getB1OpAttrName().strref() +
510  " attribute");
511  }
512 
513  // Ensure int4/int8 MMA variants specify the accum overflow behavior
514  // attribute.
515  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
516  isInt8PtxType(*getMultiplicandAPtxType())) {
517  if (!getIntOverflowBehavior())
518  return emitOpError("op requires " +
519  getIntOverflowBehaviorAttrName().strref() +
520  " attribute");
521  }
522 
523  return success();
524 }
525 
527  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
528  return success();
529  auto type = getType().dyn_cast<LLVM::LLVMStructType>();
530  auto elementType = (type && type.getBody().size() == 2)
531  ? type.getBody()[1].dyn_cast<IntegerType>()
532  : nullptr;
533  if (!elementType || elementType.getWidth() != 1)
534  return emitError("expected return type to be a two-element struct with "
535  "i1 as the second element");
536  return success();
537 }
538 
539 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
540  NVVM::MMAFrag frag,
541  MLIRContext *context) {
542  unsigned numberElements = 0;
543  Type elementType;
544  OpBuilder builder(context);
545  Type f16x2 = VectorType::get(2, builder.getF16Type());
546  if (type == NVVM::MMATypes::f16) {
547  elementType = f16x2;
548  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
549  numberElements = 8;
550  else
551  numberElements = 4;
552  } else if (type == NVVM::MMATypes::f32) {
553  elementType = builder.getF32Type();
554  numberElements = 8;
555  } else if (type == NVVM::MMATypes::tf32) {
556  elementType = builder.getI32Type();
557  numberElements = 4;
558  }
559  assert(numberElements != 0 && elementType != nullptr);
560  return std::make_pair(elementType, numberElements);
561 }
562 
564  unsigned addressSpace =
565  getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
566  if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
567  return emitOpError("expected source pointer in memory "
568  "space 0, 1, 3");
569 
570  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
571  getEltype(), getFrag()) == 0)
572  return emitOpError() << "invalid attribute combination";
573  std::pair<Type, unsigned> typeInfo =
574  inferMMAType(getEltype(), getFrag(), getContext());
576  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
577  if (getType() != dstType)
578  return emitOpError("expected destination type is a structure of ")
579  << typeInfo.second << " elements of type " << typeInfo.first;
580  return success();
581 }
582 
584  unsigned addressSpace =
585  getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
586  if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
587  return emitOpError("expected operands to be a source pointer in memory "
588  "space 0, 1, 3");
589 
590  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
591  getEltype()) == 0)
592  return emitOpError() << "invalid attribute combination";
593  std::pair<Type, unsigned> typeInfo =
594  inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
595  if (getArgs().size() != typeInfo.second)
596  return emitOpError() << "expected " << typeInfo.second << " data operands";
597  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
598  return operands.getType() != typeInfo.first;
599  }))
600  return emitOpError() << "expected data operands of type " << typeInfo.first;
601  return success();
602 }
603 
605  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
606  getLayoutB(), getEltypeA(),
607  getEltypeB()) == 0)
608  return emitOpError() << "invalid attribute combination";
609  std::pair<Type, unsigned> typeInfoA =
610  inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
611  std::pair<Type, unsigned> typeInfoB =
612  inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
613  std::pair<Type, unsigned> typeInfoC =
614  inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
615  SmallVector<Type, 32> arguments;
616  arguments.append(typeInfoA.second, typeInfoA.first);
617  arguments.append(typeInfoB.second, typeInfoB.first);
618  arguments.append(typeInfoC.second, typeInfoC.first);
619  unsigned numArgs = arguments.size();
620  if (getArgs().size() != numArgs)
621  return emitOpError() << "expected " << numArgs << " arguments";
622  for (unsigned i = 0; i < numArgs; i++) {
623  if (getArgs()[i].getType() != arguments[i])
624  return emitOpError() << "expected argument " << i << " to be of type "
625  << arguments[i];
626  }
628  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
629  if (getType() != dstType)
630  return emitOpError("expected destination type is a structure of ")
631  << typeInfoC.second << " elements of type " << typeInfoC.first;
632  return success();
633 }
634 
636  unsigned addressSpace =
637  getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
638  if (addressSpace != 3)
639  return emitOpError("expected source pointer in memory space 3");
640 
641  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
642  return emitOpError("expected num attribute to be 1, 2 or 4");
643 
644  Type i32 = IntegerType::get(getContext(), 32);
645  if (getNum() == 1 && getType() != i32)
646  return emitOpError("expected destination type is i32");
647  if (getNum() == 2 || getNum() == 4) {
649  getContext(), SmallVector<Type>(getNum(), i32));
650  if (getType() != dstType)
651  return emitOpError("expected destination type is a structure of ")
652  << getNum() << " elements of type i32";
653  }
654  return success();
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // NVVMDialect initialization, type parsing, and registration.
659 //===----------------------------------------------------------------------===//
660 
661 // TODO: This should be the llvm.nvvm dialect once this is supported.
662 void NVVMDialect::initialize() {
663  addOperations<
664 #define GET_OP_LIST
665 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
666  >();
667  addAttributes<
668 #define GET_ATTRDEF_LIST
669 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
670  >();
671 
672  // Support unknown operations because not all NVVM operations are
673  // registered.
674  allowUnknownOperations();
675 }
676 
677 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
678  NamedAttribute attr) {
679  // Kernel function attribute should be attached to functions.
680  if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
681  if (!isa<LLVM::LLVMFuncOp>(op)) {
682  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
683  << "' attribute attached to unexpected op";
684  }
685  }
686  return success();
687 }
688 
689 #define GET_OP_CLASSES
690 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
691 
692 #define GET_ATTRDEF_CLASSES
693 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
Include the generated interface declarations.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
bool isF32() const
Definition: Types.cpp:23
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:940
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:429
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
FloatType getF16Type()
Definition: Builders.cpp:38
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
FloatType getF32Type()
Definition: Builders.cpp:40
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
void printArrowTypeList(TypeRange &&types)
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
static constexpr const bool value
SmallVector< Value, 4 > operands
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:149
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
void addOperands(ValueRange newOperands)
U dyn_cast() const
Definition: Types.h:270
bool isF16() const
Definition: Types.cpp:22
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isInt4PtxType(MMATypes type)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
virtual ParseResult parseRParen()=0
Parse a ) token.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:466
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:90
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:135
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseArrow()=0
Parse a &#39;->&#39; token.
NamedAttrList attributes
bool isF64() const
Definition: Types.cpp:24
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:283
static bool isIntegerPtxType(MMATypes type)
virtual ParseResult parseType(Type &result)=0
Parse a type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
LLVM dialect pointer type.
Definition: LLVMTypes.h:194
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
static ParseResult parseOperandList(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &args, SmallVectorImpl< Type > &argTypes, OperationState &result)
Definition: OpenACC.cpp:56
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
static bool isInt8PtxType(MMATypes type)
bool isa() const
Definition: Types.h:254
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
Square brackets supporting zero or more ops, or nothing.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
Definition: NVVMDialect.cpp:42
Optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, None otherwise.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
result_type_range getResultTypes()
Definition: Operation.h:345
IntegerType getI32Type()
Definition: Builders.cpp:54
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics...
SmallVector< Type, 4 > types
Types of the results of this operation.