32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/AsmParser/Parser.h"
35 #include "llvm/IR/Attributes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/IntrinsicsNVPTX.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/SourceMgr.h"
43 #include "llvm/Support/raw_ostream.h"
51 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
52 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
64 size_t numIm2ColOffsets,
66 if (tensorDims < 1 || tensorDims > 5)
67 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
75 "to use im2col mode, the tensor has to be at least 3-dimensional");
77 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
79 loc,
"im2col offsets must be 2 less than number of coordinates");
85 size_t numIm2ColOffsets = getIm2colOffsets().size();
86 bool isIm2Col = numIm2ColOffsets > 0;
88 numIm2ColOffsets, getLoc());
93 return emitError(
"Maximum 5 coordinates and dimension is supported.");
98 if (getModifier() != LoadCacheModifierKind::CG &&
99 getModifier() != LoadCacheModifierKind::CA)
100 return emitError(
"Only CG and CA cache modifiers are supported.");
101 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
102 return emitError(
"expected byte size to be either 4, 8 or 16.");
103 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
104 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
109 size_t numIm2ColOffsets = getIm2colOffsets().size();
110 bool isIm2Col = numIm2ColOffsets > 0;
112 numIm2ColOffsets, getLoc());
116 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
122 using RndMode = NVVM::FPRoundingMode;
126 return emitError(
"Relu not supported with rna rounding mode.");
133 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
139 using RndMode = NVVM::FPRoundingMode;
140 using SatMode = NVVM::SaturationMode;
142 bool isRoundingModeRN = getRnd() == RndMode::RN;
143 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
144 bool isRoundingModeRP = getRnd() == RndMode::RP;
145 bool isSatFinite = getSat() == SatMode::SATFINITE;
147 bool hasRelu = getRelu();
150 case ConvertFP8Type::E4M3:
151 case ConvertFP8Type::E5M2:
152 if (!isRoundingModeRN)
153 return emitOpError(
"Only RN rounding mode is supported for conversions "
154 "from f32x2 to .e4m3x2 or .e5m2x2 types");
156 return emitOpError(
"Only SATFINITE saturation mode is supported for "
157 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
159 case ConvertFP8Type::UE8M0:
160 if (!(isRoundingModeRZ || isRoundingModeRP))
161 return emitOpError(
"Only RZ or RP rounding modes are supported for "
162 "conversions from f32x2 to .ue8m0x2 type");
164 return emitOpError(
"relu not supported for conversions to .ue8m0x2 type");
171 if (
getType() == ConvertFP8Type::UE8M0)
172 return emitOpError(
"Only .e4m3 or .e5m2 types are supported for "
173 "conversions from f16x2 to f8x2.");
179 using RndMode = NVVM::FPRoundingMode;
181 if (
getType() != ConvertFP8Type::UE8M0)
183 "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
186 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
187 return emitOpError(
"Only RZ and RP rounding modes are supported for "
188 "conversions from bf16x2 to f8x2.");
194 if (getInitVal() != 0)
195 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
202 std::optional<mlir::NVVM::MMATypes>
203 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
206 if (operandElType.
isF64())
207 return NVVM::MMATypes::f64;
208 if (operandElType.
isF16() || operandElType == half2Type)
209 return NVVM::MMATypes::f16;
210 if (operandElType.
isF32() && isAccumulator)
211 return NVVM::MMATypes::f32;
212 if (operandElType.
isF32() && !isAccumulator)
213 return NVVM::MMATypes::tf32;
214 if (llvm::isa<IntegerType>(operandElType)) {
216 return NVVM::MMATypes::s32;
220 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
221 if (structType.getBody().empty())
223 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
230 return (type == MMATypes::u4 || type == MMATypes::s4);
234 return (type == MMATypes::u8 || type == MMATypes::s8);
239 type == MMATypes::s32;
242 MMATypes MmaOp::accumPtxType() {
243 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
244 getODSOperands(2).getTypes().front(),
true);
245 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
249 MMATypes MmaOp::resultPtxType() {
250 std::optional<mlir::NVVM::MMATypes> val =
251 inferOperandMMAType(getResult().
getType(),
true);
252 assert(val.has_value() &&
"result PTX type should always be inferrable");
258 struct OperandFragment {
259 StringRef operandName;
260 StringRef ptxTypeAttr;
262 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
263 : operandName(name), ptxTypeAttr(ptxTypeName) {}
266 std::array<OperandFragment, 3> frags{
267 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
268 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
269 OperandFragment(
"C",
"")};
271 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
273 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
274 auto &frag = frags[fragIdx];
275 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
276 for (
auto operandIdx = varOperandSpec.first;
277 operandIdx < varOperandSpec.first + varOperandSpec.second;
279 frag.regs.push_back(this->getOperand(operandIdx));
280 if (operandIdx == 0) {
281 regTypes.push_back(this->getOperand(operandIdx).
getType());
284 std::optional<MMATypes> inferredType =
285 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
287 ignoreAttrNames.push_back(frag.ptxTypeAttr);
290 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
291 p <<
" " << frag.operandName;
297 for (
const auto &frag : frags) {
298 printMmaOperand(frag);
317 std::optional<MMAIntOverflow> intOverflow,
318 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
319 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
321 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
324 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
330 if (multiplicandPtxTypes) {
336 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
338 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
342 if (multiplicandLayouts) {
352 if (intOverflow.has_value())
355 if (b1Op.has_value())
360 MmaOp::getOperandSegmentSizeAttr(),
362 static_cast<int32_t>(operandB.size()),
363 static_cast<int32_t>(operandC.size())}));
371 struct OperandFragment {
372 std::optional<MMATypes> elemtype;
378 std::array<OperandFragment, 4> frags;
383 auto parseMmaOperand = [&](StringRef operandName,
384 OperandFragment &frag) -> LogicalResult {
395 if (parseMmaOperand(
"A", frags[0]).failed())
397 if (parseMmaOperand(
"B", frags[1]).failed())
399 if (parseMmaOperand(
"C", frags[2]).failed())
414 if (operandTypes.size() != 3)
417 "expected one type for each operand segment but got " +
418 Twine(operandTypes.size()) +
" types");
420 auto &frag = frags[iter.index()];
421 frag.regTypes.resize(frag.regs.size(), iter.value());
425 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
432 frags[3].elemtype = inferOperandMMAType(resultType,
true);
434 std::array<StringRef, 2> names{
"multiplicandAPtxType",
435 "multiplicandBPtxType"};
436 for (
unsigned idx = 0; idx < names.size(); idx++) {
437 const auto &frag = frags[idx];
438 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
439 if (!frag.elemtype.has_value() && !attr.has_value()) {
442 "attribute " + names[idx] +
443 " is not provided explicitly and cannot be inferred");
445 if (!attr.has_value())
451 if (!namedAttributes.
empty())
455 static_cast<int32_t>(frags[0].regs.size()),
456 static_cast<int32_t>(frags[1].regs.size()),
457 static_cast<int32_t>(frags[2].regs.size()),
468 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
469 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
472 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
475 auto f16x2x2StructTy =
476 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
478 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
480 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
482 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
483 getShapeAttr().getK()};
489 AllowedShapes allowedShapes;
490 AllowedTypes expectedA;
491 AllowedTypes expectedB;
492 AllowedTypes expectedC;
497 if (mmaShape[0] == 16) {
499 Type multiplicandFragType;
500 switch (*getMultiplicandAPtxType()) {
503 multiplicandFragType = i32Ty;
504 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
505 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
509 multiplicandFragType = i32Ty;
510 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
511 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
515 multiplicandFragType = f16x2Ty;
516 expectedResult.push_back(f16x2x2StructTy);
517 expectedResult.push_back(f32x4StructTy);
531 return emitError(
"invalid shape or multiplicand type: " +
532 stringifyEnum(getMultiplicandAPtxType().value()));
536 expectedResult.push_back(s32x4StructTy);
537 expectedC.emplace_back(4, i32Ty);
538 multiplicandFragType = i32Ty;
540 expectedC.emplace_back(2, f16x2Ty);
541 expectedC.emplace_back(4, f32Ty);
544 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
545 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
546 expectedA.emplace_back(unitA, multiplicandFragType);
547 expectedB.emplace_back(unitB, multiplicandFragType);
548 allowedShapes.push_back({16, 8, kFactor});
549 allowedShapes.push_back({16, 8, kFactor * 2});
553 if (mmaShape[0] == 8) {
554 if (*getMultiplicandAPtxType() == MMATypes::f16) {
555 expectedA.emplace_back(2, f16x2Ty);
556 expectedB.emplace_back(2, f16x2Ty);
557 expectedResult.push_back(f16x2x4StructTy);
558 expectedResult.push_back(f32x8StructTy);
559 expectedC.emplace_back(4, f16x2Ty);
560 expectedC.emplace_back(8, f32Ty);
561 allowedShapes.push_back({8, 8, 4});
563 if (*getMultiplicandAPtxType() == MMATypes::f64) {
565 expectedA.emplace_back(1, f64Ty);
566 expectedB.emplace_back(1, f64Ty);
567 expectedC.emplace_back(2, f64Ty);
568 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
570 allowedShapes.push_back({8, 8, 4});
573 expectedA.push_back({i32Ty});
574 expectedB.push_back({i32Ty});
575 expectedC.push_back({i32Ty, i32Ty});
576 expectedResult.push_back(s32x2StructTy);
578 allowedShapes.push_back({8, 8, 32});
580 allowedShapes.push_back({8, 8, 16});
581 if (getMultiplicandAPtxType().value() == MMATypes::b1)
582 allowedShapes.push_back({8, 8, 128});
586 std::string errorMessage;
587 llvm::raw_string_ostream errorStream(errorMessage);
590 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
591 !llvm::is_contained(allowedShapes, mmaShape)) {
592 errorStream <<
"unimplemented variant for MMA shape <";
593 llvm::interleaveComma(mmaShape, errorStream);
595 return emitOpError(errorMessage);
599 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
602 auto spec = this->getODSOperandIndexAndLength(iter.index());
604 operand_type_begin() + spec.first +
606 bool match = llvm::is_contained(iter.value(), operandTySeg);
609 errorStream <<
"Could not match types for the "
610 << operandNames[iter.index()]
611 <<
" operands; expected one of ";
612 for (
const auto &x : iter.value()) {
613 errorStream << x.size() <<
"x" << x[0] <<
" ";
615 errorStream <<
"but got ";
616 llvm::interleaveComma(operandTySeg, errorStream);
617 return emitOpError(errorMessage);
622 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
623 return expectedResultType == getResult().getType();
626 <<
"Could not match allowed types for the result; expected one of ";
627 llvm::interleaveComma(expectedResult, errorStream);
628 errorStream <<
" but got " << getResult().getType();
629 return emitOpError(errorMessage);
633 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
634 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
642 if (!getIntOverflowBehavior())
643 return emitOpError(
"op requires " +
644 getIntOverflowBehaviorAttrName().strref() +
652 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
654 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
655 auto elementType = (type && type.getBody().size() == 2)
656 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
658 if (!elementType || elementType.getWidth() != 1)
659 return emitError(
"expected return type to be a two-element struct with "
660 "i1 as the second element");
665 NVVM::MMAFrag frag,
int nRow,
668 unsigned numberElements = 0;
672 if (type == NVVM::MMATypes::f16) {
674 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
678 }
else if (type == NVVM::MMATypes::f32) {
681 }
else if (type == NVVM::MMATypes::tf32) {
684 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
686 int parallelSize = 0;
687 if (frag == NVVM::MMAFrag::a)
689 if (frag == NVVM::MMAFrag::b)
693 if (parallelSize == 16)
696 else if (parallelSize == 8)
698 else if (parallelSize == 32)
700 }
else if (type == NVVM::MMATypes::s32) {
704 assert(numberElements != 0 && elementType !=
nullptr);
705 return std::make_pair(elementType, numberElements);
708 static std::pair<mlir::Type, unsigned>
712 if (frag == NVVM::MMAFrag::a) {
715 }
else if (frag == NVVM::MMAFrag::b) {
722 assert(nRow && nCol);
727 unsigned addressSpace =
728 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
731 return emitOpError(
"expected source pointer in memory "
734 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
735 getEltype(), getFrag()) == 0)
736 return emitOpError() <<
"invalid attribute combination";
739 Type dstType = LLVM::LLVMStructType::getLiteral(
742 return emitOpError(
"expected destination type is a structure of ")
743 << typeInfo.second <<
" elements of type " << typeInfo.first;
748 unsigned addressSpace =
749 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
752 return emitOpError(
"expected operands to be a source pointer in memory "
755 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
757 return emitOpError() <<
"invalid attribute combination";
760 if (getArgs().size() != typeInfo.second)
761 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
762 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
763 return operands.
getType() != typeInfo.first;
765 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
770 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
771 getLayoutB(), getEltypeA(),
773 return emitOpError() <<
"invalid attribute combination";
781 arguments.append(typeInfoA.second, typeInfoA.first);
782 arguments.append(typeInfoB.second, typeInfoB.first);
783 arguments.append(typeInfoC.second, typeInfoC.first);
784 unsigned numArgs = arguments.size();
785 if (getArgs().size() != numArgs)
786 return emitOpError() <<
"expected " << numArgs <<
" arguments";
787 for (
unsigned i = 0; i < numArgs; i++) {
788 if (getArgs()[i].
getType() != arguments[i])
789 return emitOpError() <<
"expected argument " << i <<
" to be of type "
792 Type dstType = LLVM::LLVMStructType::getLiteral(
795 return emitOpError(
"expected destination type is a structure of ")
796 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
801 unsigned addressSpace =
802 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
804 return emitOpError(
"expected source pointer in memory space 3");
806 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
807 return emitOpError(
"expected num attribute to be 1, 2 or 4");
810 if (getNum() == 1 &&
getType() != i32)
811 return emitOpError(
"expected destination type is i32");
812 if (getNum() == 2 || getNum() == 4) {
813 Type dstType = LLVM::LLVMStructType::getLiteral(
816 return emitOpError(
"expected destination type is a structure of ")
817 << getNum() <<
" elements of type i32";
823 unsigned addressSpace =
824 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
826 return emitOpError(
"expected source pointer in memory space 3");
828 int numMatrix = getSources().size();
829 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
830 return emitOpError(
"expected num attribute to be 1, 2 or 4");
836 if (typeA == NVVM::WGMMATypes::tf32)
838 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
840 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
842 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
844 if (typeA == NVVM::WGMMATypes::b1)
850 NVVM::WGMMATypes typeA,
851 NVVM::WGMMATypes typeB) {
853 case NVVM::WGMMATypes::f16:
854 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
855 typeB == NVVM::WGMMATypes::f16)
858 case NVVM::WGMMATypes::tf32:
859 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
862 case NVVM::WGMMATypes::u8:
863 case NVVM::WGMMATypes::s8:
864 if (typeD == NVVM::WGMMATypes::s32 &&
865 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
868 case NVVM::WGMMATypes::b1:
869 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
872 case NVVM::WGMMATypes::bf16:
873 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
874 typeB == NVVM::WGMMATypes::bf16)
877 case NVVM::WGMMATypes::e4m3:
878 case NVVM::WGMMATypes::e5m2:
879 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
880 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
883 case WGMMATypes::f32:
884 case WGMMATypes::s32:
885 llvm_unreachable(
"unsupported input types");
893 72, 80, 88, 96, 104, 112, 120, 128,
894 136, 144, 152, 160, 168, 176, 184, 192,
895 200, 208, 216, 224, 232, 240, 248, 256};
897 80, 96, 112, 128, 144, 160,
898 176, 192, 208, 224, 240, 256};
900 case WGMMATypes::f16:
901 case WGMMATypes::tf32:
902 case WGMMATypes::bf16:
903 case WGMMATypes::e4m3:
904 case WGMMATypes::e5m2:
905 if (llvm::is_contained(allowedN, sizeN))
911 if (llvm::is_contained(allowedNshort, sizeN))
914 case WGMMATypes::f32:
915 case WGMMATypes::s32:
916 llvm_unreachable(
"unsupported input types");
923 Value outValue = getResults();
924 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
926 return emitOpError() <<
"expected results to be struct";
927 int outputSize = stype.getBody().size();
928 WGMMATypes typeD = getTypeD();
929 WGMMATypes typeA = getTypeA();
930 WGMMATypes typeB = getTypeB();
932 for (
Type t : stype.getBody()) {
933 if (t != stype.getBody().front())
935 <<
"all elements in struct must be same type but there is " << t;
938 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
939 typeD != WGMMATypes::s32) {
940 return emitOpError() <<
"does not support the given output type "
941 << NVVM::stringifyWGMMATypes(typeD);
943 if (typeD == WGMMATypes::s32 &&
944 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
945 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
949 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
950 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
951 << NVVM::stringifyWGMMATypes(typeB)
952 <<
", it is not supported.";
957 return emitOpError() <<
"shape 'm' must be 64";
961 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
962 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
963 <<
" for input type "
964 << NVVM::stringifyWGMMATypes(typeA);
968 return emitOpError() <<
"has input type "
969 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
970 <<
getShape().getN() <<
", it is not supported.";
977 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
978 (getLayoutA() == mlir::NVVM::MMALayout::col ||
979 getLayoutB() == mlir::NVVM::MMALayout::row)) {
981 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
982 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
983 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
984 << stringifyWGMMATypes(typeB)
985 <<
" requires transpose. However, this is only supported for: "
986 << stringifyMMATypes(MMATypes::f16) <<
" and "
987 << stringifyMMATypes(MMATypes::bf16);
991 int expectedOutput = 0;
992 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
993 expectedOutput =
getShape().getN() / 2;
994 if (typeD == WGMMATypes::f16)
995 expectedOutput =
getShape().getN() / 4;
996 if (outputSize != expectedOutput) {
997 return emitOpError() <<
"results " << expectedOutput
998 <<
", however output struct has " << outputSize
1002 if (typeD != WGMMATypes::s32 &&
1003 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1004 NVVM::MMAIntOverflow::satfinite) {
1005 return emitOpError()
1006 <<
" `satfinite` can be only used with s32 accumulator, however "
1007 "the current accumulator is "
1008 << NVVM::stringifyWGMMATypes(typeD);
1014 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1017 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1019 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1021 int expectedOutputRegisters = 0;
1022 if (getTypeD() == WGMMATypes::f16)
1023 expectedOutputRegisters =
getShape().getN() / 4;
1025 expectedOutputRegisters =
getShape().getN() / 2;
1028 llvm::raw_string_ostream ss(ptx);
1033 << ((expectedOutputRegisters * 2) + 2)
1035 "wgmma.mma_async.sync.aligned.m"
1036 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1037 << stringifyWGMMATypes(getTypeA()) <<
"."
1038 << stringifyWGMMATypes(getTypeB());
1039 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1040 NVVM::MMAIntOverflow::satfinite)
1044 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1045 ss <<
"$" << regCnt;
1046 if (regCnt != expectedOutputRegisters - 1)
1052 regCnt = (regCnt * 2);
1053 ss <<
" $" << (regCnt) <<
","
1054 <<
" $" << (regCnt + 1) <<
","
1056 if (getTypeD() != WGMMATypes::s32) {
1057 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1061 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1068 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1072 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1079 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1081 if (getTypeD() != WGMMATypes::s32) {
1082 asmValues.push_back(
1083 {makeConstantI32(rewriter,
1084 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1086 asmValues.push_back(
1087 {makeConstantI32(rewriter,
1088 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1092 asmValues.push_back(
1093 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1095 asmValues.push_back(
1096 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1101 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1102 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1103 if (getKind() == NVVM::ProxyKind::GENERIC)
1104 return emitOpError() <<
"generic proxy not a supported proxy kind";
1105 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1106 return emitOpError() <<
"async_shared fence requires space attribute";
1108 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1109 return emitOpError() <<
"only async_shared fence can have space attribute";
1115 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1116 return emitOpError(
"uni-directional proxies only support generic for "
1117 "from_proxy attribute");
1119 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1120 return emitOpError(
"uni-directional proxies only support tensormap "
1121 "for to_proxy attribute");
1127 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1128 return emitOpError(
"uni-directional proxies only support generic for "
1129 "from_proxy attribute");
1131 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1132 return emitOpError(
"uni-directional proxies only support tensormap "
1133 "for to_proxy attribute");
1139 if (getRegCount() % 8)
1140 return emitOpError(
"new register size must be multiple of 8");
1141 if (getRegCount() < 24 || getRegCount() > 256)
1142 return emitOpError(
"new register size must be in between 24 to 256");
1147 if (getNumberOfThreads() && !getBarrierId())
1149 "barrier id is missing, it should be set between 0 to 15");
1154 auto mc = getMulticast();
1156 using SH = Tcgen05CpShape;
1157 using MC = Tcgen05CpMulticast;
1159 case SH::SHAPE_128x256b:
1160 case SH::SHAPE_128x128b:
1161 case SH::SHAPE_4x256b:
1163 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1165 case SH::SHAPE_64x128b:
1166 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1167 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1168 "warpx2_02_13 for tcgen05.cp Op");
1170 case SH::SHAPE_32x128b:
1171 if (mc != MC::WARPX4)
1173 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1180 if (getKind() == NVVM::MatchSyncKind::all) {
1181 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1182 if (!type || type.getBody().size() != 2 ||
1183 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1184 return emitOpError(
"match.sync 'all' returns a two element struct with "
1185 "first element as i32 and second element as i1");
1188 if (!
getType().isInteger(32)) {
1189 return emitOpError(
"match.sync 'any' returns an i32");
1196 if (getKind() == NVVM::VoteSyncKind::ballot) {
1197 if (!
getType().isInteger(32)) {
1198 return emitOpError(
"vote.sync 'ballot' returns an i32");
1201 if (!
getType().isInteger(1)) {
1202 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1210 using CacheLevel = NVVM::PrefetchCacheLevel;
1212 unsigned addressSpace =
1213 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1214 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1217 if (getCacheLevel() != CacheLevel::L1)
1218 return emitOpError(
"unsupported cache level, the only supported uniform "
1219 "cache level is L1");
1223 "prefetch to uniform cache requires a generic pointer");
1226 if (evictPriority) {
1227 if (getCacheLevel() != CacheLevel::L2)
1229 "cache eviction priority supported only for cache level L2");
1232 return emitOpError(
"cache eviction priority requires a global pointer");
1234 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1235 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1237 "unsupported cache eviction priority, only evict_last and "
1238 "evict_normal are supported");
1246 static llvm::Value *
1248 llvm::Value *result,
1250 unsigned sizeInBits,
1252 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1254 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1255 if (mask != 0xffffffffu)
1256 field = builder.CreateAnd(field, builder.getInt32(mask));
1258 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1259 field = builder.CreateShl(field, start);
1261 return builder.CreateOr(result, field);
1264 void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1266 llvm::IRBuilderBase &builder) {
1267 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1268 llvm::Value *smemDesc = builder.getInt64(0);
1273 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1275 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1281 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1285 mt.
mapValue(thisOp.getRes()) = smemDesc;
1292 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1293 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1295 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1296 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1303 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1304 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1305 switch (cpAsyncOp.getSize()) {
1313 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1318 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1322 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1323 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1325 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1332 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1337 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1341 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1342 llvm::Value *i64Unused =
1344 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1345 args.push_back(builder.getInt1(hasCacheHint));
1347 return {id, std::move(args)};
1352 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1355 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1358 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
1359 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1363 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1364 llvm::Value *i64Unused =
1366 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1367 args.push_back(builder.getInt1(hasCacheHint));
1370 if (
mlir::Value byteMask = thisOp.getByteMask()) {
1372 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1375 return {id, std::move(args)};
1380 switch (tensorDims) {
1382 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1384 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1387 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1388 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1391 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1392 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1395 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1396 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1398 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1402 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1403 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1405 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1406 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1407 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1409 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1413 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1415 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1417 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1419 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1421 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1423 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1428 int tensorDims, NVVM::TMAReduxKind
kind,
bool isIm2Col) {
1429 using RedTy = NVVM::TMAReduxKind;
1448 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1453 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1454 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1455 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1457 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1458 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1459 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1462 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1463 NVVM::SaturationMode sat,
bool hasRelu) {
1464 using RndMode = NVVM::FPRoundingMode;
1465 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1474 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
1478 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1479 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1480 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1483 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type,
bool hasRelu) {
1485 case NVVM::ConvertFP6Type::E2M3:
1487 case NVVM::ConvertFP6Type::E3M2:
1490 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
1493 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1494 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1495 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1497 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1498 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1499 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1502 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1503 NVVM::FPRoundingMode rnd,
1504 NVVM::SaturationMode sat,
bool hasRelu) {
1505 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1506 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1507 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1510 case NVVM::ConvertFP8Type::E4M3:
1512 case NVVM::ConvertFP8Type::E5M2:
1514 case NVVM::ConvertFP8Type::UE8M0:
1515 if (hasRoundingModeRZ)
1517 else if (hasRoundingModeRP)
1520 llvm_unreachable(
"Invalid conversion in CvtFloatToF8x2Op");
1523 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1524 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1525 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1528 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
bool hasRelu) {
1530 case NVVM::ConvertFP8Type::E4M3:
1532 case NVVM::ConvertFP8Type::E5M2:
1535 llvm_unreachable(
"Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1539 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1540 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1541 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1544 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1545 NVVM::SaturationMode sat) {
1546 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1548 case NVVM::FPRoundingMode::RZ:
1550 case NVVM::FPRoundingMode::RP:
1553 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
1558 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
1561 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1562 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1565 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1569 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1570 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1572 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1573 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1586 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1587 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1588 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1589 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1598 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1599 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1600 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1602 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1603 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1604 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1607 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
1610 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1611 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1614 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
1615 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1624 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
1629 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1630 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1632 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1633 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1634 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1636 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1638 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1639 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1640 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1641 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1642 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1646 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1647 bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1648 auto srcFmt = curOp.getSrcFormat();
1649 auto mc = curOp.getMulticast();
1651 switch (curOp.getShape()) {
1652 case Tcgen05CpShape::SHAPE_128x256b:
1654 case Tcgen05CpShape::SHAPE_128x128b:
1656 case Tcgen05CpShape::SHAPE_4x256b:
1658 case Tcgen05CpShape::SHAPE_32x128b:
1660 case Tcgen05CpShape::SHAPE_64x128b:
1661 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1665 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
1672 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1674 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1680 LogicalResult result = success();
1681 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1682 result =
emitError(
"shape 16x32bx2 requires offset argument");
1684 auto resTy = getRes().getType();
1685 unsigned resLen = isa<VectorType>(resTy)
1686 ? llvm::cast<VectorType>(resTy).getNumElements()
1689 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
1690 "{1} in tcgen05.ld Op",
1691 resLen, stringifyEnum(
getShape())));
1697 LogicalResult result = success();
1698 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1699 result =
emitError(
"shape 16x32bx2 requires offset argument");
1701 auto valTy = getVal().getType();
1702 unsigned valLen = isa<VectorType>(valTy)
1703 ? llvm::cast<VectorType>(valTy).getNumElements()
1706 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
1707 "{1} in tcgen05.st Op",
1708 valLen, stringifyEnum(
getShape())));
1718 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
1719 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1720 rangeAttr.getLower(), rangeAttr.getUpper()});
1725 llvm::IRBuilderBase &builder) {
1726 return builder.CreateBitCast(arg,
1727 llvm::Type::getInt32Ty(builder.getContext()));
1732 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1739 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1740 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1741 unsigned type = (isASigned << 1) | isBSigned;
1743 llvm::Intrinsic::nvvm_idp4a_u_u,
1744 llvm::Intrinsic::nvvm_idp4a_u_s,
1745 llvm::Intrinsic::nvvm_idp4a_s_u,
1746 llvm::Intrinsic::nvvm_idp4a_s_s,
1748 return {ids[type], args};
1753 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1758 args.push_back(builder.getInt1(curOp.getBHi()));
1761 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1762 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1763 unsigned type = (isASigned << 1) | isBSigned;
1765 llvm::Intrinsic::nvvm_idp2a_u_u,
1766 llvm::Intrinsic::nvvm_idp2a_u_s,
1767 llvm::Intrinsic::nvvm_idp2a_s_u,
1768 llvm::Intrinsic::nvvm_idp2a_s_s,
1770 return {ids[type], args};
1775 using CacheLevel = NVVM::PrefetchCacheLevel;
1777 NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1778 std::optional<NVVM::CacheEvictionPriority> evictPriority =
1779 op.getEvictPriority();
1780 unsigned addressSpace =
1781 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1784 if (op.getUniform() && cacheLevel == CacheLevel::L1)
1785 return llvm::Intrinsic::nvvm_prefetchu_L1;
1787 if (evictPriority && cacheLevel == CacheLevel::L2) {
1788 switch (*evictPriority) {
1789 case NVVM::CacheEvictionPriority::EvictLast:
1790 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1791 case NVVM::CacheEvictionPriority::EvictNormal:
1792 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1794 llvm_unreachable(
"Invalid cache eviction priority");
1798 switch (addressSpace) {
1800 return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1801 : llvm::Intrinsic::nvvm_prefetch_L2;
1803 return cacheLevel == CacheLevel::L1
1804 ? llvm::Intrinsic::nvvm_prefetch_global_L1
1805 : llvm::Intrinsic::nvvm_prefetch_global_L2;
1807 return cacheLevel == CacheLevel::L1
1808 ? llvm::Intrinsic::nvvm_prefetch_local_L1
1809 : llvm::Intrinsic::nvvm_prefetch_local_L2;
1811 llvm_unreachable(
"Invalid pointer address space");
1820 void NVVMDialect::initialize() {
1823 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1826 #define GET_ATTRDEF_LIST
1827 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1832 allowUnknownOperations();
1833 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1834 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1837 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1839 StringAttr attrName = attr.
getName();
1841 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1842 if (!isa<LLVM::LLVMFuncOp>(op)) {
1843 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1844 <<
"' attribute attached to unexpected op";
1849 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1850 attrName == NVVMDialect::getReqntidAttrName() ||
1851 attrName == NVVMDialect::getClusterDimAttrName()) {
1852 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1853 if (!values || values.empty() || values.size() > 3)
1856 <<
"' attribute must be integer array with maximum 3 index";
1860 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1861 attrName == NVVMDialect::getMaxnregAttrName() ||
1862 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1863 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1865 <<
"'" << attrName <<
"' attribute must be integer constant";
1871 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1872 unsigned regionIndex,
1875 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1879 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1880 StringAttr attrName = argAttr.
getName();
1881 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1885 <<
"' attribute must be present only on kernel arguments";
1887 if (!isa<UnitAttr>(argAttr.
getValue()))
1888 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1889 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1892 <<
"' attribute requires the argument to also have attribute '"
1893 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1905 int optLevel, StringRef triple, StringRef chip,
1906 StringRef features, DictionaryAttr flags,
1907 ArrayAttr files,
bool verifyTarget) {
1908 if (optLevel < 0 || optLevel > 3) {
1909 emitError() <<
"The optimization level must be a number between 0 and 3.";
1912 if (triple.empty()) {
1913 emitError() <<
"The target triple cannot be empty.";
1917 emitError() <<
"The target chip cannot be empty.";
1921 return mlir::isa_and_nonnull<StringAttr>(attr);
1923 emitError() <<
"All the elements in the `link` array must be strings.";
1929 LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
1930 if (!getVerifyTarget())
1933 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1936 "NVVM target attribute must be attached to a GPU module");
1940 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1943 "Minimum NVVM target SM version is sm_20");
1947 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1950 op->emitOpError() <<
"is not supported on " << getChip();
1951 return WalkResult::interrupt();
1960 #define GET_OP_CLASSES
1961 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1963 #define GET_ATTRDEF_CLASSES
1964 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1219::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
NVVMMemorySpace
NVVM memory space identifiers.
@ kGenericMemorySpace
Generic memory space identifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kLocalMemorySpace
Local memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)