MLIR  21.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
55  p << " " << op->getOperands();
56  if (op->getNumResults() > 0)
57  p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62  MLIRContext *context = parser.getContext();
63  auto int32Ty = IntegerType::get(context, 32);
64  auto int1Ty = IntegerType::get(context, 1);
65 
67  Type type;
68  return failure(parser.parseOperandList(ops) ||
69  parser.parseOptionalAttrDict(result.attributes) ||
70  parser.parseColonType(type) ||
71  parser.addTypeToList(type, result.types) ||
72  parser.resolveOperands(ops, {int32Ty, int1Ty},
73  parser.getNameLoc(), result.operands));
74 }
75 
77 
78 // This verifier is shared among the following Ops:
79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
82 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
83  bool isIm2Col,
84  size_t numIm2ColOffsets,
85  Location loc) {
86  if (tensorDims < 1 || tensorDims > 5)
87  return emitError(loc, "expects coordinates between 1 to 5 dimension");
88 
89  // For Im2Col mode, there are two constraints:
90  if (isIm2Col) {
91  // 1. Tensor must always be at least 3-d.
92  if (tensorDims < 3)
93  return emitError(
94  loc,
95  "to use im2col mode, the tensor has to be at least 3-dimensional");
96  // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97  if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
98  return emitError(
99  loc, "im2col offsets must be 2 less than number of coordinates");
100  }
101  return success();
102 }
103 
105  size_t numIm2ColOffsets = getIm2colOffsets().size();
106  bool isIm2Col = numIm2ColOffsets > 0;
107  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
108  numIm2ColOffsets, getLoc());
109 }
110 
112  if (getCoordinates().size() > 5)
113  return emitError("Maximum 5 coordinates and dimension is supported.");
114  return success();
115 }
116 
117 LogicalResult CpAsyncOp::verify() {
118  if (getModifier() != LoadCacheModifierKind::CG &&
119  getModifier() != LoadCacheModifierKind::CA)
120  return emitError("Only CG and CA cache modifiers are supported.");
121  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122  return emitError("expected byte size to be either 4, 8 or 16.");
123  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124  return emitError("CG cache modifier is only support for 16 bytes copy.");
125  return success();
126 }
127 
128 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
129  size_t numIm2ColOffsets = getIm2colOffsets().size();
130  bool isIm2Col = numIm2ColOffsets > 0;
131  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
132  numIm2ColOffsets, getLoc());
133 }
134 
135 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
136  bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
137  return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
138  getLoc());
139 }
140 
141 LogicalResult CvtFloatToTF32Op::verify() {
142  using RndMode = NVVM::FPRoundingMode;
143  switch (getRnd()) {
144  case RndMode::RNA:
145  if (getRelu())
146  return emitError("Relu not supported with rna rounding mode.");
147  break;
148  case RndMode::RN:
149  case RndMode::RZ:
150  break;
151  default:
152  return emitError(
153  "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
154  }
155  return success();
156 }
157 
158 // Given the element type of an operand and whether or not it is an accumulator,
159 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
160 // operand's element type.
161 std::optional<mlir::NVVM::MMATypes>
162 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
163  auto half2Type =
165  if (operandElType.isF64())
166  return NVVM::MMATypes::f64;
167  if (operandElType.isF16() || operandElType == half2Type)
168  return NVVM::MMATypes::f16;
169  if (operandElType.isF32() && isAccumulator)
170  return NVVM::MMATypes::f32;
171  if (operandElType.isF32() && !isAccumulator)
172  return NVVM::MMATypes::tf32;
173  if (llvm::isa<IntegerType>(operandElType)) {
174  if (isAccumulator)
175  return NVVM::MMATypes::s32;
176  return std::nullopt;
177  }
178 
179  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
180  if (structType.getBody().empty())
181  return std::nullopt;
182  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
183  }
184 
185  return std::nullopt;
186 }
187 
188 static bool isInt4PtxType(MMATypes type) {
189  return (type == MMATypes::u4 || type == MMATypes::s4);
190 }
191 
192 static bool isInt8PtxType(MMATypes type) {
193  return (type == MMATypes::u8 || type == MMATypes::s8);
194 }
195 
196 static bool isIntegerPtxType(MMATypes type) {
197  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
198  type == MMATypes::s32;
199 }
200 
201 MMATypes MmaOp::accumPtxType() {
202  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
203  getODSOperands(2).getTypes().front(), /*isAccum=*/true);
204  assert(val.has_value() && "accumulator PTX type should always be inferrable");
205  return val.value();
206 }
207 
208 MMATypes MmaOp::resultPtxType() {
209  std::optional<mlir::NVVM::MMATypes> val =
210  inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
211  assert(val.has_value() && "result PTX type should always be inferrable");
212  return val.value();
213 }
214 
215 void MmaOp::print(OpAsmPrinter &p) {
216  SmallVector<Type, 4> regTypes;
217  struct OperandFragment {
218  StringRef operandName;
219  StringRef ptxTypeAttr;
221  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
222  : operandName(name), ptxTypeAttr(ptxTypeName) {}
223  };
224 
225  std::array<OperandFragment, 3> frags{
226  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
227  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
228  OperandFragment("C", "")};
229  SmallVector<StringRef, 4> ignoreAttrNames{
230  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
231 
232  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
233  auto &frag = frags[fragIdx];
234  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
235  for (auto operandIdx = varOperandSpec.first;
236  operandIdx < varOperandSpec.first + varOperandSpec.second;
237  operandIdx++) {
238  frag.regs.push_back(this->getOperand(operandIdx));
239  if (operandIdx == 0) {
240  regTypes.push_back(this->getOperand(operandIdx).getType());
241  }
242  }
243  std::optional<MMATypes> inferredType =
244  inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
245  if (inferredType)
246  ignoreAttrNames.push_back(frag.ptxTypeAttr);
247  }
248 
249  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
250  p << " " << frag.operandName;
251  p << "[";
252  p.printOperands(frag.regs);
253  p << "] ";
254  };
255 
256  for (const auto &frag : frags) {
257  printMmaOperand(frag);
258  }
259 
260  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
261 
262  // Print the types of the operands and result.
263  p << " : " << "(";
264  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
265  frags[1].regs[0].getType(),
266  frags[2].regs[0].getType()},
267  p);
268  p << ")";
269  p.printArrowTypeList(TypeRange{this->getRes().getType()});
270 }
271 
272 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
273  ValueRange operandA, ValueRange operandB, ValueRange operandC,
274  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
275  std::optional<MMAIntOverflow> intOverflow,
276  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
277  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
278 
279  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
280  MLIRContext *ctx = builder.getContext();
281  result.addAttribute(
282  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
283 
284  result.addOperands(operandA);
285  result.addOperands(operandB);
286  result.addOperands(operandC);
287 
288  if (multiplicandPtxTypes) {
289  result.addAttribute("multiplicandAPtxType",
290  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
291  result.addAttribute("multiplicandBPtxType",
292  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
293  } else {
294  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
295  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
296  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
297  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
298  }
299 
300  if (multiplicandLayouts) {
301  result.addAttribute("layoutA",
302  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
303  result.addAttribute("layoutB",
304  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
305  } else {
306  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
307  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
308  }
309 
310  if (intOverflow.has_value())
311  result.addAttribute("intOverflowBehavior",
312  MMAIntOverflowAttr::get(ctx, *intOverflow));
313  if (b1Op.has_value())
314  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
315 
316  result.addTypes(resultType);
317  result.addAttribute(
318  MmaOp::getOperandSegmentSizeAttr(),
319  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
320  static_cast<int32_t>(operandB.size()),
321  static_cast<int32_t>(operandC.size())}));
322 }
323 
324 // <operation> :=
325 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
326 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
327 // `->` type($res)
328 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
329  struct OperandFragment {
330  std::optional<MMATypes> elemtype;
332  SmallVector<Type> regTypes;
333  };
334 
335  Builder &builder = parser.getBuilder();
336  std::array<OperandFragment, 4> frags;
337 
338  NamedAttrList namedAttributes;
339 
340  // A helper to parse the operand segments.
341  auto parseMmaOperand = [&](StringRef operandName,
342  OperandFragment &frag) -> LogicalResult {
343  if (parser.parseKeyword(operandName).failed())
344  return failure();
345  if (parser
346  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
347  .failed())
348  return failure();
349  return success();
350  };
351 
352  // Parse the operand segments.
353  if (parseMmaOperand("A", frags[0]).failed())
354  return failure();
355  if (parseMmaOperand("B", frags[1]).failed())
356  return failure();
357  if (parseMmaOperand("C", frags[2]).failed())
358  return failure();
359 
360  if (parser.parseOptionalAttrDict(namedAttributes).failed())
361  return failure();
362 
363  // Parse the type specification and resolve operands.
364  SmallVector<Type, 3> operandTypes;
365  if (failed(parser.parseColon()))
366  return failure();
367  if (failed(parser.parseLParen()))
368  return failure();
369  if (failed(parser.parseTypeList(operandTypes)))
370  return failure();
371  if (failed(parser.parseRParen()))
372  if (operandTypes.size() != 3)
373  return parser.emitError(
374  parser.getNameLoc(),
375  "expected one type for each operand segment but got " +
376  Twine(operandTypes.size()) + " types");
377  for (const auto &iter : llvm::enumerate(operandTypes)) {
378  auto &frag = frags[iter.index()];
379  frag.regTypes.resize(frag.regs.size(), iter.value());
380  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
381  parser.getNameLoc(), result.operands)))
382  return failure();
383  frag.elemtype =
384  inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
385  }
386 
387  Type resultType;
388  if (parser.parseArrow() || parser.parseType(resultType))
389  return failure();
390  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
391 
392  std::array<StringRef, 2> names{"multiplicandAPtxType",
393  "multiplicandBPtxType"};
394  for (unsigned idx = 0; idx < names.size(); idx++) {
395  const auto &frag = frags[idx];
396  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
397  if (!frag.elemtype.has_value() && !attr.has_value()) {
398  return parser.emitError(
399  parser.getNameLoc(),
400  "attribute " + names[idx] +
401  " is not provided explicitly and cannot be inferred");
402  }
403  if (!attr.has_value())
404  result.addAttribute(
405  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
406  }
407 
408  result.addTypes(resultType);
409  if (!namedAttributes.empty())
410  result.addAttributes(namedAttributes);
411  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
412  builder.getDenseI32ArrayAttr({
413  static_cast<int32_t>(frags[0].regs.size()),
414  static_cast<int32_t>(frags[1].regs.size()),
415  static_cast<int32_t>(frags[2].regs.size()),
416  }));
417  return success();
418 }
419 
420 LogicalResult MmaOp::verify() {
421  MLIRContext *context = getContext();
422  auto f16Ty = Float16Type::get(context);
423  auto i32Ty = IntegerType::get(context, 32);
424  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
425  auto f32Ty = Float32Type::get(context);
426  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
427  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
428 
429  auto s32x4StructTy =
430  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
431  auto f32x8StructTy =
432  LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
433  auto f16x2x2StructTy =
434  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
435  auto f32x4StructTy =
436  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
437  auto s32x2StructTy =
438  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
439 
440  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
441  getShapeAttr().getK()};
442 
443  // These variables define the set of allowed data types for matrices A, B, C,
444  // and result.
445  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
446  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
447  AllowedShapes allowedShapes;
448  AllowedTypes expectedA;
449  AllowedTypes expectedB;
450  AllowedTypes expectedC;
451  SmallVector<Type> expectedResult;
452 
453  // When M = 16, we just need to calculate the number of 8xk tiles, where
454  // k is a factor that depends on the data type.
455  if (mmaShape[0] == 16) {
456  int64_t kFactor;
457  Type multiplicandFragType;
458  switch (*getMultiplicandAPtxType()) {
459  case MMATypes::tf32:
460  kFactor = 4;
461  multiplicandFragType = i32Ty;
462  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
463  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
464  break;
465  case MMATypes::bf16:
466  kFactor = 8;
467  multiplicandFragType = i32Ty;
468  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
469  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
470  break;
471  case MMATypes::f16:
472  kFactor = 8;
473  multiplicandFragType = f16x2Ty;
474  expectedResult.push_back(f16x2x2StructTy);
475  expectedResult.push_back(f32x4StructTy);
476  break;
477  case MMATypes::s4:
478  case MMATypes::u4:
479  kFactor = 32;
480  break;
481  case MMATypes::b1:
482  kFactor = 128;
483  break;
484  case MMATypes::s8:
485  case MMATypes::u8:
486  kFactor = 16;
487  break;
488  default:
489  return emitError("invalid shape or multiplicand type: " +
490  stringifyEnum(getMultiplicandAPtxType().value()));
491  }
492 
493  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
494  expectedResult.push_back(s32x4StructTy);
495  expectedC.emplace_back(4, i32Ty);
496  multiplicandFragType = i32Ty;
497  } else {
498  expectedC.emplace_back(2, f16x2Ty);
499  expectedC.emplace_back(4, f32Ty);
500  }
501 
502  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
503  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
504  expectedA.emplace_back(unitA, multiplicandFragType);
505  expectedB.emplace_back(unitB, multiplicandFragType);
506  allowedShapes.push_back({16, 8, kFactor});
507  allowedShapes.push_back({16, 8, kFactor * 2});
508  }
509 
510  // In the M=8 case, there is only 1 possible case per data type.
511  if (mmaShape[0] == 8) {
512  if (*getMultiplicandAPtxType() == MMATypes::f16) {
513  expectedA.emplace_back(2, f16x2Ty);
514  expectedB.emplace_back(2, f16x2Ty);
515  expectedResult.push_back(f16x2x4StructTy);
516  expectedResult.push_back(f32x8StructTy);
517  expectedC.emplace_back(4, f16x2Ty);
518  expectedC.emplace_back(8, f32Ty);
519  allowedShapes.push_back({8, 8, 4});
520  }
521  if (*getMultiplicandAPtxType() == MMATypes::f64) {
522  Type f64Ty = Float64Type::get(context);
523  expectedA.emplace_back(1, f64Ty);
524  expectedB.emplace_back(1, f64Ty);
525  expectedC.emplace_back(2, f64Ty);
526  // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
527  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
528  context, SmallVector<Type>(2, f64Ty)));
529  allowedShapes.push_back({8, 8, 4});
530  }
531  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
532  expectedA.push_back({i32Ty});
533  expectedB.push_back({i32Ty});
534  expectedC.push_back({i32Ty, i32Ty});
535  expectedResult.push_back(s32x2StructTy);
536  if (isInt4PtxType(getMultiplicandAPtxType().value()))
537  allowedShapes.push_back({8, 8, 32});
538  if (isInt8PtxType(getMultiplicandAPtxType().value()))
539  allowedShapes.push_back({8, 8, 16});
540  if (getMultiplicandAPtxType().value() == MMATypes::b1)
541  allowedShapes.push_back({8, 8, 128});
542  }
543  }
544 
545  std::string errorMessage;
546  llvm::raw_string_ostream errorStream(errorMessage);
547 
548  // Check that we matched an existing shape/dtype combination.
549  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
550  !llvm::is_contained(allowedShapes, mmaShape)) {
551  errorStream << "unimplemented variant for MMA shape <";
552  llvm::interleaveComma(mmaShape, errorStream);
553  errorStream << ">";
554  return emitOpError(errorMessage);
555  }
556 
557  // Verify the operand types for segments of A, B, and C operands.
558  std::array<StringRef, 3> operandNames{"A", "B", "C"};
559  for (const auto &iter : llvm::enumerate(
560  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
561  auto spec = this->getODSOperandIndexAndLength(iter.index());
562  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
563  operand_type_begin() + spec.first +
564  spec.second);
565  bool match = llvm::is_contained(iter.value(), operandTySeg);
566 
567  if (!match) {
568  errorStream << "Could not match types for the "
569  << operandNames[iter.index()]
570  << " operands; expected one of ";
571  for (const auto &x : iter.value()) {
572  errorStream << x.size() << "x" << x[0] << " ";
573  }
574  errorStream << "but got ";
575  llvm::interleaveComma(operandTySeg, errorStream);
576  return emitOpError(errorMessage);
577  }
578  }
579 
580  // Check the result type
581  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
582  return expectedResultType == getResult().getType();
583  })) {
584  errorStream
585  << "Could not match allowed types for the result; expected one of ";
586  llvm::interleaveComma(expectedResult, errorStream);
587  errorStream << " but got " << getResult().getType();
588  return emitOpError(errorMessage);
589  }
590 
591  // Ensure that binary MMA variants have a b1 MMA operation defined.
592  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
593  return emitOpError("op requires " + getB1OpAttrName().strref() +
594  " attribute");
595  }
596 
597  // Ensure int4/int8 MMA variants specify the accum overflow behavior
598  // attribute.
599  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
600  isInt8PtxType(*getMultiplicandAPtxType())) {
601  if (!getIntOverflowBehavior())
602  return emitOpError("op requires " +
603  getIntOverflowBehaviorAttrName().strref() +
604  " attribute");
605  }
606 
607  return success();
608 }
609 
610 LogicalResult ShflOp::verify() {
611  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
612  return success();
613  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
614  auto elementType = (type && type.getBody().size() == 2)
615  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
616  : nullptr;
617  if (!elementType || elementType.getWidth() != 1)
618  return emitError("expected return type to be a two-element struct with "
619  "i1 as the second element");
620  return success();
621 }
622 
623 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
624  NVVM::MMAFrag frag, int nRow,
625  int nCol,
626  MLIRContext *context) {
627  unsigned numberElements = 0;
628  Type elementType;
629  OpBuilder builder(context);
630  Type f16x2 = VectorType::get(2, builder.getF16Type());
631  if (type == NVVM::MMATypes::f16) {
632  elementType = f16x2;
633  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
634  numberElements = 8;
635  else
636  numberElements = 4;
637  } else if (type == NVVM::MMATypes::f32) {
638  elementType = builder.getF32Type();
639  numberElements = 8;
640  } else if (type == NVVM::MMATypes::tf32) {
641  elementType = builder.getI32Type();
642  numberElements = 4;
643  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
644  elementType = builder.getI32Type();
645  int parallelSize = 0;
646  if (frag == NVVM::MMAFrag::a)
647  parallelSize = nRow;
648  if (frag == NVVM::MMAFrag::b)
649  parallelSize = nCol;
650 
651  // m == 16 && n == 16 && k == 16
652  if (parallelSize == 16)
653  numberElements = 2;
654  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
655  else if (parallelSize == 8)
656  numberElements = 1;
657  else if (parallelSize == 32)
658  numberElements = 4;
659  } else if (type == NVVM::MMATypes::s32) {
660  elementType = builder.getI32Type();
661  numberElements = 8;
662  }
663  assert(numberElements != 0 && elementType != nullptr);
664  return std::make_pair(elementType, numberElements);
665 }
666 
667 static std::pair<mlir::Type, unsigned>
668 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
669  int k, MLIRContext *context) {
670  int nRow, nCol;
671  if (frag == NVVM::MMAFrag::a) {
672  nRow = m;
673  nCol = k;
674  } else if (frag == NVVM::MMAFrag::b) {
675  nRow = k;
676  nCol = n;
677  } else {
678  nRow = m;
679  nCol = n;
680  }
681  assert(nRow && nCol);
682  return inferMMAType(type, frag, nRow, nCol, context);
683 }
684 
685 LogicalResult NVVM::WMMALoadOp::verify() {
686  unsigned addressSpace =
687  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
688  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
689  addressSpace != NVVM::kSharedMemorySpace)
690  return emitOpError("expected source pointer in memory "
691  "space 0, 1, 3");
692 
693  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
694  getEltype(), getFrag()) == 0)
695  return emitOpError() << "invalid attribute combination";
696  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
697  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
698  Type dstType = LLVM::LLVMStructType::getLiteral(
699  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
700  if (getType() != dstType)
701  return emitOpError("expected destination type is a structure of ")
702  << typeInfo.second << " elements of type " << typeInfo.first;
703  return success();
704 }
705 
706 LogicalResult NVVM::WMMAStoreOp::verify() {
707  unsigned addressSpace =
708  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
709  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
710  addressSpace != NVVM::kSharedMemorySpace)
711  return emitOpError("expected operands to be a source pointer in memory "
712  "space 0, 1, 3");
713 
714  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
715  getEltype()) == 0)
716  return emitOpError() << "invalid attribute combination";
717  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
718  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
719  if (getArgs().size() != typeInfo.second)
720  return emitOpError() << "expected " << typeInfo.second << " data operands";
721  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
722  return operands.getType() != typeInfo.first;
723  }))
724  return emitOpError() << "expected data operands of type " << typeInfo.first;
725  return success();
726 }
727 
728 LogicalResult NVVM::WMMAMmaOp::verify() {
729  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
730  getLayoutB(), getEltypeA(),
731  getEltypeB()) == 0)
732  return emitOpError() << "invalid attribute combination";
733  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
734  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
735  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
736  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
737  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
738  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
739  SmallVector<Type, 32> arguments;
740  arguments.append(typeInfoA.second, typeInfoA.first);
741  arguments.append(typeInfoB.second, typeInfoB.first);
742  arguments.append(typeInfoC.second, typeInfoC.first);
743  unsigned numArgs = arguments.size();
744  if (getArgs().size() != numArgs)
745  return emitOpError() << "expected " << numArgs << " arguments";
746  for (unsigned i = 0; i < numArgs; i++) {
747  if (getArgs()[i].getType() != arguments[i])
748  return emitOpError() << "expected argument " << i << " to be of type "
749  << arguments[i];
750  }
751  Type dstType = LLVM::LLVMStructType::getLiteral(
752  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
753  if (getType() != dstType)
754  return emitOpError("expected destination type is a structure of ")
755  << typeInfoC.second << " elements of type " << typeInfoC.first;
756  return success();
757 }
758 
759 LogicalResult NVVM::LdMatrixOp::verify() {
760  unsigned addressSpace =
761  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
762  if (addressSpace != NVVM::kSharedMemorySpace)
763  return emitOpError("expected source pointer in memory space 3");
764 
765  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
766  return emitOpError("expected num attribute to be 1, 2 or 4");
767 
768  Type i32 = IntegerType::get(getContext(), 32);
769  if (getNum() == 1 && getType() != i32)
770  return emitOpError("expected destination type is i32");
771  if (getNum() == 2 || getNum() == 4) {
772  Type dstType = LLVM::LLVMStructType::getLiteral(
773  getContext(), SmallVector<Type>(getNum(), i32));
774  if (getType() != dstType)
775  return emitOpError("expected destination type is a structure of ")
776  << getNum() << " elements of type i32";
777  }
778  return success();
779 }
780 
781 LogicalResult NVVM::StMatrixOp::verify() {
782  unsigned addressSpace =
783  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
784  if (addressSpace != NVVM::kSharedMemorySpace)
785  return emitOpError("expected source pointer in memory space 3");
786 
787  int numMatrix = getSources().size();
788  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
789  return emitOpError("expected num attribute to be 1, 2 or 4");
790 
791  return success();
792 }
793 
794 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
795  if (typeA == NVVM::WGMMATypes::tf32)
796  return 8;
797  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
798  return 16;
799  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
800  return 32;
801  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
802  return 32;
803  if (typeA == NVVM::WGMMATypes::b1)
804  return 256;
805  return failure();
806 }
807 
808 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
809  NVVM::WGMMATypes typeA,
810  NVVM::WGMMATypes typeB) {
811  switch (typeA) {
812  case NVVM::WGMMATypes::f16:
813  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
814  typeB == NVVM::WGMMATypes::f16)
815  return success();
816  break;
817  case NVVM::WGMMATypes::tf32:
818  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
819  return success();
820  break;
821  case NVVM::WGMMATypes::u8:
822  case NVVM::WGMMATypes::s8:
823  if (typeD == NVVM::WGMMATypes::s32 &&
824  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
825  return success();
826  break;
827  case NVVM::WGMMATypes::b1:
828  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
829  return success();
830  break;
831  case NVVM::WGMMATypes::bf16:
832  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
833  typeB == NVVM::WGMMATypes::bf16)
834  return success();
835  break;
836  case NVVM::WGMMATypes::e4m3:
837  case NVVM::WGMMATypes::e5m2:
838  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
839  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
840  return success();
841  break;
842  case WGMMATypes::f32:
843  case WGMMATypes::s32:
844  llvm_unreachable("unsupported input types");
845  break;
846  }
847  return failure();
848 }
849 
850 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
851  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
852  72, 80, 88, 96, 104, 112, 120, 128,
853  136, 144, 152, 160, 168, 176, 184, 192,
854  200, 208, 216, 224, 232, 240, 248, 256};
855  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
856  80, 96, 112, 128, 144, 160,
857  176, 192, 208, 224, 240, 256};
858  switch (typeA) {
859  case WGMMATypes::f16:
860  case WGMMATypes::tf32:
861  case WGMMATypes::bf16:
862  case WGMMATypes::e4m3:
863  case WGMMATypes::e5m2:
864  if (llvm::is_contained(allowedN, sizeN))
865  return success();
866  break;
867  case WGMMATypes::u8:
868  case WGMMATypes::s8:
869  case WGMMATypes::b1:
870  if (llvm::is_contained(allowedNshort, sizeN))
871  return success();
872  break;
873  case WGMMATypes::f32:
874  case WGMMATypes::s32:
875  llvm_unreachable("unsupported input types");
876  break;
877  }
878  return failure();
879 }
880 
881 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
882  Value outValue = getResults();
883  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
884  if (!stype)
885  return emitOpError() << "expected results to be struct";
886  int outputSize = stype.getBody().size();
887  WGMMATypes typeD = getTypeD();
888  WGMMATypes typeA = getTypeA();
889  WGMMATypes typeB = getTypeB();
890 
891  for (Type t : stype.getBody()) {
892  if (t != stype.getBody().front())
893  return emitOpError()
894  << "all elements in struct must be same type but there is " << t;
895  }
896 
897  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
898  typeD != WGMMATypes::s32) {
899  return emitOpError() << "does not support the given output type "
900  << NVVM::stringifyWGMMATypes(typeD);
901  }
902  if (typeD == WGMMATypes::s32 &&
903  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
904  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
905  }
906 
907  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
908  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
909  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
910  << NVVM::stringifyWGMMATypes(typeB)
911  << ", it is not supported.";
912  }
913 
914  // Check M
915  if (getShape().getM() != 64)
916  return emitOpError() << "shape 'm' must be 64";
917 
918  // Check K
919  FailureOr<int> allowedK = getAllowedSizeK(typeA);
920  if (failed(allowedK) || allowedK.value() != getShape().getK())
921  return emitOpError() << "shape 'k' must be " << allowedK.value()
922  << " for input type "
923  << NVVM::stringifyWGMMATypes(typeA);
924 
925  // Check N
926  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
927  return emitOpError() << "has input type "
928  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
929  << getShape().getN() << ", it is not supported.";
930  }
931 
932  // Check transpose (only available for f16/bf16)
933  // Matrices A should be stored in row-major and B in column-major.
934  // Only f16/bf16 matrices can be stored in either column-major or row-major
935  // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
936  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
937  (getLayoutA() == mlir::NVVM::MMALayout::col ||
938  getLayoutB() == mlir::NVVM::MMALayout::row)) {
939  return emitOpError()
940  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
941  << " and layout_b = " << stringifyMMALayout(getLayoutB())
942  << " for input types " << stringifyWGMMATypes(typeA) << " and "
943  << stringifyWGMMATypes(typeB)
944  << " requires transpose. However, this is only supported for: "
945  << stringifyMMATypes(MMATypes::f16) << " and "
946  << stringifyMMATypes(MMATypes::bf16);
947  }
948 
949  // Check result registers
950  int expectedOutput = 0;
951  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
952  expectedOutput = getShape().getN() / 2;
953  if (typeD == WGMMATypes::f16)
954  expectedOutput = getShape().getN() / 4;
955  if (outputSize != expectedOutput) {
956  return emitOpError() << "results " << expectedOutput
957  << ", however output struct has " << outputSize
958  << " elements";
959  }
960  // Check satfinite (only available for s32 accumulator)
961  if (typeD != WGMMATypes::s32 &&
962  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
963  NVVM::MMAIntOverflow::satfinite) {
964  return emitOpError()
965  << " `satfinite` can be only used with s32 accumulator, however "
966  "the current accumulator is "
967  << NVVM::stringifyWGMMATypes(typeD);
968  }
969 
970  return success();
971 }
972 
973 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
974 
975  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
976  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
977 
978  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
979 
980  int expectedOutputRegisters = 0;
981  if (getTypeD() == WGMMATypes::f16)
982  expectedOutputRegisters = getShape().getN() / 4;
983  else
984  expectedOutputRegisters = getShape().getN() / 2;
985 
986  std::string ptx;
987  llvm::raw_string_ostream ss(ptx);
988 
989  ss << "{\n"
990  ".reg .pred p;\n"
991  "setp.ne.b32 p, $"
992  << ((expectedOutputRegisters * 2) + 2)
993  << ", 0;\n"
994  "wgmma.mma_async.sync.aligned.m"
995  << m << "n" << n << "k" << k << "." << outputTypeName << "."
996  << stringifyWGMMATypes(getTypeA()) << "."
997  << stringifyWGMMATypes(getTypeB());
998  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
999  NVVM::MMAIntOverflow::satfinite)
1000  ss << ".satfinite";
1001  ss << " {";
1002  int regCnt = 0;
1003  for (; regCnt < expectedOutputRegisters; ++regCnt) {
1004  ss << "$" << regCnt;
1005  if (regCnt != expectedOutputRegisters - 1)
1006  ss << ", ";
1007  }
1008 
1009  ss << "},";
1010  // Need to map read/write registers correctly.
1011  regCnt = (regCnt * 2);
1012  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
1013  if (getTypeD() != WGMMATypes::s32) {
1014  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1015  }
1016  // Don't add transpose parameters unless needed.
1017  if (isF16) {
1018  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1019  }
1020  ss << ";\n"
1021  << "}\n";
1022  return ptx;
1023 }
1024 
1025 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1026  RewriterBase &rewriter,
1027  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1028  &asmValues) {
1029  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1030  if (getResults())
1031  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1032  if (getInouts())
1033  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1034  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1035  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1036  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1038  if (getTypeD() != WGMMATypes::s32) {
1039  asmValues.push_back(
1040  {makeConstantI32(rewriter,
1041  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1043  asmValues.push_back(
1044  {makeConstantI32(rewriter,
1045  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1047  }
1048  if (isF16) {
1049  asmValues.push_back(
1050  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1052  asmValues.push_back(
1053  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1055  }
1056 }
1057 LogicalResult NVVM::FenceProxyOp::verify() {
1058  if (getKind() == NVVM::ProxyKind::TENSORMAP)
1059  return emitOpError() << "tensormap proxy is not a supported proxy kind";
1060  if (getKind() == NVVM::ProxyKind::GENERIC)
1061  return emitOpError() << "generic proxy not a supported proxy kind";
1062  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1063  return emitOpError() << "async_shared fence requires space attribute";
1064  }
1065  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1066  return emitOpError() << "only async_shared fence can have space attribute";
1067  }
1068  return success();
1069 }
1070 
1071 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1072  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1073  return emitOpError("uni-directional proxies only support generic for "
1074  "from_proxy attribute");
1075 
1076  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1077  return emitOpError("uni-directional proxies only support tensormap "
1078  "for to_proxy attribute");
1079 
1080  return success();
1081 }
1082 
1083 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1084  if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1085  return emitOpError("uni-directional proxies only support generic for "
1086  "from_proxy attribute");
1087 
1088  if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1089  return emitOpError("uni-directional proxies only support tensormap "
1090  "for to_proxy attribute");
1091 
1092  return success();
1093 }
1094 
1095 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1096  if (getRegCount() % 8)
1097  return emitOpError("new register size must be multiple of 8");
1098  if (getRegCount() < 24 || getRegCount() > 256)
1099  return emitOpError("new register size must be in between 24 to 256");
1100  return success();
1101 }
1102 
1103 LogicalResult NVVM::BarrierOp::verify() {
1104  if (getNumberOfThreads() && !getBarrierId())
1105  return emitOpError(
1106  "barrier id is missing, it should be set between 0 to 15");
1107  return success();
1108 }
1109 
1110 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1111  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1112 
1113 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1114  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1115 
1117 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1120 
1121  auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1122  bool hasCpSize = cpAsyncOp.getCpSize() ? true : false;
1123  switch (cpAsyncOp.getSize()) {
1124  case 4:
1125  id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1126  break;
1127  case 8:
1128  id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1129  break;
1130  case 16:
1131  id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1132  ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1133  : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1134  break;
1135  default:
1136  llvm_unreachable("Invalid copy size in CpAsyncOp.");
1137  }
1138 
1139  // Fill the Intrinsic Args
1140  args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1141  args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1142  if (hasCpSize)
1143  args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1144 
1145  return id;
1146 }
1147 
1148 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1149  bool isIm2Col) {
1150  switch (tensorDims) {
1151  case 1:
1152  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1153  case 2:
1154  return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1155  case 3:
1156  return isIm2Col
1157  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1158  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1159  case 4:
1160  return isIm2Col
1161  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1162  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1163  case 5:
1164  return isIm2Col
1165  ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1166  : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1167  default:
1168  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1169  }
1170 }
1171 
1172 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1173  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1174 
1175 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1176  is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1177  : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1178 
1179 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1180  [&]() -> auto { \
1181  switch (dims) { \
1182  case 1: \
1183  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1184  case 2: \
1185  return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1186  case 3: \
1187  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1188  case 4: \
1189  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1190  case 5: \
1191  return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1192  default: \
1193  llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1194  } \
1195  }()
1196 
1197 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1198  int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1199  using RedTy = NVVM::TMAReduxKind;
1200  switch (kind) {
1201  case RedTy::ADD:
1202  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1203  case RedTy::MIN:
1204  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1205  case RedTy::MAX:
1206  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1207  case RedTy::INC:
1208  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1209  case RedTy::DEC:
1210  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1211  case RedTy::AND:
1212  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1213  case RedTy::OR:
1214  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1215  case RedTy::XOR:
1216  return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1217  }
1218  llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1219 }
1220 
1221 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1222  hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1223  : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1224 
1225 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1226  hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1227  : CVT_F2TF32_ID_IMPL(rnd, relu, )
1228 
1229 llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1230  NVVM::SaturationMode sat,
1231  bool hasRelu) {
1232  using RndMode = NVVM::FPRoundingMode;
1233  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1234  switch (rnd) {
1235  case RndMode::RN:
1236  return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1237  case RndMode::RZ:
1238  return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1239  case RndMode::RNA:
1240  return GET_CVT_F2TF32_ID(rna, , _satfinite);
1241  default:
1242  llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1243  }
1244 }
1245 
1247 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1250  auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1251  unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1252  .getAddressSpace();
1253  bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1254  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1255 
1257  if (isShared) {
1258  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1259  : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1260  } else {
1261  id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1262  : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1263  }
1264 
1265  // Fill the Intrinsic Args
1266  args.push_back(mt.lookupValue(curOp.getAddr()));
1267  args.push_back(mt.lookupValue(curOp.getNCols()));
1268 
1269  return id;
1270 }
1271 
1272 llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1275  auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1276  auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1277  ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1278  : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1279 
1280  // Fill the Intrinsic Args
1281  args.push_back(mt.lookupValue(curOp.getTaddr()));
1282  args.push_back(mt.lookupValue(curOp.getNCols()));
1283 
1284  return id;
1285 }
1286 
1287 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1288  is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1289  : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1290 
1291 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1292  has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1293  : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1294 
1296 Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1299  auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1300  unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1301  .getAddressSpace();
1302  bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace;
1303  bool hasMulticast = curOp.getMulticastMask() ? true : false;
1304  bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1305 
1306  auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1307  : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1308 
1309  // Fill the Intrinsic Args
1310  args.push_back(mt.lookupValue(curOp.getAddr()));
1311  if (hasMulticast)
1312  args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1313 
1314  return id;
1315 }
1316 
1317 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1318 /// have ConstantRangeAttr.
1319 static void nvvmInferResultRanges(Operation *op, Value result,
1321  SetIntRangeFn setResultRanges) {
1322  if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1323  setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1324  rangeAttr.getLower(), rangeAttr.getUpper()});
1325  }
1326 }
1327 
1328 //===----------------------------------------------------------------------===//
1329 // NVVMDialect initialization, type parsing, and registration.
1330 //===----------------------------------------------------------------------===//
1331 
1332 // TODO: This should be the llvm.nvvm dialect once this is supported.
1333 void NVVMDialect::initialize() {
1334  addOperations<
1335 #define GET_OP_LIST
1336 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1337  >();
1338  addAttributes<
1339 #define GET_ATTRDEF_LIST
1340 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1341  >();
1342 
1343  // Support unknown operations because not all NVVM operations are
1344  // registered.
1345  allowUnknownOperations();
1346  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1347  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1348 }
1349 
1350 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1351  NamedAttribute attr) {
1352  StringAttr attrName = attr.getName();
1353  // Kernel function attribute should be attached to functions.
1354  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1355  if (!isa<LLVM::LLVMFuncOp>(op)) {
1356  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1357  << "' attribute attached to unexpected op";
1358  }
1359  }
1360  // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1361  // dim
1362  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1363  attrName == NVVMDialect::getReqntidAttrName() ||
1364  attrName == NVVMDialect::getClusterDimAttrName()) {
1365  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1366  if (!values || values.empty() || values.size() > 3)
1367  return op->emitError()
1368  << "'" << attrName
1369  << "' attribute must be integer array with maximum 3 index";
1370  }
1371  // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1372  // attribute
1373  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1374  attrName == NVVMDialect::getMaxnregAttrName() ||
1375  attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1376  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1377  return op->emitError()
1378  << "'" << attrName << "' attribute must be integer constant";
1379  }
1380 
1381  return success();
1382 }
1383 
1384 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1385  unsigned regionIndex,
1386  unsigned argIndex,
1387  NamedAttribute argAttr) {
1388  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1389  if (!funcOp)
1390  return success();
1391 
1392  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1393  StringAttr attrName = argAttr.getName();
1394  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1395  if (!isKernel) {
1396  return op->emitError()
1397  << "'" << attrName
1398  << "' attribute must be present only on kernel arguments";
1399  }
1400  if (!isa<UnitAttr>(argAttr.getValue()))
1401  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1402  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1403  return op->emitError()
1404  << "'" << attrName
1405  << "' attribute requires the argument to also have attribute '"
1406  << LLVM::LLVMDialect::getByValAttrName() << "'";
1407  }
1408  }
1409 
1410  return success();
1411 }
1412 
1413 //===----------------------------------------------------------------------===//
1414 // NVVM target attribute.
1415 //===----------------------------------------------------------------------===//
1416 LogicalResult
1418  int optLevel, StringRef triple, StringRef chip,
1419  StringRef features, DictionaryAttr flags,
1420  ArrayAttr files) {
1421  if (optLevel < 0 || optLevel > 3) {
1422  emitError() << "The optimization level must be a number between 0 and 3.";
1423  return failure();
1424  }
1425  if (triple.empty()) {
1426  emitError() << "The target triple cannot be empty.";
1427  return failure();
1428  }
1429  if (chip.empty()) {
1430  emitError() << "The target chip cannot be empty.";
1431  return failure();
1432  }
1433  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1434  return attr && mlir::isa<StringAttr>(attr);
1435  })) {
1436  emitError() << "All the elements in the `link` array must be strings.";
1437  return failure();
1438  }
1439  return success();
1440 }
1441 
1442 #define GET_OP_CLASSES
1443 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1444 
1445 #define GET_ATTRDEF_CLASSES
1446 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
Definition: NVVMDialect.cpp:54
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
Definition: NVVMDialect.cpp:82
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:222
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:957
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:38
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:40
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.