31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/IRBuilder.h"
37 #include "llvm/IR/IntrinsicsNVPTX.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/SourceMgr.h"
42 #include "llvm/Support/raw_ostream.h"
50 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
51 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
63 size_t numIm2ColOffsets,
65 if (tensorDims < 1 || tensorDims > 5)
66 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
74 "to use im2col mode, the tensor has to be at least 3-dimensional");
76 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
78 loc,
"im2col offsets must be 2 less than number of coordinates");
84 size_t numIm2ColOffsets = getIm2colOffsets().size();
85 bool isIm2Col = numIm2ColOffsets > 0;
87 numIm2ColOffsets, getLoc());
92 return emitError(
"Maximum 5 coordinates and dimension is supported.");
97 if (getModifier() != LoadCacheModifierKind::CG &&
98 getModifier() != LoadCacheModifierKind::CA)
99 return emitError(
"Only CG and CA cache modifiers are supported.");
100 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
101 return emitError(
"expected byte size to be either 4, 8 or 16.");
102 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
103 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
108 size_t numIm2ColOffsets = getIm2colOffsets().size();
109 bool isIm2Col = numIm2ColOffsets > 0;
111 numIm2ColOffsets, getLoc());
115 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
121 using RndMode = NVVM::FPRoundingMode;
125 return emitError(
"Relu not supported with rna rounding mode.");
132 "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
138 using RndMode = NVVM::FPRoundingMode;
139 using SatMode = NVVM::SaturationMode;
141 bool isRoundingModeRN = getRnd() == RndMode::RN;
142 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
143 bool isRoundingModeRP = getRnd() == RndMode::RP;
144 bool isSatFinite = getSat() == SatMode::SATFINITE;
146 bool hasRelu = getRelu();
149 case CVTFP8Type::E4M3:
150 case CVTFP8Type::E5M2:
151 if (!isRoundingModeRN)
152 return emitOpError(
"Only RN rounding mode is supported for conversions "
153 "from f32x2 to .e4m3x2 or .e5m2x2 types");
155 return emitOpError(
"Only SATFINITE saturation mode is supported for "
156 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
158 case CVTFP8Type::UE8M0:
159 if (!(isRoundingModeRZ || isRoundingModeRP))
160 return emitOpError(
"Only RZ or RP rounding modes are supported for "
161 "conversions from f32x2 to .ue8m0x2 type");
163 return emitOpError(
"relu not supported for conversions to .ue8m0x2 type");
170 if (
getType() == CVTFP8Type::UE8M0)
171 return emitOpError(
"Only .e4m3 or .e5m2 types are supported for "
172 "conversions from f16x2 to f8x2.");
178 using RndMode = NVVM::FPRoundingMode;
180 if (
getType() != CVTFP8Type::UE8M0)
182 "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
185 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
186 return emitOpError(
"Only RZ and RP rounding modes are supported for "
187 "conversions from bf16x2 to f8x2.");
193 if (getInitVal() != 0)
194 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
201 std::optional<mlir::NVVM::MMATypes>
202 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
205 if (operandElType.
isF64())
206 return NVVM::MMATypes::f64;
207 if (operandElType.
isF16() || operandElType == half2Type)
208 return NVVM::MMATypes::f16;
209 if (operandElType.
isF32() && isAccumulator)
210 return NVVM::MMATypes::f32;
211 if (operandElType.
isF32() && !isAccumulator)
212 return NVVM::MMATypes::tf32;
213 if (llvm::isa<IntegerType>(operandElType)) {
215 return NVVM::MMATypes::s32;
219 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
220 if (structType.getBody().empty())
222 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
229 return (type == MMATypes::u4 || type == MMATypes::s4);
233 return (type == MMATypes::u8 || type == MMATypes::s8);
238 type == MMATypes::s32;
241 MMATypes MmaOp::accumPtxType() {
242 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
243 getODSOperands(2).getTypes().front(),
true);
244 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
248 MMATypes MmaOp::resultPtxType() {
249 std::optional<mlir::NVVM::MMATypes> val =
250 inferOperandMMAType(getResult().
getType(),
true);
251 assert(val.has_value() &&
"result PTX type should always be inferrable");
257 struct OperandFragment {
258 StringRef operandName;
259 StringRef ptxTypeAttr;
261 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
262 : operandName(name), ptxTypeAttr(ptxTypeName) {}
265 std::array<OperandFragment, 3> frags{
266 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
267 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
268 OperandFragment(
"C",
"")};
270 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
272 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
273 auto &frag = frags[fragIdx];
274 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
275 for (
auto operandIdx = varOperandSpec.first;
276 operandIdx < varOperandSpec.first + varOperandSpec.second;
278 frag.regs.push_back(this->getOperand(operandIdx));
279 if (operandIdx == 0) {
280 regTypes.push_back(this->getOperand(operandIdx).
getType());
283 std::optional<MMATypes> inferredType =
284 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
286 ignoreAttrNames.push_back(frag.ptxTypeAttr);
289 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
290 p <<
" " << frag.operandName;
296 for (
const auto &frag : frags) {
297 printMmaOperand(frag);
316 std::optional<MMAIntOverflow> intOverflow,
317 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
318 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
320 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
323 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
329 if (multiplicandPtxTypes) {
335 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
337 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
341 if (multiplicandLayouts) {
351 if (intOverflow.has_value())
354 if (b1Op.has_value())
359 MmaOp::getOperandSegmentSizeAttr(),
361 static_cast<int32_t>(operandB.size()),
362 static_cast<int32_t>(operandC.size())}));
370 struct OperandFragment {
371 std::optional<MMATypes> elemtype;
377 std::array<OperandFragment, 4> frags;
382 auto parseMmaOperand = [&](StringRef operandName,
383 OperandFragment &frag) -> LogicalResult {
394 if (parseMmaOperand(
"A", frags[0]).failed())
396 if (parseMmaOperand(
"B", frags[1]).failed())
398 if (parseMmaOperand(
"C", frags[2]).failed())
413 if (operandTypes.size() != 3)
416 "expected one type for each operand segment but got " +
417 Twine(operandTypes.size()) +
" types");
419 auto &frag = frags[iter.index()];
420 frag.regTypes.resize(frag.regs.size(), iter.value());
424 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
431 frags[3].elemtype = inferOperandMMAType(resultType,
true);
433 std::array<StringRef, 2> names{
"multiplicandAPtxType",
434 "multiplicandBPtxType"};
435 for (
unsigned idx = 0; idx < names.size(); idx++) {
436 const auto &frag = frags[idx];
437 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
438 if (!frag.elemtype.has_value() && !attr.has_value()) {
441 "attribute " + names[idx] +
442 " is not provided explicitly and cannot be inferred");
444 if (!attr.has_value())
450 if (!namedAttributes.
empty())
454 static_cast<int32_t>(frags[0].regs.size()),
455 static_cast<int32_t>(frags[1].regs.size()),
456 static_cast<int32_t>(frags[2].regs.size()),
467 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
468 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
471 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
474 auto f16x2x2StructTy =
475 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
477 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
479 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
481 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
482 getShapeAttr().getK()};
488 AllowedShapes allowedShapes;
489 AllowedTypes expectedA;
490 AllowedTypes expectedB;
491 AllowedTypes expectedC;
496 if (mmaShape[0] == 16) {
498 Type multiplicandFragType;
499 switch (*getMultiplicandAPtxType()) {
502 multiplicandFragType = i32Ty;
503 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
504 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
508 multiplicandFragType = i32Ty;
509 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
510 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
514 multiplicandFragType = f16x2Ty;
515 expectedResult.push_back(f16x2x2StructTy);
516 expectedResult.push_back(f32x4StructTy);
530 return emitError(
"invalid shape or multiplicand type: " +
531 stringifyEnum(getMultiplicandAPtxType().value()));
535 expectedResult.push_back(s32x4StructTy);
536 expectedC.emplace_back(4, i32Ty);
537 multiplicandFragType = i32Ty;
539 expectedC.emplace_back(2, f16x2Ty);
540 expectedC.emplace_back(4, f32Ty);
543 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
544 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
545 expectedA.emplace_back(unitA, multiplicandFragType);
546 expectedB.emplace_back(unitB, multiplicandFragType);
547 allowedShapes.push_back({16, 8, kFactor});
548 allowedShapes.push_back({16, 8, kFactor * 2});
552 if (mmaShape[0] == 8) {
553 if (*getMultiplicandAPtxType() == MMATypes::f16) {
554 expectedA.emplace_back(2, f16x2Ty);
555 expectedB.emplace_back(2, f16x2Ty);
556 expectedResult.push_back(f16x2x4StructTy);
557 expectedResult.push_back(f32x8StructTy);
558 expectedC.emplace_back(4, f16x2Ty);
559 expectedC.emplace_back(8, f32Ty);
560 allowedShapes.push_back({8, 8, 4});
562 if (*getMultiplicandAPtxType() == MMATypes::f64) {
564 expectedA.emplace_back(1, f64Ty);
565 expectedB.emplace_back(1, f64Ty);
566 expectedC.emplace_back(2, f64Ty);
567 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
569 allowedShapes.push_back({8, 8, 4});
572 expectedA.push_back({i32Ty});
573 expectedB.push_back({i32Ty});
574 expectedC.push_back({i32Ty, i32Ty});
575 expectedResult.push_back(s32x2StructTy);
577 allowedShapes.push_back({8, 8, 32});
579 allowedShapes.push_back({8, 8, 16});
580 if (getMultiplicandAPtxType().value() == MMATypes::b1)
581 allowedShapes.push_back({8, 8, 128});
585 std::string errorMessage;
586 llvm::raw_string_ostream errorStream(errorMessage);
589 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
590 !llvm::is_contained(allowedShapes, mmaShape)) {
591 errorStream <<
"unimplemented variant for MMA shape <";
592 llvm::interleaveComma(mmaShape, errorStream);
594 return emitOpError(errorMessage);
598 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
601 auto spec = this->getODSOperandIndexAndLength(iter.index());
603 operand_type_begin() + spec.first +
605 bool match = llvm::is_contained(iter.value(), operandTySeg);
608 errorStream <<
"Could not match types for the "
609 << operandNames[iter.index()]
610 <<
" operands; expected one of ";
611 for (
const auto &x : iter.value()) {
612 errorStream << x.size() <<
"x" << x[0] <<
" ";
614 errorStream <<
"but got ";
615 llvm::interleaveComma(operandTySeg, errorStream);
616 return emitOpError(errorMessage);
621 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
622 return expectedResultType == getResult().getType();
625 <<
"Could not match allowed types for the result; expected one of ";
626 llvm::interleaveComma(expectedResult, errorStream);
627 errorStream <<
" but got " << getResult().getType();
628 return emitOpError(errorMessage);
632 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
633 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
641 if (!getIntOverflowBehavior())
642 return emitOpError(
"op requires " +
643 getIntOverflowBehaviorAttrName().strref() +
651 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
653 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
654 auto elementType = (type && type.getBody().size() == 2)
655 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
657 if (!elementType || elementType.getWidth() != 1)
658 return emitError(
"expected return type to be a two-element struct with "
659 "i1 as the second element");
664 NVVM::MMAFrag frag,
int nRow,
667 unsigned numberElements = 0;
671 if (type == NVVM::MMATypes::f16) {
673 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
677 }
else if (type == NVVM::MMATypes::f32) {
680 }
else if (type == NVVM::MMATypes::tf32) {
683 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
685 int parallelSize = 0;
686 if (frag == NVVM::MMAFrag::a)
688 if (frag == NVVM::MMAFrag::b)
692 if (parallelSize == 16)
695 else if (parallelSize == 8)
697 else if (parallelSize == 32)
699 }
else if (type == NVVM::MMATypes::s32) {
703 assert(numberElements != 0 && elementType !=
nullptr);
704 return std::make_pair(elementType, numberElements);
707 static std::pair<mlir::Type, unsigned>
711 if (frag == NVVM::MMAFrag::a) {
714 }
else if (frag == NVVM::MMAFrag::b) {
721 assert(nRow && nCol);
726 unsigned addressSpace =
727 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
730 return emitOpError(
"expected source pointer in memory "
733 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
734 getEltype(), getFrag()) == 0)
735 return emitOpError() <<
"invalid attribute combination";
738 Type dstType = LLVM::LLVMStructType::getLiteral(
741 return emitOpError(
"expected destination type is a structure of ")
742 << typeInfo.second <<
" elements of type " << typeInfo.first;
747 unsigned addressSpace =
748 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
751 return emitOpError(
"expected operands to be a source pointer in memory "
754 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
756 return emitOpError() <<
"invalid attribute combination";
759 if (getArgs().size() != typeInfo.second)
760 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
761 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
762 return operands.
getType() != typeInfo.first;
764 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
769 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
770 getLayoutB(), getEltypeA(),
772 return emitOpError() <<
"invalid attribute combination";
780 arguments.append(typeInfoA.second, typeInfoA.first);
781 arguments.append(typeInfoB.second, typeInfoB.first);
782 arguments.append(typeInfoC.second, typeInfoC.first);
783 unsigned numArgs = arguments.size();
784 if (getArgs().size() != numArgs)
785 return emitOpError() <<
"expected " << numArgs <<
" arguments";
786 for (
unsigned i = 0; i < numArgs; i++) {
787 if (getArgs()[i].
getType() != arguments[i])
788 return emitOpError() <<
"expected argument " << i <<
" to be of type "
791 Type dstType = LLVM::LLVMStructType::getLiteral(
794 return emitOpError(
"expected destination type is a structure of ")
795 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
800 unsigned addressSpace =
801 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
803 return emitOpError(
"expected source pointer in memory space 3");
805 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
806 return emitOpError(
"expected num attribute to be 1, 2 or 4");
809 if (getNum() == 1 &&
getType() != i32)
810 return emitOpError(
"expected destination type is i32");
811 if (getNum() == 2 || getNum() == 4) {
812 Type dstType = LLVM::LLVMStructType::getLiteral(
815 return emitOpError(
"expected destination type is a structure of ")
816 << getNum() <<
" elements of type i32";
822 unsigned addressSpace =
823 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
825 return emitOpError(
"expected source pointer in memory space 3");
827 int numMatrix = getSources().size();
828 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
829 return emitOpError(
"expected num attribute to be 1, 2 or 4");
835 if (typeA == NVVM::WGMMATypes::tf32)
837 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
839 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
841 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
843 if (typeA == NVVM::WGMMATypes::b1)
849 NVVM::WGMMATypes typeA,
850 NVVM::WGMMATypes typeB) {
852 case NVVM::WGMMATypes::f16:
853 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
854 typeB == NVVM::WGMMATypes::f16)
857 case NVVM::WGMMATypes::tf32:
858 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
861 case NVVM::WGMMATypes::u8:
862 case NVVM::WGMMATypes::s8:
863 if (typeD == NVVM::WGMMATypes::s32 &&
864 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
867 case NVVM::WGMMATypes::b1:
868 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
871 case NVVM::WGMMATypes::bf16:
872 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
873 typeB == NVVM::WGMMATypes::bf16)
876 case NVVM::WGMMATypes::e4m3:
877 case NVVM::WGMMATypes::e5m2:
878 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
879 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
882 case WGMMATypes::f32:
883 case WGMMATypes::s32:
884 llvm_unreachable(
"unsupported input types");
892 72, 80, 88, 96, 104, 112, 120, 128,
893 136, 144, 152, 160, 168, 176, 184, 192,
894 200, 208, 216, 224, 232, 240, 248, 256};
896 80, 96, 112, 128, 144, 160,
897 176, 192, 208, 224, 240, 256};
899 case WGMMATypes::f16:
900 case WGMMATypes::tf32:
901 case WGMMATypes::bf16:
902 case WGMMATypes::e4m3:
903 case WGMMATypes::e5m2:
904 if (llvm::is_contained(allowedN, sizeN))
910 if (llvm::is_contained(allowedNshort, sizeN))
913 case WGMMATypes::f32:
914 case WGMMATypes::s32:
915 llvm_unreachable(
"unsupported input types");
922 Value outValue = getResults();
923 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
925 return emitOpError() <<
"expected results to be struct";
926 int outputSize = stype.getBody().size();
927 WGMMATypes typeD = getTypeD();
928 WGMMATypes typeA = getTypeA();
929 WGMMATypes typeB = getTypeB();
931 for (
Type t : stype.getBody()) {
932 if (t != stype.getBody().front())
934 <<
"all elements in struct must be same type but there is " << t;
937 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
938 typeD != WGMMATypes::s32) {
939 return emitOpError() <<
"does not support the given output type "
940 << NVVM::stringifyWGMMATypes(typeD);
942 if (typeD == WGMMATypes::s32 &&
943 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
944 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
948 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
949 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
950 << NVVM::stringifyWGMMATypes(typeB)
951 <<
", it is not supported.";
956 return emitOpError() <<
"shape 'm' must be 64";
960 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
961 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
962 <<
" for input type "
963 << NVVM::stringifyWGMMATypes(typeA);
967 return emitOpError() <<
"has input type "
968 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
969 <<
getShape().getN() <<
", it is not supported.";
976 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
977 (getLayoutA() == mlir::NVVM::MMALayout::col ||
978 getLayoutB() == mlir::NVVM::MMALayout::row)) {
980 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
981 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
982 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
983 << stringifyWGMMATypes(typeB)
984 <<
" requires transpose. However, this is only supported for: "
985 << stringifyMMATypes(MMATypes::f16) <<
" and "
986 << stringifyMMATypes(MMATypes::bf16);
990 int expectedOutput = 0;
991 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
992 expectedOutput =
getShape().getN() / 2;
993 if (typeD == WGMMATypes::f16)
994 expectedOutput =
getShape().getN() / 4;
995 if (outputSize != expectedOutput) {
996 return emitOpError() <<
"results " << expectedOutput
997 <<
", however output struct has " << outputSize
1001 if (typeD != WGMMATypes::s32 &&
1002 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1003 NVVM::MMAIntOverflow::satfinite) {
1004 return emitOpError()
1005 <<
" `satfinite` can be only used with s32 accumulator, however "
1006 "the current accumulator is "
1007 << NVVM::stringifyWGMMATypes(typeD);
1013 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1016 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1018 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1020 int expectedOutputRegisters = 0;
1021 if (getTypeD() == WGMMATypes::f16)
1022 expectedOutputRegisters =
getShape().getN() / 4;
1024 expectedOutputRegisters =
getShape().getN() / 2;
1027 llvm::raw_string_ostream ss(ptx);
1032 << ((expectedOutputRegisters * 2) + 2)
1034 "wgmma.mma_async.sync.aligned.m"
1035 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1036 << stringifyWGMMATypes(getTypeA()) <<
"."
1037 << stringifyWGMMATypes(getTypeB());
1038 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1039 NVVM::MMAIntOverflow::satfinite)
1043 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1044 ss <<
"$" << regCnt;
1045 if (regCnt != expectedOutputRegisters - 1)
1051 regCnt = (regCnt * 2);
1052 ss <<
" $" << (regCnt) <<
","
1053 <<
" $" << (regCnt + 1) <<
","
1055 if (getTypeD() != WGMMATypes::s32) {
1056 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1060 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1067 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1071 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1078 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1080 if (getTypeD() != WGMMATypes::s32) {
1081 asmValues.push_back(
1082 {makeConstantI32(rewriter,
1083 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1085 asmValues.push_back(
1086 {makeConstantI32(rewriter,
1087 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1091 asmValues.push_back(
1092 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1094 asmValues.push_back(
1095 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1100 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1101 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1102 if (getKind() == NVVM::ProxyKind::GENERIC)
1103 return emitOpError() <<
"generic proxy not a supported proxy kind";
1104 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1105 return emitOpError() <<
"async_shared fence requires space attribute";
1107 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1108 return emitOpError() <<
"only async_shared fence can have space attribute";
1114 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1115 return emitOpError(
"uni-directional proxies only support generic for "
1116 "from_proxy attribute");
1118 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1119 return emitOpError(
"uni-directional proxies only support tensormap "
1120 "for to_proxy attribute");
1126 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1127 return emitOpError(
"uni-directional proxies only support generic for "
1128 "from_proxy attribute");
1130 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1131 return emitOpError(
"uni-directional proxies only support tensormap "
1132 "for to_proxy attribute");
1138 if (getRegCount() % 8)
1139 return emitOpError(
"new register size must be multiple of 8");
1140 if (getRegCount() < 24 || getRegCount() > 256)
1141 return emitOpError(
"new register size must be in between 24 to 256");
1146 if (getNumberOfThreads() && !getBarrierId())
1148 "barrier id is missing, it should be set between 0 to 15");
1153 auto mc = getMulticast();
1155 using SH = Tcgen05CpShape;
1156 using MC = Tcgen05CpMulticast;
1158 case SH::SHAPE_128x256b:
1159 case SH::SHAPE_128x128b:
1160 case SH::SHAPE_4x256b:
1162 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1164 case SH::SHAPE_64x128b:
1165 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1166 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1167 "warpx2_02_13 for tcgen05.cp Op");
1169 case SH::SHAPE_32x128b:
1170 if (mc != MC::WARPX4)
1172 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1179 if (getKind() == NVVM::MatchSyncKind::all) {
1180 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1181 if (!type || type.getBody().size() != 2 ||
1182 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1183 return emitOpError(
"match.sync 'all' returns a two element struct with "
1184 "first element as i32 and second element as i1");
1187 if (!
getType().isInteger(32)) {
1188 return emitOpError(
"match.sync 'any' returns an i32");
1195 if (getKind() == NVVM::VoteSyncKind::ballot) {
1196 if (!
getType().isInteger(32)) {
1197 return emitOpError(
"vote.sync 'ballot' returns an i32");
1200 if (!
getType().isInteger(1)) {
1201 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1208 NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
1209 llvm::IRBuilderBase &builder) {
1210 return builder.CreateBitCast(arg,
1211 llvm::Type::getInt32Ty(builder.getContext()));
1218 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1219 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1221 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1222 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1229 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1230 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1231 switch (cpAsyncOp.getSize()) {
1239 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1244 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1248 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1249 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1251 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1258 switch (tensorDims) {
1260 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1262 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1265 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1266 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1269 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1270 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1273 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1274 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1276 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1280 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1281 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1283 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1284 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1285 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1287 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1291 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1293 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1295 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1297 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1299 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1301 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1306 int tensorDims, NVVM::TMAReduxKind
kind,
bool isIm2Col) {
1307 using RedTy = NVVM::TMAReduxKind;
1326 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1331 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1332 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1333 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1335 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1336 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1337 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1340 NVVM::SaturationMode sat,
1342 using RndMode = NVVM::FPRoundingMode;
1343 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1352 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
1356 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1357 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1358 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1363 case NVVM::CVTFP6Type::E2M3:
1365 case NVVM::CVTFP6Type::E3M2:
1370 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1371 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1372 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1374 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1375 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1376 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1379 NVVM::FPRoundingMode rnd,
1380 NVVM::SaturationMode sat,
1382 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1383 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1384 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1387 case NVVM::CVTFP8Type::E4M3:
1389 case NVVM::CVTFP8Type::E5M2:
1391 case NVVM::CVTFP8Type::UE8M0:
1392 if (hasRoundingModeRZ)
1394 else if (hasRoundingModeRP)
1397 llvm_unreachable(
"Invalid conversion in CvtFloatToF8x2Op");
1400 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1401 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1402 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1407 case NVVM::CVTFP8Type::E4M3:
1409 case NVVM::CVTFP8Type::E5M2:
1412 llvm_unreachable(
"Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
1416 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1417 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1418 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1421 CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1422 NVVM::SaturationMode sat) {
1423 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1425 case NVVM::FPRoundingMode::RZ:
1427 case NVVM::FPRoundingMode::RP:
1430 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
1435 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
1438 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1439 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1442 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1446 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1447 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1449 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1450 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1463 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1464 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1465 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1466 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1475 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1476 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1477 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1479 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1480 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1481 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1484 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
1487 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1488 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1491 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
1492 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1501 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
1506 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1507 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1509 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1510 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1511 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1513 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1515 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1516 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1517 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1518 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1519 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1523 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1524 bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1525 auto srcFmt = curOp.getSrcFormat();
1526 auto mc = curOp.getMulticast();
1528 switch (curOp.getShape()) {
1529 case Tcgen05CpShape::SHAPE_128x256b:
1531 case Tcgen05CpShape::SHAPE_128x128b:
1533 case Tcgen05CpShape::SHAPE_4x256b:
1535 case Tcgen05CpShape::SHAPE_32x128b:
1537 case Tcgen05CpShape::SHAPE_64x128b:
1538 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1542 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
1549 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1551 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1557 LogicalResult result = success();
1558 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1559 result =
emitError(
"shape 16x32bx2 requires offset argument");
1561 auto resTy = getRes().getType();
1562 unsigned resLen = isa<VectorType>(resTy)
1563 ? llvm::cast<VectorType>(resTy).getNumElements()
1566 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
1567 "{1} in tcgen05.ld Op",
1568 resLen, stringifyEnum(
getShape())));
1574 LogicalResult result = success();
1575 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1576 result =
emitError(
"shape 16x32bx2 requires offset argument");
1578 auto valTy = getVal().getType();
1579 unsigned valLen = isa<VectorType>(valTy)
1580 ? llvm::cast<VectorType>(valTy).getNumElements()
1583 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
1584 "{1} in tcgen05.st Op",
1585 valLen, stringifyEnum(
getShape())));
1595 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
1596 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1597 rangeAttr.getLower(), rangeAttr.getUpper()});
1602 DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1603 NVVM::DotAccumulate4WayType b_type) {
1604 bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1605 bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1606 unsigned type = (is_a_siext << 1) | is_b_siext;
1609 return llvm::Intrinsic::nvvm_idp4a_u_u;
1611 return llvm::Intrinsic::nvvm_idp4a_u_s;
1613 return llvm::Intrinsic::nvvm_idp4a_s_u;
1615 return llvm::Intrinsic::nvvm_idp4a_s_s;
1617 llvm_unreachable(
"Invalid DP4a type");
1626 void NVVMDialect::initialize() {
1629 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1632 #define GET_ATTRDEF_LIST
1633 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1638 allowUnknownOperations();
1639 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1640 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1643 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1645 StringAttr attrName = attr.
getName();
1647 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1648 if (!isa<LLVM::LLVMFuncOp>(op)) {
1649 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1650 <<
"' attribute attached to unexpected op";
1655 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1656 attrName == NVVMDialect::getReqntidAttrName() ||
1657 attrName == NVVMDialect::getClusterDimAttrName()) {
1658 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1659 if (!values || values.empty() || values.size() > 3)
1662 <<
"' attribute must be integer array with maximum 3 index";
1666 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1667 attrName == NVVMDialect::getMaxnregAttrName() ||
1668 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1669 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1671 <<
"'" << attrName <<
"' attribute must be integer constant";
1677 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1678 unsigned regionIndex,
1681 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1685 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1686 StringAttr attrName = argAttr.
getName();
1687 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1691 <<
"' attribute must be present only on kernel arguments";
1693 if (!isa<UnitAttr>(argAttr.
getValue()))
1694 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1695 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1698 <<
"' attribute requires the argument to also have attribute '"
1699 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1711 int optLevel, StringRef triple, StringRef chip,
1712 StringRef features, DictionaryAttr flags,
1714 if (optLevel < 0 || optLevel > 3) {
1715 emitError() <<
"The optimization level must be a number between 0 and 3.";
1718 if (triple.empty()) {
1719 emitError() <<
"The target triple cannot be empty.";
1723 emitError() <<
"The target chip cannot be empty.";
1727 return mlir::isa_and_nonnull<StringAttr>(attr);
1729 emitError() <<
"All the elements in the `link` array must be strings.";
1735 #define GET_OP_CLASSES
1736 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1738 #define GET_ATTRDEF_CLASSES
1739 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)