MLIR  19.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Types.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/AsmParser/Parser.h"
35 #include "llvm/IR/Attributes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/SourceMgr.h"
40 #include "llvm/Support/raw_ostream.h"
41 #include <cassert>
42 #include <optional>
43 #include <string>
44 
45 using namespace mlir;
46 using namespace NVVM;
47 
48 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
49 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Printing/parsing for NVVM ops
53 //===----------------------------------------------------------------------===//
54 
56  p << " " << op->getOperands();
57  if (op->getNumResults() > 0)
58  p << " : " << op->getResultTypes();
59 }
60 
61 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
63  MLIRContext *context = parser.getContext();
64  auto int32Ty = IntegerType::get(context, 32);
65  auto int1Ty = IntegerType::get(context, 1);
66 
68  Type type;
69  return failure(parser.parseOperandList(ops) ||
70  parser.parseOptionalAttrDict(result.attributes) ||
71  parser.parseColonType(type) ||
72  parser.addTypeToList(type, result.types) ||
73  parser.resolveOperands(ops, {int32Ty, int1Ty},
74  parser.getNameLoc(), result.operands));
75 }
76 
78 
80  if (getCoordinates().empty() || getCoordinates().size() > 5)
81  return emitError("expects coordinates between 1 to 5 dimension");
82 
83  // Check for im2col mode
84  if (!getIm2colOffsets().empty()) {
85  if (getCoordinates().size() < 3)
86  return emitError(
87  "to use im2col mode, the tensor has to be at least 3-dimensional");
88  if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
89  return emitError(
90  "im2col offsets must be 2 less than number of coordinates");
91  }
92  return success();
93 }
94 
96  if (getCoordinates().size() > 5)
97  return emitError("Maximum 5 coordinates and dimension is supported.");
98  return success();
99 }
100 
102  if (getModifier() != LoadCacheModifierKind::CG &&
103  getModifier() != LoadCacheModifierKind::CA)
104  return emitError("Only CG and CA cache modifiers are supported.");
105  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
106  return emitError("expected byte size to be either 4, 8 or 16.");
107  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
108  return emitError("CG cache modifier is only support for 16 bytes copy.");
109  return success();
110 }
111 
112 // Given the element type of an operand and whether or not it is an accumulator,
113 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
114 // operand's element type.
115 std::optional<mlir::NVVM::MMATypes>
116 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
117  auto half2Type =
119  if (operandElType.isF64())
120  return NVVM::MMATypes::f64;
121  if (operandElType.isF16() || operandElType == half2Type)
122  return NVVM::MMATypes::f16;
123  if (operandElType.isF32() && isAccumulator)
124  return NVVM::MMATypes::f32;
125  if (operandElType.isF32() && !isAccumulator)
126  return NVVM::MMATypes::tf32;
127  if (llvm::isa<IntegerType>(operandElType)) {
128  if (isAccumulator)
129  return NVVM::MMATypes::s32;
130  return std::nullopt;
131  }
132 
133  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
134  if (structType.getBody().empty())
135  return std::nullopt;
136  return inferOperandMMAType(structType.getBody()[0], isAccumulator);
137  }
138 
139  return std::nullopt;
140 }
141 
142 static bool isInt4PtxType(MMATypes type) {
143  return (type == MMATypes::u4 || type == MMATypes::s4);
144 }
145 
146 static bool isInt8PtxType(MMATypes type) {
147  return (type == MMATypes::u8 || type == MMATypes::s8);
148 }
149 
150 static bool isIntegerPtxType(MMATypes type) {
151  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
152  type == MMATypes::s32;
153 }
154 
155 MMATypes MmaOp::accumPtxType() {
156  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
157  getODSOperands(2).getTypes().front(), /*isAccum=*/true);
158  assert(val.has_value() && "accumulator PTX type should always be inferrable");
159  return val.value();
160 }
161 
162 MMATypes MmaOp::resultPtxType() {
163  std::optional<mlir::NVVM::MMATypes> val =
164  inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
165  assert(val.has_value() && "result PTX type should always be inferrable");
166  return val.value();
167 }
168 
169 void MmaOp::print(OpAsmPrinter &p) {
170  SmallVector<Type, 4> regTypes;
171  struct OperandFragment {
172  StringRef operandName;
173  StringRef ptxTypeAttr;
175  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
176  : operandName(name), ptxTypeAttr(ptxTypeName) {}
177  };
178 
179  std::array<OperandFragment, 3> frags{
180  OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
181  OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
182  OperandFragment("C", "")};
183  SmallVector<StringRef, 4> ignoreAttrNames{
184  mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
185 
186  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
187  auto &frag = frags[fragIdx];
188  auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
189  for (auto operandIdx = varOperandSpec.first;
190  operandIdx < varOperandSpec.first + varOperandSpec.second;
191  operandIdx++) {
192  frag.regs.push_back(this->getOperand(operandIdx));
193  if (operandIdx == 0) {
194  regTypes.push_back(this->getOperand(operandIdx).getType());
195  }
196  }
197  std::optional<MMATypes> inferredType =
198  inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
199  if (inferredType)
200  ignoreAttrNames.push_back(frag.ptxTypeAttr);
201  }
202 
203  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
204  p << " " << frag.operandName;
205  p << "[";
206  p.printOperands(frag.regs);
207  p << "] ";
208  };
209 
210  for (const auto &frag : frags) {
211  printMmaOperand(frag);
212  }
213 
214  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
215 
216  // Print the types of the operands and result.
217  p << " : " << "(";
218  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
219  frags[1].regs[0].getType(),
220  frags[2].regs[0].getType()},
221  p);
222  p << ")";
223  p.printArrowTypeList(TypeRange{this->getRes().getType()});
224 }
225 
226 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
227  ValueRange operandA, ValueRange operandB, ValueRange operandC,
228  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
229  std::optional<MMAIntOverflow> intOverflow,
230  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
231  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
232 
233  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
234  MLIRContext *ctx = builder.getContext();
235  result.addAttribute(
236  "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
237 
238  result.addOperands(operandA);
239  result.addOperands(operandB);
240  result.addOperands(operandC);
241 
242  if (multiplicandPtxTypes) {
243  result.addAttribute("multiplicandAPtxType",
244  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
245  result.addAttribute("multiplicandBPtxType",
246  MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
247  } else {
248  if (auto res = inferOperandMMAType(operandA[0].getType(), false))
249  result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
250  if (auto res = inferOperandMMAType(operandB[0].getType(), false))
251  result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
252  }
253 
254  if (multiplicandLayouts) {
255  result.addAttribute("layoutA",
256  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
257  result.addAttribute("layoutB",
258  MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
259  } else {
260  result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
261  result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
262  }
263 
264  if (intOverflow.has_value())
265  result.addAttribute("intOverflowBehavior",
266  MMAIntOverflowAttr::get(ctx, *intOverflow));
267  if (b1Op.has_value())
268  result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
269 
270  result.addTypes(resultType);
271  result.addAttribute(
272  MmaOp::getOperandSegmentSizeAttr(),
273  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
274  static_cast<int32_t>(operandB.size()),
275  static_cast<int32_t>(operandC.size())}));
276 }
277 
278 // <operation> :=
279 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
280 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
281 // `->` type($res)
283  struct OperandFragment {
284  std::optional<MMATypes> elemtype;
286  SmallVector<Type> regTypes;
287  };
288 
289  Builder &builder = parser.getBuilder();
290  std::array<OperandFragment, 4> frags;
291 
292  NamedAttrList namedAttributes;
293 
294  // A helper to parse the operand segments.
295  auto parseMmaOperand = [&](StringRef operandName,
296  OperandFragment &frag) -> LogicalResult {
297  if (parser.parseKeyword(operandName).failed())
298  return failure();
299  if (parser
300  .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
301  .failed())
302  return failure();
303  return success();
304  };
305 
306  // Parse the operand segments.
307  if (parseMmaOperand("A", frags[0]).failed())
308  return failure();
309  if (parseMmaOperand("B", frags[1]).failed())
310  return failure();
311  if (parseMmaOperand("C", frags[2]).failed())
312  return failure();
313 
314  if (parser.parseOptionalAttrDict(namedAttributes).failed())
315  return failure();
316 
317  // Parse the type specification and resolve operands.
318  SmallVector<Type, 3> operandTypes;
319  if (failed(parser.parseColon()))
320  return failure();
321  if (failed(parser.parseLParen()))
322  return failure();
323  if (failed(parser.parseTypeList(operandTypes)))
324  return failure();
325  if (failed(parser.parseRParen()))
326  if (operandTypes.size() != 3)
327  return parser.emitError(
328  parser.getNameLoc(),
329  "expected one type for each operand segment but got " +
330  Twine(operandTypes.size()) + " types");
331  for (const auto &iter : llvm::enumerate(operandTypes)) {
332  auto &frag = frags[iter.index()];
333  frag.regTypes.resize(frag.regs.size(), iter.value());
334  if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
335  parser.getNameLoc(), result.operands)))
336  return failure();
337  frag.elemtype =
338  inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
339  }
340 
341  Type resultType;
342  if (parser.parseArrow() || parser.parseType(resultType))
343  return failure();
344  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
345 
346  std::array<StringRef, 2> names{"multiplicandAPtxType",
347  "multiplicandBPtxType"};
348  for (unsigned idx = 0; idx < names.size(); idx++) {
349  const auto &frag = frags[idx];
350  std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
351  if (!frag.elemtype.has_value() && !attr.has_value()) {
352  return parser.emitError(
353  parser.getNameLoc(),
354  "attribute " + names[idx] +
355  " is not provided explicitly and cannot be inferred");
356  }
357  if (!attr.has_value())
358  result.addAttribute(
359  names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
360  }
361 
362  result.addTypes(resultType);
363  if (!namedAttributes.empty())
364  result.addAttributes(namedAttributes);
365  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
366  builder.getDenseI32ArrayAttr({
367  static_cast<int32_t>(frags[0].regs.size()),
368  static_cast<int32_t>(frags[1].regs.size()),
369  static_cast<int32_t>(frags[2].regs.size()),
370  }));
371  return success();
372 }
373 
375  MLIRContext *context = getContext();
376  auto f16Ty = Float16Type::get(context);
377  auto i32Ty = IntegerType::get(context, 32);
378  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
379  auto f32Ty = Float32Type::get(context);
380  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
381  context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
382 
383  auto s32x4StructTy =
384  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
385  auto f32x8StructTy =
387  auto f16x2x2StructTy =
388  LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
389  auto f32x4StructTy =
390  LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
391  auto s32x2StructTy =
392  LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
393 
394  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
395  getShapeAttr().getK()};
396 
397  // These variables define the set of allowed data types for matrices A, B, C,
398  // and result.
399  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
400  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
401  AllowedShapes allowedShapes;
402  AllowedTypes expectedA;
403  AllowedTypes expectedB;
404  AllowedTypes expectedC;
405  SmallVector<Type> expectedResult;
406 
407  // When M = 16, we just need to calculate the number of 8xk tiles, where
408  // k is a factor that depends on the data type.
409  if (mmaShape[0] == 16) {
410  int64_t kFactor;
411  Type multiplicandFragType;
412  switch (*getMultiplicandAPtxType()) {
413  case MMATypes::tf32:
414  kFactor = 4;
415  multiplicandFragType = i32Ty;
416  expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
417  context, {f32Ty, f32Ty, f32Ty, f32Ty}));
418  break;
419  case MMATypes::f16:
420  case MMATypes::bf16:
421  kFactor = 8;
422  multiplicandFragType = f16x2Ty;
423  expectedResult.push_back(f16x2x2StructTy);
424  expectedResult.push_back(f32x4StructTy);
425  break;
426  case MMATypes::s4:
427  case MMATypes::u4:
428  kFactor = 32;
429  break;
430  case MMATypes::b1:
431  kFactor = 128;
432  break;
433  case MMATypes::s8:
434  case MMATypes::u8:
435  kFactor = 16;
436  break;
437  default:
438  return emitError("invalid shape or multiplicand type: " +
439  stringifyEnum(getMultiplicandAPtxType().value()));
440  }
441 
442  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
443  expectedResult.push_back(s32x4StructTy);
444  expectedC.emplace_back(4, i32Ty);
445  multiplicandFragType = i32Ty;
446  } else {
447  expectedC.emplace_back(2, f16x2Ty);
448  expectedC.emplace_back(4, f32Ty);
449  }
450 
451  int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
452  int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
453  expectedA.emplace_back(unitA, multiplicandFragType);
454  expectedB.emplace_back(unitB, multiplicandFragType);
455  allowedShapes.push_back({16, 8, kFactor});
456  allowedShapes.push_back({16, 8, kFactor * 2});
457  }
458 
459  // In the M=8 case, there is only 1 possible case per data type.
460  if (mmaShape[0] == 8) {
461  if (*getMultiplicandAPtxType() == MMATypes::f16) {
462  expectedA.emplace_back(2, f16x2Ty);
463  expectedB.emplace_back(2, f16x2Ty);
464  expectedResult.push_back(f16x2x4StructTy);
465  expectedResult.push_back(f32x8StructTy);
466  expectedC.emplace_back(4, f16x2Ty);
467  expectedC.emplace_back(8, f32Ty);
468  allowedShapes.push_back({8, 8, 4});
469  }
470  if (*getMultiplicandAPtxType() == MMATypes::f64) {
471  Type f64Ty = Float64Type::get(context);
472  expectedA.emplace_back(1, f64Ty);
473  expectedB.emplace_back(1, f64Ty);
474  expectedC.emplace_back(2, f64Ty);
475  // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
476  expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
477  context, SmallVector<Type>(2, f64Ty)));
478  allowedShapes.push_back({8, 8, 4});
479  }
480  if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
481  expectedA.push_back({i32Ty});
482  expectedB.push_back({i32Ty});
483  expectedC.push_back({i32Ty, i32Ty});
484  expectedResult.push_back(s32x2StructTy);
485  if (isInt4PtxType(getMultiplicandAPtxType().value()))
486  allowedShapes.push_back({8, 8, 32});
487  if (isInt8PtxType(getMultiplicandAPtxType().value()))
488  allowedShapes.push_back({8, 8, 16});
489  if (getMultiplicandAPtxType().value() == MMATypes::b1)
490  allowedShapes.push_back({8, 8, 128});
491  }
492  }
493 
494  std::string errorMessage;
495  llvm::raw_string_ostream errorStream(errorMessage);
496 
497  // Check that we matched an existing shape/dtype combination.
498  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
499  !llvm::is_contained(allowedShapes, mmaShape)) {
500  errorStream << "unimplemented variant for MMA shape <";
501  llvm::interleaveComma(mmaShape, errorStream);
502  errorStream << ">";
503  return emitOpError(errorMessage);
504  }
505 
506  // Verify the operand types for segments of A, B, and C operands.
507  std::array<StringRef, 3> operandNames{"A", "B", "C"};
508  for (const auto &iter : llvm::enumerate(
509  SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
510  auto spec = this->getODSOperandIndexAndLength(iter.index());
511  SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
512  operand_type_begin() + spec.first +
513  spec.second);
514  bool match = llvm::is_contained(iter.value(), operandTySeg);
515 
516  if (!match) {
517  errorStream << "Could not match types for the "
518  << operandNames[iter.index()]
519  << " operands; expected one of ";
520  for (const auto &x : iter.value()) {
521  errorStream << x.size() << "x" << x[0] << " ";
522  }
523  errorStream << "but got ";
524  llvm::interleaveComma(operandTySeg, errorStream);
525  return emitOpError(errorStream.str());
526  }
527  }
528 
529  // Check the result type
530  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
531  return expectedResultType == getResult().getType();
532  })) {
533  errorStream
534  << "Could not match allowed types for the result; expected one of ";
535  llvm::interleaveComma(expectedResult, errorStream);
536  errorStream << " but got " << getResult().getType();
537  return emitOpError(errorStream.str());
538  }
539 
540  // Ensure that binary MMA variants have a b1 MMA operation defined.
541  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
542  return emitOpError("op requires " + getB1OpAttrName().strref() +
543  " attribute");
544  }
545 
546  // Ensure int4/int8 MMA variants specify the accum overflow behavior
547  // attribute.
548  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
549  isInt8PtxType(*getMultiplicandAPtxType())) {
550  if (!getIntOverflowBehavior())
551  return emitOpError("op requires " +
552  getIntOverflowBehaviorAttrName().strref() +
553  " attribute");
554  }
555 
556  return success();
557 }
558 
560  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
561  return success();
562  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
563  auto elementType = (type && type.getBody().size() == 2)
564  ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
565  : nullptr;
566  if (!elementType || elementType.getWidth() != 1)
567  return emitError("expected return type to be a two-element struct with "
568  "i1 as the second element");
569  return success();
570 }
571 
572 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
573  NVVM::MMAFrag frag, int nRow,
574  int nCol,
575  MLIRContext *context) {
576  unsigned numberElements = 0;
577  Type elementType;
578  OpBuilder builder(context);
579  Type f16x2 = VectorType::get(2, builder.getF16Type());
580  if (type == NVVM::MMATypes::f16) {
581  elementType = f16x2;
582  if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
583  numberElements = 8;
584  else
585  numberElements = 4;
586  } else if (type == NVVM::MMATypes::f32) {
587  elementType = builder.getF32Type();
588  numberElements = 8;
589  } else if (type == NVVM::MMATypes::tf32) {
590  elementType = builder.getI32Type();
591  numberElements = 4;
592  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
593  elementType = builder.getI32Type();
594  int parallelSize = 0;
595  if (frag == NVVM::MMAFrag::a)
596  parallelSize = nRow;
597  if (frag == NVVM::MMAFrag::b)
598  parallelSize = nCol;
599 
600  // m == 16 && n == 16 && k == 16
601  if (parallelSize == 16)
602  numberElements = 2;
603  // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
604  else if (parallelSize == 8)
605  numberElements = 1;
606  else if (parallelSize == 32)
607  numberElements = 4;
608  } else if (type == NVVM::MMATypes::s32) {
609  elementType = builder.getI32Type();
610  numberElements = 8;
611  }
612  assert(numberElements != 0 && elementType != nullptr);
613  return std::make_pair(elementType, numberElements);
614 }
615 
616 static std::pair<mlir::Type, unsigned>
617 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
618  int k, MLIRContext *context) {
619  int nRow, nCol;
620  if (frag == NVVM::MMAFrag::a) {
621  nRow = m;
622  nCol = k;
623  } else if (frag == NVVM::MMAFrag::b) {
624  nRow = k;
625  nCol = n;
626  } else {
627  nRow = m;
628  nCol = n;
629  }
630  assert(nRow && nCol);
631  return inferMMAType(type, frag, nRow, nCol, context);
632 }
633 
635  unsigned addressSpace =
636  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
637  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
638  addressSpace != NVVM::kSharedMemorySpace)
639  return emitOpError("expected source pointer in memory "
640  "space 0, 1, 3");
641 
642  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
643  getEltype(), getFrag()) == 0)
644  return emitOpError() << "invalid attribute combination";
645  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
646  getEltype(), getFrag(), getM(), getN(), getK(), getContext());
648  getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
649  if (getType() != dstType)
650  return emitOpError("expected destination type is a structure of ")
651  << typeInfo.second << " elements of type " << typeInfo.first;
652  return success();
653 }
654 
656  unsigned addressSpace =
657  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
658  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
659  addressSpace != NVVM::kSharedMemorySpace)
660  return emitOpError("expected operands to be a source pointer in memory "
661  "space 0, 1, 3");
662 
663  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
664  getEltype()) == 0)
665  return emitOpError() << "invalid attribute combination";
666  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
667  getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
668  if (getArgs().size() != typeInfo.second)
669  return emitOpError() << "expected " << typeInfo.second << " data operands";
670  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
671  return operands.getType() != typeInfo.first;
672  }))
673  return emitOpError() << "expected data operands of type " << typeInfo.first;
674  return success();
675 }
676 
678  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
679  getLayoutB(), getEltypeA(),
680  getEltypeB()) == 0)
681  return emitOpError() << "invalid attribute combination";
682  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
683  getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
684  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
685  getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
686  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
687  getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
688  SmallVector<Type, 32> arguments;
689  arguments.append(typeInfoA.second, typeInfoA.first);
690  arguments.append(typeInfoB.second, typeInfoB.first);
691  arguments.append(typeInfoC.second, typeInfoC.first);
692  unsigned numArgs = arguments.size();
693  if (getArgs().size() != numArgs)
694  return emitOpError() << "expected " << numArgs << " arguments";
695  for (unsigned i = 0; i < numArgs; i++) {
696  if (getArgs()[i].getType() != arguments[i])
697  return emitOpError() << "expected argument " << i << " to be of type "
698  << arguments[i];
699  }
701  getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
702  if (getType() != dstType)
703  return emitOpError("expected destination type is a structure of ")
704  << typeInfoC.second << " elements of type " << typeInfoC.first;
705  return success();
706 }
707 
709  unsigned addressSpace =
710  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
711  if (addressSpace != NVVM::kSharedMemorySpace)
712  return emitOpError("expected source pointer in memory space 3");
713 
714  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
715  return emitOpError("expected num attribute to be 1, 2 or 4");
716 
717  Type i32 = IntegerType::get(getContext(), 32);
718  if (getNum() == 1 && getType() != i32)
719  return emitOpError("expected destination type is i32");
720  if (getNum() == 2 || getNum() == 4) {
722  getContext(), SmallVector<Type>(getNum(), i32));
723  if (getType() != dstType)
724  return emitOpError("expected destination type is a structure of ")
725  << getNum() << " elements of type i32";
726  }
727  return success();
728 }
729 
731  unsigned addressSpace =
732  llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
733  if (addressSpace != NVVM::kSharedMemorySpace)
734  return emitOpError("expected source pointer in memory space 3");
735 
736  int numMatrix = getSources().size();
737  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
738  return emitOpError("expected num attribute to be 1, 2 or 4");
739 
740  return success();
741 }
742 
743 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
744  if (typeA == NVVM::WGMMATypes::tf32)
745  return 8;
746  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
747  return 16;
748  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
749  return 32;
750  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
751  return 32;
752  if (typeA == NVVM::WGMMATypes::b1)
753  return 256;
754  return failure();
755 }
756 
757 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
758  NVVM::WGMMATypes typeA,
759  NVVM::WGMMATypes typeB) {
760  switch (typeA) {
761  case NVVM::WGMMATypes::f16:
762  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
763  typeB == NVVM::WGMMATypes::f16)
764  return success();
765  break;
766  case NVVM::WGMMATypes::tf32:
767  if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
768  return success();
769  break;
770  case NVVM::WGMMATypes::u8:
771  case NVVM::WGMMATypes::s8:
772  if (typeD == NVVM::WGMMATypes::s32 &&
773  (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
774  return success();
775  break;
776  case NVVM::WGMMATypes::b1:
777  if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
778  return success();
779  break;
780  case NVVM::WGMMATypes::bf16:
781  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
782  typeB == NVVM::WGMMATypes::bf16)
783  return success();
784  break;
785  case NVVM::WGMMATypes::e4m3:
786  case NVVM::WGMMATypes::e5m2:
787  if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
788  (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
789  return success();
790  break;
791  case WGMMATypes::f32:
792  case WGMMATypes::s32:
793  llvm_unreachable("unsupported input types");
794  break;
795  }
796  return failure();
797 }
798 
799 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
800  SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
801  72, 80, 88, 96, 104, 112, 120, 128,
802  136, 144, 152, 160, 168, 176, 184, 192,
803  200, 208, 216, 224, 232, 240, 248, 256};
804  SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
805  80, 96, 112, 128, 144, 160,
806  176, 192, 208, 224, 240, 256};
807  switch (typeA) {
808  case WGMMATypes::f16:
809  case WGMMATypes::tf32:
810  case WGMMATypes::bf16:
811  case WGMMATypes::e4m3:
812  case WGMMATypes::e5m2:
813  if (llvm::is_contained(allowedN, sizeN))
814  return success();
815  break;
816  case WGMMATypes::u8:
817  case WGMMATypes::s8:
818  case WGMMATypes::b1:
819  if (llvm::is_contained(allowedNshort, sizeN))
820  return success();
821  break;
822  case WGMMATypes::f32:
823  case WGMMATypes::s32:
824  llvm_unreachable("unsupported input types");
825  break;
826  }
827  return failure();
828 }
829 
831  Value outValue = getResults();
832  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
833  if (!stype)
834  return emitOpError() << "expected results to be struct";
835  int outputSize = stype.getBody().size();
836  WGMMATypes typeD = getTypeD();
837  WGMMATypes typeA = getTypeA();
838  WGMMATypes typeB = getTypeB();
839 
840  for (Type t : stype.getBody()) {
841  if (t != stype.getBody().front())
842  return emitOpError()
843  << "all elements in struct must be same type but there is " << t;
844  }
845 
846  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
847  typeD != WGMMATypes::s32) {
848  return emitOpError() << "does not support the given output type "
849  << NVVM::stringifyWGMMATypes(typeD);
850  }
851  if (typeD == WGMMATypes::s32 &&
852  (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
853  return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
854  }
855 
856  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
857  return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
858  << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
859  << NVVM::stringifyWGMMATypes(typeB)
860  << ", it is not supported.";
861  }
862 
863  // Check M
864  if (getShape().getM() != 64)
865  return emitOpError() << "shape 'm' must be 64";
866 
867  // Check K
868  FailureOr<int> allowedK = getAllowedSizeK(typeA);
869  if (failed(allowedK) || allowedK.value() != getShape().getK())
870  return emitOpError() << "shape 'k' must be " << allowedK.value()
871  << " for input type "
872  << NVVM::stringifyWGMMATypes(typeA);
873 
874  // Check N
875  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
876  return emitOpError() << "has input type "
877  << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
878  << getShape().getN() << ", it is not supported.";
879  }
880 
881  // Check transpose (only available for f16/bf16)
882  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
883  (getLayoutA() == mlir::NVVM::MMALayout::col ||
884  getLayoutB() == mlir::NVVM::MMALayout::col)) {
885  return emitOpError()
886  << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
887  << " and layout_b = " << stringifyMMALayout(getLayoutB())
888  << " for input types " << stringifyWGMMATypes(typeA) << " and "
889  << stringifyWGMMATypes(typeB)
890  << " requires transpose. However, this is only supported for: "
891  << stringifyMMATypes(MMATypes::f16) << " and "
892  << stringifyMMATypes(MMATypes::bf16);
893  }
894 
895  // Check result registers
896  int expectedOutput = 0;
897  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
898  expectedOutput = getShape().getN() / 2;
899  if (typeD == WGMMATypes::f16)
900  expectedOutput = getShape().getN() / 4;
901  if (outputSize != expectedOutput) {
902  return emitOpError() << "results " << expectedOutput
903  << ", however output struct has " << outputSize
904  << " elements";
905  }
906  // Check satfinite (only available for s32 accumulator)
907  if (typeD != WGMMATypes::s32 &&
908  getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
909  NVVM::MMAIntOverflow::satfinite) {
910  return emitOpError()
911  << " `satfinite` can be only used with s32 accumulator, however "
912  "the current accumulator is "
913  << NVVM::stringifyWGMMATypes(typeD);
914  }
915 
916  return success();
917 }
918 
919 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
920 
921  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
922  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
923 
924  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
925 
926  int expectedOutputRegisters = 0;
927  if (getTypeD() == WGMMATypes::f16)
928  expectedOutputRegisters = getShape().getN() / 4;
929  else
930  expectedOutputRegisters = getShape().getN() / 2;
931 
932  std::string ptx;
933  llvm::raw_string_ostream ss(ptx);
934 
935  ss << "{\n"
936  ".reg .pred p;\n"
937  "setp.ne.b32 p, $"
938  << ((expectedOutputRegisters * 2) + 2)
939  << ", 0;\n"
940  "wgmma.mma_async.sync.aligned.m"
941  << m << "n" << n << "k" << k << "." << outputTypeName << "."
942  << stringifyWGMMATypes(getTypeA()) << "."
943  << stringifyWGMMATypes(getTypeB());
944  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
945  NVVM::MMAIntOverflow::satfinite)
946  ss << ".satfinite";
947  ss << " {";
948  int regCnt = 0;
949  for (; regCnt < expectedOutputRegisters; ++regCnt) {
950  ss << "$" << regCnt;
951  if (regCnt != expectedOutputRegisters - 1)
952  ss << ", ";
953  }
954 
955  ss << "},";
956  // Need to map read/write registers correctly.
957  regCnt = (regCnt * 2);
958  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
959  if (getTypeD() != WGMMATypes::s32) {
960  ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
961  }
962  // Don't add transpose parameters unless needed.
963  if (isF16) {
964  ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
965  }
966  ss << ";\n"
967  << "}\n";
968  ss.flush();
969  return ptx;
970 }
971 
972 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
973  RewriterBase &rewriter,
974  llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
975  &asmValues) {
976  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
977  if (getResults())
978  asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
979  if (getInouts())
980  asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
981  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
982  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
983  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
985  if (getTypeD() != WGMMATypes::s32) {
986  asmValues.push_back(
987  {makeConstantI32(rewriter,
988  getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
990  asmValues.push_back(
991  {makeConstantI32(rewriter,
992  getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
994  }
995  if (isF16) {
996  asmValues.push_back(
997  {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
999  asmValues.push_back(
1000  {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1002  }
1003 }
1005  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1006  return emitOpError() << "async_shared fence requires space attribute";
1007  }
1008  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1009  return emitOpError() << "only async_shared fence can have space attribute";
1010  }
1011  return success();
1012 }
1013 
1015  if (getRegCount() % 8)
1016  return emitOpError("new register size must be multiple of 8");
1017  if (getRegCount() < 24 || getRegCount() > 256)
1018  return emitOpError("new register size must be in between 24 to 256");
1019  return success();
1020 }
1021 
1023  if (getNumberOfThreads() && !getBarrierId())
1024  return emitOpError(
1025  "barrier id is missing, it should be set between 0 to 15");
1026  return success();
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // NVVMDialect initialization, type parsing, and registration.
1031 //===----------------------------------------------------------------------===//
1032 
1033 // TODO: This should be the llvm.nvvm dialect once this is supported.
1034 void NVVMDialect::initialize() {
1035  addOperations<
1036 #define GET_OP_LIST
1037 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1038  >();
1039  addAttributes<
1040 #define GET_ATTRDEF_LIST
1041 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1042  >();
1043 
1044  // Support unknown operations because not all NVVM operations are
1045  // registered.
1046  allowUnknownOperations();
1047  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1048  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1049 }
1050 
1051 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1052  NamedAttribute attr) {
1053  StringAttr attrName = attr.getName();
1054  // Kernel function attribute should be attached to functions.
1055  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1056  if (!isa<LLVM::LLVMFuncOp>(op)) {
1057  return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1058  << "' attribute attached to unexpected op";
1059  }
1060  }
1061  // If maxntid and reqntid exist, it must be an array with max 3 dim
1062  if (attrName == NVVMDialect::getMaxntidAttrName() ||
1063  attrName == NVVMDialect::getReqntidAttrName()) {
1064  auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1065  if (!values || values.empty() || values.size() > 3)
1066  return op->emitError()
1067  << "'" << attrName
1068  << "' attribute must be integer array with maximum 3 index";
1069  }
1070  // If minctasm and maxnreg exist, it must be an integer attribute
1071  if (attrName == NVVMDialect::getMinctasmAttrName() ||
1072  attrName == NVVMDialect::getMaxnregAttrName()) {
1073  if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1074  return op->emitError()
1075  << "'" << attrName << "' attribute must be integer constant";
1076  }
1077 
1078  return success();
1079 }
1080 
1081 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1082  unsigned regionIndex,
1083  unsigned argIndex,
1084  NamedAttribute argAttr) {
1085  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1086  if (!funcOp)
1087  return success();
1088 
1089  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1090  StringAttr attrName = argAttr.getName();
1091  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1092  if (!isKernel) {
1093  return op->emitError()
1094  << "'" << attrName
1095  << "' attribute must be present only on kernel arguments";
1096  }
1097  if (!isa<UnitAttr>(argAttr.getValue()))
1098  return op->emitError() << "'" << attrName << "' must be a unit attribute";
1099  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1100  return op->emitError()
1101  << "'" << attrName
1102  << "' attribute requires the argument to also have attribute '"
1103  << LLVM::LLVMDialect::getByValAttrName() << "'";
1104  }
1105  }
1106 
1107  return success();
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // NVVM target attribute.
1112 //===----------------------------------------------------------------------===//
1115  int optLevel, StringRef triple, StringRef chip,
1116  StringRef features, DictionaryAttr flags,
1117  ArrayAttr files) {
1118  if (optLevel < 0 || optLevel > 3) {
1119  emitError() << "The optimization level must be a number between 0 and 3.";
1120  return failure();
1121  }
1122  if (triple.empty()) {
1123  emitError() << "The target triple cannot be empty.";
1124  return failure();
1125  }
1126  if (chip.empty()) {
1127  emitError() << "The target chip cannot be empty.";
1128  return failure();
1129  }
1130  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1131  return attr && mlir::isa<StringAttr>(attr);
1132  })) {
1133  emitError() << "All the elements in the `link` array must be strings.";
1134  return failure();
1135  }
1136  return success();
1137 }
1138 
1139 #define GET_OP_CLASSES
1140 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1141 
1142 #define GET_ATTRDEF_CLASSES
1143 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
Definition: NVVMDialect.cpp:55
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:76
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getI32Type()
Definition: Builders.cpp:83
FloatType getF16Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:216
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:51
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:955
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:38
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
uint64_t getN(LevelType lt)
Definition: Enums.h:438
uint64_t getM(LevelType lt)
Definition: Enums.h:439
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.