31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48 static constexpr
unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
59 size_t numIm2ColOffsets,
61 if (tensorDims < 1 || tensorDims > 5)
62 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
70 "to use im2col mode, the tensor has to be at least 3-dimensional");
72 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
74 loc,
"im2col offsets must be 2 less than number of coordinates");
80 TMAStoreMode mode = getMode();
85 if (mode != TMAStoreMode::TILE)
86 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
88 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
93 case TMAStoreMode::TILE:
95 case TMAStoreMode::IM2COL:
97 case TMAStoreMode::TILE_SCATTER4:
99 return emitError(
"Scatter4 mode expects 5 coordinates");
105 if (getModifier() != LoadCacheModifierKind::CG &&
106 getModifier() != LoadCacheModifierKind::CA)
107 return emitError(
"Only CG and CA cache modifiers are supported.");
108 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
109 return emitError(
"expected byte size to be either 4, 8 or 16.");
110 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
111 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
118 if (tensorDims < 1 || tensorDims > 5)
119 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
121 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
122 size_t expectedIm2colOff) -> LogicalResult {
123 if (isIm2col && (tensorDims < 3))
125 <<
"to use " << stringifyEnum(mode)
126 <<
" mode, the tensor has to be at least 3-dimensional";
128 if (numIm2colOff != expectedIm2colOff)
129 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
130 <<
" (provided " << numIm2colOff <<
")";
136 case TMALoadMode::TILE:
137 return checkTMALoadParams(mode,
false, 0);
138 case TMALoadMode::IM2COL:
139 return checkTMALoadParams(mode,
true, tensorDims - 2);
140 case TMALoadMode::IM2COL_W:
141 case TMALoadMode::IM2COL_W_128:
142 return checkTMALoadParams(mode,
true, 2);
143 case TMALoadMode::TILE_GATHER4:
144 return (tensorDims == 5)
145 ? checkTMALoadParams(mode,
false, 0)
146 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
153 getMode(), getLoc());
157 TMALoadMode mode = getMode();
158 bool isCTAOnly = getIsCTAOnly();
159 if (getPredicate()) {
161 return emitError(
"Predicate is supported only for shared::cluster mode.");
162 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
164 "Predicate is supported only for Tile and Im2col modes.");
166 NVVMMemorySpace expectedAS =
167 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
170 if (AS != expectedAS)
173 ?
"Shared::cta destination requires address-space 3."
174 :
"Shared::cluster destination requires address-space 7.");
177 if (getMulticastMask())
178 return emitError(
"Multicast is not supported with shared::cta mode.");
180 return emitError(
"CTAGroup is not supported with shared::cta mode.");
185 getMode(), getLoc());
189 TMAStoreMode mode = getMode();
192 case TMAStoreMode::TILE:
194 case TMAStoreMode::IM2COL:
196 case TMAStoreMode::TILE_SCATTER4:
197 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
203 using RndMode = NVVM::FPRoundingMode;
207 return emitError(
"Relu not supported with rna rounding mode.");
214 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
222 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
223 return emitOpError(
"Only ")
226 <<
" types are supported for conversions from f32x2 to f6x2.";
232 using RndMode = NVVM::FPRoundingMode;
233 using SatMode = NVVM::SaturationMode;
235 bool isRoundingModeRN = getRnd() == RndMode::RN;
236 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
237 bool isRoundingModeRP = getRnd() == RndMode::RP;
238 bool isSatFinite = getSat() == SatMode::SATFINITE;
240 bool hasRelu = getRelu();
245 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
247 if (!isRoundingModeRN) {
248 return emitOpError(
"Only RN rounding mode is supported for "
249 "conversions from f32x2 to ")
254 return emitOpError(
"Only SATFINITE saturation mode is supported "
262 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
263 if (!(isRoundingModeRZ || isRoundingModeRP)) {
264 return emitOpError(
"Only RZ and RP rounding modes are supported for "
265 "conversions from f32x2 to ")
269 return emitOpError(
"relu not supported for conversions to ")
275 return emitOpError(
"Only ")
280 "supported for conversions from f32x2 to f8x2";
287 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
288 return emitOpError(
"Only ")
291 <<
" types are supported for conversions from f16x2 to f8x2.";
297 using RndMode = NVVM::FPRoundingMode;
299 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
301 <<
" type is supported for conversions from "
305 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
306 return emitOpError(
"Only RZ and RP rounding modes are supported for "
307 "conversions from bf16x2 to f8x2.");
315 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
316 return emitOpError(
"Only ")
318 <<
" type is supported for conversions from f32x2 to f4x2.";
326 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
327 return emitOpError(
"Only ")
330 <<
" types are supported for conversions from f8x2 to f16x2.";
337 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
338 return emitOpError(
"Only ")
340 <<
" type is supported for conversions from f8x2 to bf16x2.";
348 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
349 return emitOpError(
"Only ")
352 <<
" types are supported for conversions from f6x2 to f16x2.";
360 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
361 return emitOpError(
"Only ")
363 <<
" type is supported for conversions from f4x2 to f16x2.";
369 if (getInitVal() != 0)
370 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
375 auto eventId = getEventId();
376 auto maskedEventId = getMaskedEventId();
377 if (!maskedEventId && !eventId) {
378 return emitOpError() <<
"either `id` or `mask` must be set";
381 if (maskedEventId && eventId) {
382 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
386 if (eventId < 0 || eventId > 15) {
387 return emitOpError() <<
"`id` must be between 0 and 15";
391 return llvm::success();
397 std::optional<mlir::NVVM::MMATypes>
398 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
401 if (operandElType.
isF64())
402 return NVVM::MMATypes::f64;
403 if (operandElType.
isF16() || operandElType == half2Type)
404 return NVVM::MMATypes::f16;
405 if (operandElType.
isF32() && isAccumulator)
406 return NVVM::MMATypes::f32;
407 if (operandElType.
isF32() && !isAccumulator)
408 return NVVM::MMATypes::tf32;
409 if (llvm::isa<IntegerType>(operandElType)) {
411 return NVVM::MMATypes::s32;
415 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
416 if (structType.getBody().empty())
418 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
425 return (type == MMATypes::u4 || type == MMATypes::s4);
429 return (type == MMATypes::u8 || type == MMATypes::s8);
434 type == MMATypes::s32;
437 MMATypes MmaOp::accumPtxType() {
438 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
439 getODSOperands(2).getTypes().front(),
true);
440 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
444 MMATypes MmaOp::resultPtxType() {
445 std::optional<mlir::NVVM::MMATypes> val =
446 inferOperandMMAType(getResult().
getType(),
true);
447 assert(val.has_value() &&
"result PTX type should always be inferrable");
453 struct OperandFragment {
454 StringRef operandName;
455 StringRef ptxTypeAttr;
457 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
458 : operandName(name), ptxTypeAttr(ptxTypeName) {}
461 std::array<OperandFragment, 3> frags{
462 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
463 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
464 OperandFragment(
"C",
"")};
466 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
468 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
469 auto &frag = frags[fragIdx];
470 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
471 for (
auto operandIdx = varOperandSpec.first;
472 operandIdx < varOperandSpec.first + varOperandSpec.second;
474 frag.regs.push_back(this->getOperand(operandIdx));
475 if (operandIdx == 0) {
476 regTypes.push_back(this->getOperand(operandIdx).
getType());
479 std::optional<MMATypes> inferredType =
480 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
482 ignoreAttrNames.push_back(frag.ptxTypeAttr);
485 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
486 p <<
" " << frag.operandName;
492 for (
const auto &frag : frags) {
493 printMmaOperand(frag);
512 std::optional<MMAIntOverflow> intOverflow,
513 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
514 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
516 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
519 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
525 if (multiplicandPtxTypes) {
531 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
533 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
537 if (multiplicandLayouts) {
547 if (intOverflow.has_value())
550 if (b1Op.has_value())
555 MmaOp::getOperandSegmentSizeAttr(),
557 static_cast<int32_t>(operandB.size()),
558 static_cast<int32_t>(operandC.size())}));
566 struct OperandFragment {
567 std::optional<MMATypes> elemtype;
573 std::array<OperandFragment, 4> frags;
578 auto parseMmaOperand = [&](StringRef operandName,
579 OperandFragment &frag) -> LogicalResult {
590 if (parseMmaOperand(
"A", frags[0]).
failed())
592 if (parseMmaOperand(
"B", frags[1]).
failed())
594 if (parseMmaOperand(
"C", frags[2]).
failed())
609 if (operandTypes.size() != 3)
612 "expected one type for each operand segment but got " +
613 Twine(operandTypes.size()) +
" types");
615 auto &frag = frags[iter.index()];
616 frag.regTypes.resize(frag.regs.size(), iter.value());
620 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
627 frags[3].elemtype = inferOperandMMAType(resultType,
true);
629 std::array<StringRef, 2> names{
"multiplicandAPtxType",
630 "multiplicandBPtxType"};
631 for (
unsigned idx = 0; idx < names.size(); idx++) {
632 const auto &frag = frags[idx];
633 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
634 if (!frag.elemtype.has_value() && !attr.has_value()) {
637 "attribute " + names[idx] +
638 " is not provided explicitly and cannot be inferred");
640 if (!attr.has_value())
646 if (!namedAttributes.
empty())
650 static_cast<int32_t>(frags[0].regs.size()),
651 static_cast<int32_t>(frags[1].regs.size()),
652 static_cast<int32_t>(frags[2].regs.size()),
663 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
664 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
667 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
670 auto f16x2x2StructTy =
671 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
673 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
675 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
677 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
678 getShapeAttr().getK()};
684 AllowedShapes allowedShapes;
685 AllowedTypes expectedA;
686 AllowedTypes expectedB;
687 AllowedTypes expectedC;
692 if (mmaShape[0] == 16) {
694 Type multiplicandFragType;
695 switch (*getMultiplicandAPtxType()) {
698 multiplicandFragType = i32Ty;
699 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
700 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
704 multiplicandFragType = i32Ty;
705 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
706 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
710 multiplicandFragType = f16x2Ty;
711 expectedResult.push_back(f16x2x2StructTy);
712 expectedResult.push_back(f32x4StructTy);
726 return emitError(
"invalid shape or multiplicand type: " +
727 stringifyEnum(getMultiplicandAPtxType().value()));
731 expectedResult.push_back(s32x4StructTy);
732 expectedC.emplace_back(4, i32Ty);
733 multiplicandFragType = i32Ty;
735 expectedC.emplace_back(2, f16x2Ty);
736 expectedC.emplace_back(4, f32Ty);
739 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
740 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
741 expectedA.emplace_back(unitA, multiplicandFragType);
742 expectedB.emplace_back(unitB, multiplicandFragType);
743 allowedShapes.push_back({16, 8, kFactor});
744 allowedShapes.push_back({16, 8, kFactor * 2});
746 if (resultPtxType() != accumPtxType())
747 return emitOpError(
"ctype does not match dtype");
751 if (mmaShape[0] == 8) {
752 if (*getMultiplicandAPtxType() == MMATypes::f16) {
753 expectedA.emplace_back(2, f16x2Ty);
754 expectedB.emplace_back(2, f16x2Ty);
755 expectedResult.push_back(f16x2x4StructTy);
756 expectedResult.push_back(f32x8StructTy);
757 expectedC.emplace_back(4, f16x2Ty);
758 expectedC.emplace_back(8, f32Ty);
759 allowedShapes.push_back({8, 8, 4});
761 if (*getMultiplicandAPtxType() == MMATypes::f64) {
763 expectedA.emplace_back(1, f64Ty);
764 expectedB.emplace_back(1, f64Ty);
765 expectedC.emplace_back(2, f64Ty);
766 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
768 allowedShapes.push_back({8, 8, 4});
771 expectedA.push_back({i32Ty});
772 expectedB.push_back({i32Ty});
773 expectedC.push_back({i32Ty, i32Ty});
774 expectedResult.push_back(s32x2StructTy);
776 allowedShapes.push_back({8, 8, 32});
778 allowedShapes.push_back({8, 8, 16});
779 if (getMultiplicandAPtxType().value() == MMATypes::b1)
780 allowedShapes.push_back({8, 8, 128});
784 std::string errorMessage;
785 llvm::raw_string_ostream errorStream(errorMessage);
788 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
789 !llvm::is_contained(allowedShapes, mmaShape)) {
790 errorStream <<
"unimplemented variant for MMA shape <";
791 llvm::interleaveComma(mmaShape, errorStream);
793 return emitOpError(errorMessage);
797 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
800 auto spec = this->getODSOperandIndexAndLength(iter.index());
802 operand_type_begin() + spec.first +
804 bool match = llvm::is_contained(iter.value(), operandTySeg);
807 errorStream <<
"Could not match types for the "
808 << operandNames[iter.index()]
809 <<
" operands; expected one of ";
810 for (
const auto &x : iter.value()) {
811 errorStream << x.size() <<
"x" << x[0] <<
" ";
813 errorStream <<
"but got ";
814 llvm::interleaveComma(operandTySeg, errorStream);
815 return emitOpError(errorMessage);
820 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
821 return expectedResultType == getResult().getType();
824 <<
"Could not match allowed types for the result; expected one of ";
825 llvm::interleaveComma(expectedResult, errorStream);
826 errorStream <<
" but got " << getResult().getType();
827 return emitOpError(errorMessage);
831 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
832 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
840 if (!getIntOverflowBehavior())
841 return emitOpError(
"op requires " +
842 getIntOverflowBehaviorAttrName().strref() +
850 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
851 getMultiplicandAPtxType() == MMATypes::f16);
855 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
856 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
857 "layoutB = #nvvm.mma_layout<col> for shape <")
858 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
859 <<
"> with element types "
860 << stringifyEnum(*getMultiplicandAPtxType()) <<
" and "
861 << stringifyEnum(*getMultiplicandBPtxType())
862 <<
". Only m8n8k4 with f16 supports other layouts.";
870 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
872 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
873 auto elementType = (type && type.getBody().size() == 2)
874 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
876 if (!elementType || elementType.getWidth() != 1)
877 return emitError(
"expected return type to be a two-element struct with "
878 "i1 as the second element");
883 NVVM::MMAFrag frag,
int nRow,
886 unsigned numberElements = 0;
890 if (type == NVVM::MMATypes::f16) {
892 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
896 }
else if (type == NVVM::MMATypes::f32) {
899 }
else if (type == NVVM::MMATypes::f64) {
901 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
905 }
else if (type == NVVM::MMATypes::tf32) {
908 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
910 int parallelSize = 0;
911 if (frag == NVVM::MMAFrag::a)
913 if (frag == NVVM::MMAFrag::b)
917 if (parallelSize == 16)
920 else if (parallelSize == 8)
922 else if (parallelSize == 32)
924 }
else if (type == NVVM::MMATypes::s32) {
928 assert(numberElements != 0 && elementType !=
nullptr);
929 return std::make_pair(elementType, numberElements);
932 static std::pair<mlir::Type, unsigned>
936 if (frag == NVVM::MMAFrag::a) {
939 }
else if (frag == NVVM::MMAFrag::b) {
946 assert(nRow && nCol);
951 unsigned addressSpace =
952 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
953 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
954 addressSpace != NVVMMemorySpace::Shared)
955 return emitOpError(
"expected source pointer in memory "
958 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
959 getEltype(), getFrag()) == 0)
960 return emitOpError() <<
"invalid attribute combination";
965 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
967 return emitOpError(
"expected destination type to be f64");
971 Type dstType = LLVM::LLVMStructType::getLiteral(
974 return emitOpError(
"expected destination type is a structure of ")
975 << typeInfo.second <<
" elements of type " << typeInfo.first;
980 unsigned addressSpace =
981 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
982 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
983 addressSpace != NVVMMemorySpace::Shared)
984 return emitOpError(
"expected operands to be a source pointer in memory "
987 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
989 return emitOpError() <<
"invalid attribute combination";
992 if (getArgs().size() != typeInfo.second)
993 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
994 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
995 return operands.
getType() != typeInfo.first;
997 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
1002 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
1003 getLayoutB(), getEltypeA(),
1005 return emitOpError() <<
"invalid attribute combination";
1013 arguments.append(typeInfoA.second, typeInfoA.first);
1014 arguments.append(typeInfoB.second, typeInfoB.first);
1015 arguments.append(typeInfoC.second, typeInfoC.first);
1016 unsigned numArgs = arguments.size();
1017 if (getArgs().size() != numArgs)
1018 return emitOpError() <<
"expected " << numArgs <<
" arguments";
1019 for (
unsigned i = 0; i < numArgs; i++) {
1020 if (getArgs()[i].
getType() != arguments[i])
1021 return emitOpError() <<
"expected argument " << i <<
" to be of type "
1024 Type dstType = LLVM::LLVMStructType::getLiteral(
1027 return emitOpError(
"expected destination type is a structure of ")
1028 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
1034 if (m == 8 && n == 8) {
1035 if (num != 1 && num != 2 && num != 4) {
1036 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
1039 if (getEltType() != LdStMatrixEltType::B16) {
1040 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
1042 }
else if (m == 8 && n == 16) {
1043 if (num != 1 && num != 2 && num != 4) {
1044 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
1047 if (getLayout() != MMALayout::row) {
1048 return emitOpError(
"expected layout to be row for 8x16 matrix");
1050 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1051 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1052 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
1053 "b8x16.b6x16_p32 for 8x16 matrix");
1055 }
else if (m == 16 && n == 16) {
1056 if (num != 1 && num != 2) {
1057 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
1060 if (getLayout() != MMALayout::col) {
1061 return emitOpError(
"expected layout to be col for 16x16 matrix");
1063 if (getEltType() != LdStMatrixEltType::B8 &&
1064 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1065 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1066 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
1067 "b8x16.b6x16_p32 for 16x16 matrix");
1070 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
1074 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1075 if (numElements == 1 &&
getType() != i32)
1076 return emitOpError(
"expected destination type is i32");
1077 if (numElements == 2 || numElements == 4) {
1078 Type dstType = LLVM::LLVMStructType::getLiteral(
1081 return emitOpError(
"expected destination type is a structure of ")
1082 << numElements <<
" elements of type i32";
1089 int numMatrix = getSources().size();
1090 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1091 return emitOpError(
"expected num attribute to be 1, 2 or 4");
1094 if (m == 8 && n == 8) {
1095 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1096 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
1098 }
else if (m == 16 && n == 8) {
1099 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1100 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
1102 if (getLayout() != NVVM::MMALayout::col) {
1103 return emitOpError(
"expected layout to be col for 16x8 matrix");
1106 return emitOpError(
"expected shape to be 8x8 or 16x8");
1113 if (typeA == NVVM::WGMMATypes::tf32)
1115 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1117 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1119 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1121 if (typeA == NVVM::WGMMATypes::b1)
1127 NVVM::WGMMATypes typeA,
1128 NVVM::WGMMATypes typeB) {
1130 case NVVM::WGMMATypes::f16:
1131 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1132 typeB == NVVM::WGMMATypes::f16)
1135 case NVVM::WGMMATypes::tf32:
1136 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1139 case NVVM::WGMMATypes::u8:
1140 case NVVM::WGMMATypes::s8:
1141 if (typeD == NVVM::WGMMATypes::s32 &&
1142 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1145 case NVVM::WGMMATypes::b1:
1146 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1149 case NVVM::WGMMATypes::bf16:
1150 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1151 typeB == NVVM::WGMMATypes::bf16)
1154 case NVVM::WGMMATypes::e4m3:
1155 case NVVM::WGMMATypes::e5m2:
1156 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1157 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1160 case WGMMATypes::f32:
1161 case WGMMATypes::s32:
1162 llvm_unreachable(
"unsupported input types");
1170 72, 80, 88, 96, 104, 112, 120, 128,
1171 136, 144, 152, 160, 168, 176, 184, 192,
1172 200, 208, 216, 224, 232, 240, 248, 256};
1174 80, 96, 112, 128, 144, 160,
1175 176, 192, 208, 224, 240, 256};
1177 case WGMMATypes::f16:
1178 case WGMMATypes::tf32:
1179 case WGMMATypes::bf16:
1180 case WGMMATypes::e4m3:
1181 case WGMMATypes::e5m2:
1182 if (llvm::is_contained(allowedN, sizeN))
1185 case WGMMATypes::u8:
1186 case WGMMATypes::s8:
1187 case WGMMATypes::b1:
1188 if (llvm::is_contained(allowedNshort, sizeN))
1191 case WGMMATypes::f32:
1192 case WGMMATypes::s32:
1193 llvm_unreachable(
"unsupported input types");
1200 Value outValue = getResults();
1201 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
1203 return emitOpError() <<
"expected results to be struct";
1204 int outputSize = stype.getBody().size();
1205 WGMMATypes typeD = getTypeD();
1206 WGMMATypes typeA = getTypeA();
1207 WGMMATypes typeB = getTypeB();
1209 for (
Type t : stype.getBody()) {
1210 if (t != stype.getBody().front())
1211 return emitOpError()
1212 <<
"all elements in struct must be same type but there is " << t;
1215 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1216 typeD != WGMMATypes::s32) {
1217 return emitOpError() <<
"does not support the given output type "
1218 << NVVM::stringifyWGMMATypes(typeD);
1220 if (typeD == WGMMATypes::s32 &&
1221 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1222 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
1226 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1227 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
1228 << NVVM::stringifyWGMMATypes(typeB)
1229 <<
", it is not supported.";
1234 return emitOpError() <<
"shape 'm' must be 64";
1239 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
1240 <<
" for input type "
1241 << NVVM::stringifyWGMMATypes(typeA);
1245 return emitOpError() <<
"has input type "
1246 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
1247 <<
getShape().getN() <<
", it is not supported.";
1254 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1255 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1256 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1257 return emitOpError()
1258 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1259 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
1260 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
1261 << stringifyWGMMATypes(typeB)
1262 <<
" requires transpose. However, this is only supported for: "
1263 << stringifyMMATypes(MMATypes::f16) <<
" and "
1264 << stringifyMMATypes(MMATypes::bf16);
1268 int expectedOutput = 0;
1269 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1270 expectedOutput =
getShape().getN() / 2;
1271 if (typeD == WGMMATypes::f16)
1272 expectedOutput =
getShape().getN() / 4;
1273 if (outputSize != expectedOutput) {
1274 return emitOpError() <<
"results " << expectedOutput
1275 <<
", however output struct has " << outputSize
1279 if (typeD != WGMMATypes::s32 &&
1280 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1281 NVVM::MMAIntOverflow::satfinite) {
1282 return emitOpError()
1283 <<
" `satfinite` can be only used with s32 accumulator, however "
1284 "the current accumulator is "
1285 << NVVM::stringifyWGMMATypes(typeD);
1291 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1294 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1296 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1298 int expectedOutputRegisters = 0;
1299 if (getTypeD() == WGMMATypes::f16)
1300 expectedOutputRegisters =
getShape().getN() / 4;
1302 expectedOutputRegisters =
getShape().getN() / 2;
1305 llvm::raw_string_ostream ss(ptx);
1310 << ((expectedOutputRegisters * 2) + 2)
1312 "wgmma.mma_async.sync.aligned.m"
1313 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1314 << stringifyWGMMATypes(getTypeA()) <<
"."
1315 << stringifyWGMMATypes(getTypeB());
1316 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1317 NVVM::MMAIntOverflow::satfinite)
1321 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1322 ss <<
"$" << regCnt;
1323 if (regCnt != expectedOutputRegisters - 1)
1329 regCnt = (regCnt * 2);
1330 ss <<
" $" << (regCnt) <<
","
1331 <<
" $" << (regCnt + 1) <<
","
1333 if (getTypeD() != WGMMATypes::s32) {
1334 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1338 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1345 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1349 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1356 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1358 if (getTypeD() != WGMMATypes::s32) {
1359 asmValues.push_back(
1360 {makeConstantI32(rewriter,
1361 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1363 asmValues.push_back(
1364 {makeConstantI32(rewriter,
1365 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1369 asmValues.push_back(
1370 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1372 asmValues.push_back(
1373 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1380 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1381 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1382 if (getKind() == NVVM::ProxyKind::GENERIC)
1383 return emitOpError() <<
"generic proxy not a supported proxy kind";
1384 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1385 return emitOpError() <<
"async_shared fence requires space attribute";
1387 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1388 return emitOpError() <<
"only async_shared fence can have space attribute";
1394 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1395 return emitOpError(
"uni-directional proxies only support generic for "
1396 "from_proxy attribute");
1398 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1399 return emitOpError(
"uni-directional proxies only support tensormap "
1400 "for to_proxy attribute");
1406 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1407 return emitOpError(
"uni-directional proxies only support generic for "
1408 "from_proxy attribute");
1410 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1411 return emitOpError(
"uni-directional proxies only support tensormap "
1412 "for to_proxy attribute");
1418 if (getRegCount() % 8)
1419 return emitOpError(
"new register size must be multiple of 8");
1420 if (getRegCount() < 24 || getRegCount() > 256)
1421 return emitOpError(
"new register size must be in between 24 to 256");
1426 if (getNumberOfThreads() && !getBarrierId())
1428 "barrier id is missing, it should be set between 0 to 15");
1433 auto mc = getMulticast();
1435 using SH = Tcgen05CpShape;
1436 using MC = Tcgen05CpMulticast;
1438 case SH::SHAPE_128x256b:
1439 case SH::SHAPE_128x128b:
1440 case SH::SHAPE_4x256b:
1442 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1444 case SH::SHAPE_64x128b:
1445 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1446 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1447 "warpx2_02_13 for tcgen05.cp Op");
1449 case SH::SHAPE_32x128b:
1450 if (mc != MC::WARPX4)
1452 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1459 if (getKind() == NVVM::MatchSyncKind::all) {
1460 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1461 if (!type || type.getBody().size() != 2 ||
1462 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1463 return emitOpError(
"match.sync 'all' returns a two element struct with "
1464 "first element as i32 and second element as i1");
1467 if (!
getType().isInteger(32)) {
1468 return emitOpError(
"match.sync 'any' returns an i32");
1475 if (getKind() == NVVM::VoteSyncKind::ballot) {
1476 if (!
getType().isInteger(32)) {
1477 return emitOpError(
"vote.sync 'ballot' returns an i32");
1480 if (!
getType().isInteger(1)) {
1481 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1488 using MemSpace = NVVM::NVVMMemorySpace;
1489 using CacheLevel = NVVM::PrefetchCacheLevel;
1491 unsigned addressSpace =
1492 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1493 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1494 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1496 if (getTensormap() && cacheLevel)
1497 return emitOpError(
"cannot specify both tensormap and cache level");
1499 if (getTensormap()) {
1500 if (addressSpace != MemSpace::Generic &&
1501 addressSpace != MemSpace::Constant) {
1503 "prefetch tensormap requires a generic or constant pointer");
1506 if (evictPriority) {
1508 "prefetch tensormap does not support eviction priority");
1511 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1513 "in_param_space can only be specified for a generic pointer");
1516 }
else if (cacheLevel) {
1517 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1518 addressSpace != MemSpace::Local) {
1519 return emitOpError(
"prefetch to cache level requires a generic, global, "
1520 "or local pointer");
1524 if (*cacheLevel != CacheLevel::L1) {
1526 "unsupported cache level, the only supported uniform "
1527 "cache level is L1");
1530 if (addressSpace != MemSpace::Generic) {
1532 "prefetch to uniform cache requires a generic pointer");
1536 if (evictPriority) {
1537 if (*cacheLevel != CacheLevel::L2)
1539 "cache eviction priority supported only for cache level L2");
1541 if (addressSpace != MemSpace::Global)
1542 return emitOpError(
"cache eviction priority requires a global pointer");
1544 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1545 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1547 "unsupported cache eviction priority, only evict_last and "
1548 "evict_normal are supported");
1552 return emitOpError(
"predicate supported only on prefetch tensormap");
1556 "requires specification of either cache level or tensormap");
1563 switch (getQueryType()) {
1564 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1566 return emitOpError(
"is_canceled query type returns an i1");
1568 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1569 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1570 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1571 if (!
getType().isInteger(32)) {
1572 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
1573 "get_first_cta_id_z query types return an i32");
1582 static llvm::Value *
1584 llvm::Value *result,
1586 unsigned sizeInBits,
1588 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1590 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1591 if (mask != 0xffffffffu)
1592 field = builder.CreateAnd(field, builder.getInt32(mask));
1594 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1595 field = builder.CreateShl(field, start);
1597 return builder.CreateOr(result, field);
1600 void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1601 LLVM::ModuleTranslation &mt,
1602 llvm::IRBuilderBase &builder) {
1603 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1604 llvm::Value *smemDesc = builder.getInt64(0);
1607 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1609 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1611 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1615 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1617 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1619 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1621 mt.mapValue(thisOp.getRes()) = smemDesc;
1628 std::string NVVM::MBarrierInitOp::getPtx() {
1629 unsigned addressSpace =
1630 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1631 return (addressSpace == NVVMMemorySpace::Shared)
1632 ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
1633 : std::string(
"mbarrier.init.b64 [%0], %1;");
1641 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1642 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1643 unsigned addressSpace =
1644 llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1647 ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1648 : llvm::Intrinsic::nvvm_mbarrier_init;
1652 args.push_back(mt.lookupValue(thisOp.getAddr()));
1653 args.push_back(mt.lookupValue(thisOp.getCount()));
1655 return {id, std::move(args)};
1659 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1660 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1661 unsigned addressSpace =
1662 llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
1665 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1666 : llvm::Intrinsic::nvvm_mbarrier_inval;
1668 return {id, {mt.lookupValue(thisOp.getAddr())}};
1671 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1672 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1674 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1675 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1678 CpAsyncOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
1682 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1683 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1684 switch (cpAsyncOp.getSize()) {
1692 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1697 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1701 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1702 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1704 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1710 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1711 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1716 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1717 args.push_back(mt.lookupValue(thisOp.getSize()));
1720 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1721 llvm::Value *i64Unused =
1723 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1724 args.push_back(builder.getInt1(hasCacheHint));
1726 return {id, std::move(args)};
1730 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1731 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1735 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1736 args.push_back(mt.lookupValue(thisOp.getMbar()));
1737 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1738 args.push_back(mt.lookupValue(thisOp.getSize()));
1741 mlir::Value multicastMask = thisOp.getMulticastMask();
1742 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
1744 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1748 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1750 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1753 args.push_back(builder.getInt1(hasMulticastMask));
1754 args.push_back(builder.getInt1(hasCacheHint));
1757 llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
1759 return {id, std::move(args)};
1763 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1764 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1767 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1770 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1771 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1772 args.push_back(mt.lookupValue(thisOp.getSize()));
1775 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1776 llvm::Value *i64Unused =
1778 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1779 args.push_back(builder.getInt1(hasCacheHint));
1782 if (
mlir::Value byteMask = thisOp.getByteMask()) {
1783 args.push_back(mt.lookupValue(byteMask));
1784 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1787 return {id, std::move(args)};
1790 bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1797 for (
auto val : getOperands())
1804 CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1805 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1806 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1807 const bool isCTAOnly = thisOp.getIsCTAOnly();
1811 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1812 args.push_back(mt.lookupValue(thisOp.getMbar()));
1813 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1817 args.push_back(mt.lookupValue(v));
1819 args.push_back(mt.lookupValue(v));
1823 const bool hasMC =
static_cast<bool>(mcMask);
1824 llvm::Value *i16Zero =
1829 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1830 llvm::Value *i64Zero =
1837 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
1843 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1844 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1845 args.push_back(builder.getInt1(hasMC));
1846 args.push_back(builder.getInt1(hasCacheHint));
1850 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1851 args.push_back(builder.getInt1(hasCacheHint));
1854 constexpr
size_t numDims = 5;
1855 constexpr
size_t numModes = 5;
1856 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1857 using TableTy = std::array<rowTy, numModes>;
1858 static constexpr TableTy IDTable{
1859 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1860 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1861 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1862 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1863 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1865 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1866 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1867 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1869 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1870 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1871 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1873 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1874 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1875 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1877 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1879 static constexpr TableTy IDTableCTA{
1881 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1882 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1883 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1884 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1885 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1887 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1888 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1889 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1891 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1892 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1893 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1895 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1896 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1897 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1899 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1902 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1903 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1904 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
1905 size_t mode =
static_cast<size_t>(thisOp.getMode());
1906 size_t dim = thisOp.getCoordinates().size();
1907 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1909 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1911 return {id, std::move(args)};
1915 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1916 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1920 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1922 for (
auto v : thisOp.getCoordinates())
1923 args.push_back(mt.lookupValue(v));
1924 for (
auto v : thisOp.getIm2colOffsets())
1925 args.push_back(mt.lookupValue(v));
1928 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1929 llvm::Value *i64Unused =
1931 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1932 args.push_back(builder.getInt1(hasCacheHint));
1934 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1936 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1937 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1940 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1942 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1943 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1944 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1946 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1947 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1948 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1950 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1951 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1952 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1953 {NI, NI, NI, NI, NI,
1954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1956 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1957 "TMALoadModes must match number of rows in IDTable");
1958 size_t mode =
static_cast<size_t>(thisOp.getMode());
1959 size_t dim = thisOp.getCoordinates().size();
1961 if (
id == llvm::Intrinsic::not_intrinsic)
1962 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1964 return {id, std::move(args)};
1968 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1969 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1970 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1974 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1975 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1977 for (
auto v : thisOp.getCoordinates())
1978 args.push_back(mt.lookupValue(v));
1981 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1982 llvm::Value *i64Unused =
1984 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1985 args.push_back(builder.getInt1(hasCacheHint));
1987 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1989 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1990 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1991 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1992 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1994 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1995 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1996 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1997 {NI, NI, NI, NI, NI,
1998 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
2000 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2001 "TMAStoreModes must match number of rows in IDTable");
2002 size_t mode =
static_cast<size_t>(thisOp.getMode());
2003 size_t dim = thisOp.getCoordinates().size();
2005 if (
id == llvm::Intrinsic::not_intrinsic)
2007 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2009 return {id, std::move(args)};
2013 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2014 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2015 llvm::LLVMContext &ctx = mt.getLLVMContext();
2022 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2023 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2025 for (
Value v : thisOp.getCoordinates())
2026 args.push_back(mt.lookupValue(v));
2029 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2030 llvm::Value *i64ZeroValue =
2032 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
2033 args.push_back(builder.getInt1(hasCacheHint));
2037 constexpr
unsigned numRedKinds = 8;
2038 constexpr
unsigned numLayouts = 2;
2039 constexpr
unsigned maxDim = 5;
2040 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2041 using layoutTable = std::array<row, numLayouts>;
2042 using fullTable = std::array<layoutTable, numRedKinds>;
2043 static constexpr fullTable IDTable{
2046 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2047 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2048 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2049 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2050 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2052 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2053 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2054 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2057 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2058 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2059 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2060 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2061 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2063 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2064 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2065 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2068 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2069 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2070 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2071 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2072 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2074 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2075 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2076 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2079 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2080 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2081 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2082 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2083 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2085 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2086 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2087 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2090 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2091 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2092 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2093 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2094 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2096 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2097 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2098 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2101 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2102 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2104 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2105 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2107 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2109 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2115 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2116 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2119 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2120 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2123 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2126 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2127 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2129 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2132 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2134 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2135 "TMAReduxKinds must match number of rows in IDTable");
2137 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
2138 size_t mode =
static_cast<size_t>(thisOp.getMode());
2139 size_t dim = thisOp.getCoordinates().size();
2141 assert(redKind < IDTable.size() &&
2142 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2143 assert(mode < IDTable[redKind].size() &&
2144 "Invalid mode for CpAsyncBulkTensorReduceOp");
2145 assert(dim < IDTable[redKind][mode].size() &&
2146 "Invalid dim for CpAsyncBulkTensorReduceOp");
2151 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2153 return {intrinsicID, std::move(args)};
2158 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2159 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2160 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2162 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2163 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2164 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2167 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2168 NVVM::SaturationMode sat,
bool hasRelu) {
2169 using RndMode = NVVM::FPRoundingMode;
2170 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2179 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
2184 ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2185 LLVM::ModuleTranslation &mt,
2186 llvm::IRBuilderBase &builder) {
2188 args.push_back(mt.lookupValue(op.getA()));
2189 args.push_back(mt.lookupValue(op.getB()));
2191 bool hasRelu = op.getRelu();
2194 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2195 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2197 return {intId, std::move(args)};
2200 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2201 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2202 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2207 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2210 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2214 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
2215 return llvm::Intrinsic::not_intrinsic;
2219 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2220 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2221 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2223 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2224 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2225 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2228 ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2229 NVVM::SaturationMode sat,
bool hasRelu) {
2230 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2231 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2232 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2235 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2238 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2241 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2242 if (hasRoundingModeRZ)
2244 else if (hasRoundingModeRP)
2247 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2250 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2251 return llvm::Intrinsic::not_intrinsic;
2255 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2256 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2257 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2262 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2265 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2269 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
2270 return llvm::Intrinsic::not_intrinsic;
2274 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2275 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2276 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2279 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2280 NVVM::SaturationMode sat) {
2281 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2283 case NVVM::FPRoundingMode::RZ:
2285 case NVVM::FPRoundingMode::RP:
2288 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
2293 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2294 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2296 bool hasRelu = curOp.getRelu();
2300 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2301 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2302 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2304 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2305 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2306 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2309 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
2310 return llvm::Intrinsic::not_intrinsic;
2313 llvm::Value *packedI16 =
2314 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2315 llvm::Type::getInt16Ty(builder.getContext()));
2317 return {intId, {packedI16}};
2321 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2322 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2325 llvm::Value *packedI16 =
2326 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2327 llvm::Type::getInt16Ty(builder.getContext()));
2329 return {intId, {packedI16}};
2333 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2334 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2336 bool hasRelu = curOp.getRelu();
2340 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2341 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2342 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2344 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2345 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2346 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2349 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
2350 return llvm::Intrinsic::not_intrinsic;
2353 llvm::Value *packedI16 =
2354 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2355 llvm::Type::getInt16Ty(builder.getContext()));
2357 return {intId, {packedI16}};
2361 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2362 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2364 bool hasRelu = curOp.getRelu();
2368 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2369 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2370 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2373 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
2374 return llvm::Intrinsic::not_intrinsic;
2377 llvm::Value *extendedI16 =
2378 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
2379 llvm::Type::getInt16Ty(builder.getContext()));
2381 return {intId, {extendedI16}};
2385 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
2386 LLVM::ModuleTranslation &mt,
2388 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2389 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2391 bool isShared = as == NVVMMemorySpace::Shared;
2392 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2396 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2397 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2399 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2400 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2404 args.push_back(mt.lookupValue(curOp.getAddr()));
2405 args.push_back(mt.lookupValue(curOp.getNCols()));
2411 Operation &op, LLVM::ModuleTranslation &mt,
2413 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2414 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2415 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2416 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2419 args.push_back(mt.lookupValue(curOp.getTaddr()));
2420 args.push_back(mt.lookupValue(curOp.getNCols()));
2425 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2426 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2427 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2429 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2430 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2431 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2434 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
2435 LLVM::ModuleTranslation &mt,
2437 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2438 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2440 bool isShared = as == NVVMMemorySpace::Shared;
2441 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
2442 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2449 args.push_back(mt.lookupValue(curOp.getAddr()));
2451 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2456 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2457 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2459 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2460 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2461 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2463 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2465 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2466 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2467 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2468 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2469 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2473 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2474 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2475 auto srcFmt = curOp.getSrcFormat();
2476 auto mc = curOp.getMulticast();
2478 switch (curOp.getShape()) {
2479 case Tcgen05CpShape::SHAPE_128x256b:
2481 case Tcgen05CpShape::SHAPE_128x128b:
2483 case Tcgen05CpShape::SHAPE_4x256b:
2485 case Tcgen05CpShape::SHAPE_32x128b:
2487 case Tcgen05CpShape::SHAPE_64x128b:
2488 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2492 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
2499 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2501 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2507 LogicalResult result = success();
2508 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2509 result =
emitError(
"shape 16x32bx2 requires offset argument");
2511 auto resTy = getRes().getType();
2512 unsigned resLen = isa<VectorType>(resTy)
2513 ? llvm::cast<VectorType>(resTy).getNumElements()
2516 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
2517 "{1} in tcgen05.ld Op",
2518 resLen, stringifyEnum(
getShape())));
2524 LogicalResult result = success();
2525 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2526 result =
emitError(
"shape 16x32bx2 requires offset argument");
2528 auto valTy = getVal().getType();
2529 unsigned valLen = isa<VectorType>(valTy)
2530 ? llvm::cast<VectorType>(valTy).getNumElements()
2533 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
2534 "{1} in tcgen05.st Op",
2535 valLen, stringifyEnum(
getShape())));
2545 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
2546 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2547 rangeAttr.getLower(), rangeAttr.getUpper()});
2553 static LogicalResult
2555 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2559 const llvm::APInt &lower = rangeAttr->getLower();
2560 const llvm::APInt &upper = rangeAttr->getUpper();
2563 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2564 unsigned bitWidth = lower.getBitWidth();
2565 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2566 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2568 "invalid range attribute: Lower == Upper, but they aren't min (")
2571 <<
") value! This is an invalid constant range.";
2578 llvm::IRBuilderBase &builder) {
2579 return builder.CreateBitCast(arg,
2580 llvm::Type::getInt32Ty(builder.getContext()));
2584 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2585 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2588 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2589 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2590 args.push_back(mt.lookupValue(curOp.getC()));
2592 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2593 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2594 unsigned type = (isASigned << 1) | isBSigned;
2596 llvm::Intrinsic::nvvm_idp4a_u_u,
2597 llvm::Intrinsic::nvvm_idp4a_u_s,
2598 llvm::Intrinsic::nvvm_idp4a_s_u,
2599 llvm::Intrinsic::nvvm_idp4a_s_s,
2601 return {ids[type], args};
2605 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2606 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2609 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2610 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2611 args.push_back(builder.getInt1(curOp.getBHi()));
2612 args.push_back(mt.lookupValue(curOp.getC()));
2614 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2615 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2616 unsigned type = (isASigned << 1) | isBSigned;
2618 llvm::Intrinsic::nvvm_idp2a_u_u,
2619 llvm::Intrinsic::nvvm_idp2a_u_s,
2620 llvm::Intrinsic::nvvm_idp2a_s_u,
2621 llvm::Intrinsic::nvvm_idp2a_s_s,
2623 return {ids[type], args};
2627 llvm::IRBuilderBase &builder) {
2628 return builder.CreateAddrSpaceCast(
2631 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2635 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2636 LLVM::ModuleTranslation &mt,
2637 llvm::IRBuilderBase &builder) {
2638 using MemSpace = NVVM::NVVMMemorySpace;
2639 using CacheLevel = NVVM::PrefetchCacheLevel;
2641 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2642 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2643 op.getEvictPriority();
2644 unsigned addressSpace =
2645 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2649 llvm::Value *addr = mt.lookupValue(op.getAddr());
2653 if (op.getTensormap())
2654 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2656 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
2658 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2659 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2661 if (evictPriority && *cacheLevel == CacheLevel::L2) {
2662 switch (*evictPriority) {
2663 case NVVM::CacheEvictionPriority::EvictLast:
2664 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2665 case NVVM::CacheEvictionPriority::EvictNormal:
2666 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2668 llvm_unreachable(
"Invalid cache eviction priority");
2672 switch (
static_cast<MemSpace
>(addressSpace)) {
2673 case MemSpace::Generic:
2674 return *cacheLevel == CacheLevel::L1
2677 case MemSpace::Global:
2678 return *cacheLevel == CacheLevel::L1
2680 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2682 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2683 case MemSpace::Local:
2684 return *cacheLevel == CacheLevel::L1
2686 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2688 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2690 llvm_unreachable(
"Invalid pointer address space");
2694 bool NVVM::InlinePtxOp::getAsmValues(
2698 for (
auto arg : getReadWriteArgs())
2700 for (
auto arg : getResults())
2702 for (
auto arg : getReadOnlyArgs())
2709 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2710 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2711 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2713 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2714 args.push_back(mt.lookupValue(curOp.getMbarrier()));
2717 curOp.getMulticast()
2719 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2720 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2722 return {intrinsicID, args};
2725 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2726 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2727 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2729 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2733 switch (curOp.getQueryType()) {
2734 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2736 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2738 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2739 intrinsicID = llvm::Intrinsic::
2740 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2742 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2743 intrinsicID = llvm::Intrinsic::
2744 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2746 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2747 intrinsicID = llvm::Intrinsic::
2748 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2751 return {intrinsicID, args};
2759 void NVVMDialect::initialize() {
2762 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2765 #define GET_ATTRDEF_LIST
2766 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2771 allowUnknownOperations();
2772 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2773 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2776 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
2778 StringAttr attrName = attr.
getName();
2780 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2781 if (!isa<LLVM::LLVMFuncOp>(op)) {
2782 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
2783 <<
"' attribute attached to unexpected op";
2788 if (attrName == NVVMDialect::getMaxntidAttrName() ||
2789 attrName == NVVMDialect::getReqntidAttrName() ||
2790 attrName == NVVMDialect::getClusterDimAttrName()) {
2791 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
2792 if (!values || values.empty() || values.size() > 3) {
2795 <<
"' attribute must be integer array with maximum 3 index";
2800 if (attrName == NVVMDialect::getMinctasmAttrName() ||
2801 attrName == NVVMDialect::getMaxnregAttrName() ||
2802 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2803 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
2805 <<
"'" << attrName <<
"' attribute must be integer constant";
2809 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2810 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
2811 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
2813 <<
"'" << attrName <<
"' attribute must be used along with "
2814 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
2815 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
2822 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
2823 unsigned regionIndex,
2826 auto funcOp = dyn_cast<FunctionOpInterface>(op);
2830 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
2831 StringAttr attrName = argAttr.
getName();
2832 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2836 <<
"' attribute must be present only on kernel arguments";
2838 if (!isa<UnitAttr>(argAttr.
getValue()))
2839 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
2840 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2843 <<
"' attribute requires the argument to also have attribute '"
2844 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
2855 unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
2856 return static_cast<unsigned>(getValue());
2859 bool NVVMMemorySpaceAttr::isValidLoad(
2860 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2861 const ::mlir::DataLayout *dataLayout,
2867 bool NVVMMemorySpaceAttr::isValidStore(
2868 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2869 const ::mlir::DataLayout *dataLayout,
2875 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2876 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
2877 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2880 assert(
false &&
"unimplemented, see TODO in the source.");
2884 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2885 Type type, ptr::AtomicOrdering successOrdering,
2886 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2887 const ::mlir::DataLayout *dataLayout,
2890 assert(
false &&
"unimplemented, see TODO in the source.");
2894 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2898 assert(
false &&
"unimplemented, see TODO in the source.");
2902 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2907 assert(
false &&
"unimplemented, see TODO in the source.");
2916 int optLevel, StringRef triple, StringRef chip,
2917 StringRef features, DictionaryAttr flags,
2918 ArrayAttr files,
bool verifyTarget) {
2919 if (optLevel < 0 || optLevel > 3) {
2920 emitError() <<
"The optimization level must be a number between 0 and 3.";
2923 if (triple.empty()) {
2924 emitError() <<
"The target triple cannot be empty.";
2928 emitError() <<
"The target chip cannot be empty.";
2932 return mlir::isa_and_nonnull<StringAttr>(attr);
2934 emitError() <<
"All the elements in the `link` array must be strings.";
2940 LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
2941 if (!getVerifyTarget())
2944 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2947 "NVVM target attribute must be attached to a GPU module");
2951 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2954 "Minimum NVVM target SM version is sm_20");
2958 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2961 op->emitOpError() <<
"is not supported on " << getChip();
2962 return WalkResult::interrupt();
2971 #define GET_OP_CLASSES
2972 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2974 #define GET_ATTRDEF_CLASSES
2975 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static std::string toString(bytecode::Section::ID sectionID)
Stringify the given section ID.
static MLIRContext * getContext(OpFoldResult val)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static constexpr unsigned notIntrinsic
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)