31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/raw_ostream.h"
44 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
45 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
57 size_t numIm2ColOffsets,
59 if (tensorDims < 1 || tensorDims > 5)
60 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
68 "to use im2col mode, the tensor has to be at least 3-dimensional");
70 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
72 loc,
"im2col offsets must be 2 less than number of coordinates");
78 size_t numIm2ColOffsets = getIm2colOffsets().size();
79 bool isIm2Col = numIm2ColOffsets > 0;
81 numIm2ColOffsets, getLoc());
86 return emitError(
"Maximum 5 coordinates and dimension is supported.");
91 if (getModifier() != LoadCacheModifierKind::CG &&
92 getModifier() != LoadCacheModifierKind::CA)
93 return emitError(
"Only CG and CA cache modifiers are supported.");
94 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
95 return emitError(
"expected byte size to be either 4, 8 or 16.");
96 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
97 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
102 size_t numIm2ColOffsets = getIm2colOffsets().size();
103 bool isIm2Col = numIm2ColOffsets > 0;
105 numIm2ColOffsets, getLoc());
109 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
115 using RndMode = NVVM::FPRoundingMode;
119 return emitError(
"Relu not supported with rna rounding mode.");
126 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
132 using RndMode = NVVM::FPRoundingMode;
133 using SatMode = NVVM::SaturationMode;
135 bool isRoundingModeRN = getRnd() == RndMode::RN;
136 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
137 bool isRoundingModeRP = getRnd() == RndMode::RP;
138 bool isSatFinite = getSat() == SatMode::SATFINITE;
140 bool hasRelu = getRelu();
143 case ConvertFP8Type::E4M3:
144 case ConvertFP8Type::E5M2:
145 if (!isRoundingModeRN)
146 return emitOpError(
"Only RN rounding mode is supported for conversions "
147 "from f32x2 to .e4m3x2 or .e5m2x2 types");
149 return emitOpError(
"Only SATFINITE saturation mode is supported for "
150 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
152 case ConvertFP8Type::UE8M0:
153 if (!(isRoundingModeRZ || isRoundingModeRP))
154 return emitOpError(
"Only RZ or RP rounding modes are supported for "
155 "conversions from f32x2 to .ue8m0x2 type");
157 return emitOpError(
"relu not supported for conversions to .ue8m0x2 type");
164 if (
getType() == ConvertFP8Type::UE8M0)
165 return emitOpError(
"Only .e4m3 or .e5m2 types are supported for "
166 "conversions from f16x2 to f8x2.");
172 using RndMode = NVVM::FPRoundingMode;
174 if (
getType() != ConvertFP8Type::UE8M0)
176 "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
179 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
180 return emitOpError(
"Only RZ and RP rounding modes are supported for "
181 "conversions from bf16x2 to f8x2.");
187 if (getInitVal() != 0)
188 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
195 std::optional<mlir::NVVM::MMATypes>
196 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
199 if (operandElType.
isF64())
200 return NVVM::MMATypes::f64;
201 if (operandElType.
isF16() || operandElType == half2Type)
202 return NVVM::MMATypes::f16;
203 if (operandElType.
isF32() && isAccumulator)
204 return NVVM::MMATypes::f32;
205 if (operandElType.
isF32() && !isAccumulator)
206 return NVVM::MMATypes::tf32;
207 if (llvm::isa<IntegerType>(operandElType)) {
209 return NVVM::MMATypes::s32;
213 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
214 if (structType.getBody().empty())
216 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
223 return (type == MMATypes::u4 || type == MMATypes::s4);
227 return (type == MMATypes::u8 || type == MMATypes::s8);
232 type == MMATypes::s32;
235 MMATypes MmaOp::accumPtxType() {
236 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
237 getODSOperands(2).getTypes().front(),
true);
238 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
242 MMATypes MmaOp::resultPtxType() {
243 std::optional<mlir::NVVM::MMATypes> val =
244 inferOperandMMAType(getResult().
getType(),
true);
245 assert(val.has_value() &&
"result PTX type should always be inferrable");
251 struct OperandFragment {
252 StringRef operandName;
253 StringRef ptxTypeAttr;
255 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
256 : operandName(name), ptxTypeAttr(ptxTypeName) {}
259 std::array<OperandFragment, 3> frags{
260 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
261 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
262 OperandFragment(
"C",
"")};
264 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
266 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
267 auto &frag = frags[fragIdx];
268 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
269 for (
auto operandIdx = varOperandSpec.first;
270 operandIdx < varOperandSpec.first + varOperandSpec.second;
272 frag.regs.push_back(this->getOperand(operandIdx));
273 if (operandIdx == 0) {
274 regTypes.push_back(this->getOperand(operandIdx).
getType());
277 std::optional<MMATypes> inferredType =
278 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
280 ignoreAttrNames.push_back(frag.ptxTypeAttr);
283 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
284 p <<
" " << frag.operandName;
290 for (
const auto &frag : frags) {
291 printMmaOperand(frag);
310 std::optional<MMAIntOverflow> intOverflow,
311 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
312 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
314 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
317 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
323 if (multiplicandPtxTypes) {
329 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
331 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
335 if (multiplicandLayouts) {
345 if (intOverflow.has_value())
348 if (b1Op.has_value())
353 MmaOp::getOperandSegmentSizeAttr(),
355 static_cast<int32_t>(operandB.size()),
356 static_cast<int32_t>(operandC.size())}));
364 struct OperandFragment {
365 std::optional<MMATypes> elemtype;
371 std::array<OperandFragment, 4> frags;
376 auto parseMmaOperand = [&](StringRef operandName,
377 OperandFragment &frag) -> LogicalResult {
388 if (parseMmaOperand(
"A", frags[0]).failed())
390 if (parseMmaOperand(
"B", frags[1]).failed())
392 if (parseMmaOperand(
"C", frags[2]).failed())
407 if (operandTypes.size() != 3)
410 "expected one type for each operand segment but got " +
411 Twine(operandTypes.size()) +
" types");
413 auto &frag = frags[iter.index()];
414 frag.regTypes.resize(frag.regs.size(), iter.value());
418 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
425 frags[3].elemtype = inferOperandMMAType(resultType,
true);
427 std::array<StringRef, 2> names{
"multiplicandAPtxType",
428 "multiplicandBPtxType"};
429 for (
unsigned idx = 0; idx < names.size(); idx++) {
430 const auto &frag = frags[idx];
431 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
432 if (!frag.elemtype.has_value() && !attr.has_value()) {
435 "attribute " + names[idx] +
436 " is not provided explicitly and cannot be inferred");
438 if (!attr.has_value())
444 if (!namedAttributes.
empty())
448 static_cast<int32_t>(frags[0].regs.size()),
449 static_cast<int32_t>(frags[1].regs.size()),
450 static_cast<int32_t>(frags[2].regs.size()),
461 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
462 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
465 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
468 auto f16x2x2StructTy =
469 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
471 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
473 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
475 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
476 getShapeAttr().getK()};
482 AllowedShapes allowedShapes;
483 AllowedTypes expectedA;
484 AllowedTypes expectedB;
485 AllowedTypes expectedC;
490 if (mmaShape[0] == 16) {
492 Type multiplicandFragType;
493 switch (*getMultiplicandAPtxType()) {
496 multiplicandFragType = i32Ty;
497 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
498 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
502 multiplicandFragType = i32Ty;
503 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
504 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
508 multiplicandFragType = f16x2Ty;
509 expectedResult.push_back(f16x2x2StructTy);
510 expectedResult.push_back(f32x4StructTy);
524 return emitError(
"invalid shape or multiplicand type: " +
525 stringifyEnum(getMultiplicandAPtxType().value()));
529 expectedResult.push_back(s32x4StructTy);
530 expectedC.emplace_back(4, i32Ty);
531 multiplicandFragType = i32Ty;
533 expectedC.emplace_back(2, f16x2Ty);
534 expectedC.emplace_back(4, f32Ty);
537 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
538 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
539 expectedA.emplace_back(unitA, multiplicandFragType);
540 expectedB.emplace_back(unitB, multiplicandFragType);
541 allowedShapes.push_back({16, 8, kFactor});
542 allowedShapes.push_back({16, 8, kFactor * 2});
546 if (mmaShape[0] == 8) {
547 if (*getMultiplicandAPtxType() == MMATypes::f16) {
548 expectedA.emplace_back(2, f16x2Ty);
549 expectedB.emplace_back(2, f16x2Ty);
550 expectedResult.push_back(f16x2x4StructTy);
551 expectedResult.push_back(f32x8StructTy);
552 expectedC.emplace_back(4, f16x2Ty);
553 expectedC.emplace_back(8, f32Ty);
554 allowedShapes.push_back({8, 8, 4});
556 if (*getMultiplicandAPtxType() == MMATypes::f64) {
558 expectedA.emplace_back(1, f64Ty);
559 expectedB.emplace_back(1, f64Ty);
560 expectedC.emplace_back(2, f64Ty);
561 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
563 allowedShapes.push_back({8, 8, 4});
566 expectedA.push_back({i32Ty});
567 expectedB.push_back({i32Ty});
568 expectedC.push_back({i32Ty, i32Ty});
569 expectedResult.push_back(s32x2StructTy);
571 allowedShapes.push_back({8, 8, 32});
573 allowedShapes.push_back({8, 8, 16});
574 if (getMultiplicandAPtxType().value() == MMATypes::b1)
575 allowedShapes.push_back({8, 8, 128});
579 std::string errorMessage;
580 llvm::raw_string_ostream errorStream(errorMessage);
583 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
584 !llvm::is_contained(allowedShapes, mmaShape)) {
585 errorStream <<
"unimplemented variant for MMA shape <";
586 llvm::interleaveComma(mmaShape, errorStream);
588 return emitOpError(errorMessage);
592 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
595 auto spec = this->getODSOperandIndexAndLength(iter.index());
597 operand_type_begin() + spec.first +
599 bool match = llvm::is_contained(iter.value(), operandTySeg);
602 errorStream <<
"Could not match types for the "
603 << operandNames[iter.index()]
604 <<
" operands; expected one of ";
605 for (
const auto &x : iter.value()) {
606 errorStream << x.size() <<
"x" << x[0] <<
" ";
608 errorStream <<
"but got ";
609 llvm::interleaveComma(operandTySeg, errorStream);
610 return emitOpError(errorMessage);
615 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
616 return expectedResultType == getResult().getType();
619 <<
"Could not match allowed types for the result; expected one of ";
620 llvm::interleaveComma(expectedResult, errorStream);
621 errorStream <<
" but got " << getResult().getType();
622 return emitOpError(errorMessage);
626 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
627 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
635 if (!getIntOverflowBehavior())
636 return emitOpError(
"op requires " +
637 getIntOverflowBehaviorAttrName().strref() +
645 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
647 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
648 auto elementType = (type && type.getBody().size() == 2)
649 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
651 if (!elementType || elementType.getWidth() != 1)
652 return emitError(
"expected return type to be a two-element struct with "
653 "i1 as the second element");
658 NVVM::MMAFrag frag,
int nRow,
661 unsigned numberElements = 0;
665 if (type == NVVM::MMATypes::f16) {
667 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
671 }
else if (type == NVVM::MMATypes::f32) {
674 }
else if (type == NVVM::MMATypes::tf32) {
677 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
679 int parallelSize = 0;
680 if (frag == NVVM::MMAFrag::a)
682 if (frag == NVVM::MMAFrag::b)
686 if (parallelSize == 16)
689 else if (parallelSize == 8)
691 else if (parallelSize == 32)
693 }
else if (type == NVVM::MMATypes::s32) {
697 assert(numberElements != 0 && elementType !=
nullptr);
698 return std::make_pair(elementType, numberElements);
701 static std::pair<mlir::Type, unsigned>
705 if (frag == NVVM::MMAFrag::a) {
708 }
else if (frag == NVVM::MMAFrag::b) {
715 assert(nRow && nCol);
720 unsigned addressSpace =
721 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
724 return emitOpError(
"expected source pointer in memory "
727 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
728 getEltype(), getFrag()) == 0)
729 return emitOpError() <<
"invalid attribute combination";
732 Type dstType = LLVM::LLVMStructType::getLiteral(
735 return emitOpError(
"expected destination type is a structure of ")
736 << typeInfo.second <<
" elements of type " << typeInfo.first;
741 unsigned addressSpace =
742 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
745 return emitOpError(
"expected operands to be a source pointer in memory "
748 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
750 return emitOpError() <<
"invalid attribute combination";
753 if (getArgs().size() != typeInfo.second)
754 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
755 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
756 return operands.
getType() != typeInfo.first;
758 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
763 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
764 getLayoutB(), getEltypeA(),
766 return emitOpError() <<
"invalid attribute combination";
774 arguments.append(typeInfoA.second, typeInfoA.first);
775 arguments.append(typeInfoB.second, typeInfoB.first);
776 arguments.append(typeInfoC.second, typeInfoC.first);
777 unsigned numArgs = arguments.size();
778 if (getArgs().size() != numArgs)
779 return emitOpError() <<
"expected " << numArgs <<
" arguments";
780 for (
unsigned i = 0; i < numArgs; i++) {
781 if (getArgs()[i].
getType() != arguments[i])
782 return emitOpError() <<
"expected argument " << i <<
" to be of type "
785 Type dstType = LLVM::LLVMStructType::getLiteral(
788 return emitOpError(
"expected destination type is a structure of ")
789 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
794 unsigned addressSpace =
795 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
797 return emitOpError(
"expected source pointer in memory space 3");
799 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
800 return emitOpError(
"expected num attribute to be 1, 2 or 4");
803 if (getNum() == 1 &&
getType() != i32)
804 return emitOpError(
"expected destination type is i32");
805 if (getNum() == 2 || getNum() == 4) {
806 Type dstType = LLVM::LLVMStructType::getLiteral(
809 return emitOpError(
"expected destination type is a structure of ")
810 << getNum() <<
" elements of type i32";
816 int numMatrix = getSources().size();
817 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
818 return emitOpError(
"expected num attribute to be 1, 2 or 4");
821 if (m == 8 && n == 8) {
822 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
823 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
825 }
else if (m == 16 && n == 8) {
826 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
827 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
829 if (getLayout() != NVVM::MMALayout::col) {
830 return emitOpError(
"expected layout to be col for 16x8 matrix");
833 return emitOpError(
"expected shape to be 8x8 or 16x8");
840 if (typeA == NVVM::WGMMATypes::tf32)
842 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
844 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
846 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
848 if (typeA == NVVM::WGMMATypes::b1)
854 NVVM::WGMMATypes typeA,
855 NVVM::WGMMATypes typeB) {
857 case NVVM::WGMMATypes::f16:
858 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
859 typeB == NVVM::WGMMATypes::f16)
862 case NVVM::WGMMATypes::tf32:
863 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
866 case NVVM::WGMMATypes::u8:
867 case NVVM::WGMMATypes::s8:
868 if (typeD == NVVM::WGMMATypes::s32 &&
869 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
872 case NVVM::WGMMATypes::b1:
873 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
876 case NVVM::WGMMATypes::bf16:
877 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
878 typeB == NVVM::WGMMATypes::bf16)
881 case NVVM::WGMMATypes::e4m3:
882 case NVVM::WGMMATypes::e5m2:
883 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
884 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
887 case WGMMATypes::f32:
888 case WGMMATypes::s32:
889 llvm_unreachable(
"unsupported input types");
897 72, 80, 88, 96, 104, 112, 120, 128,
898 136, 144, 152, 160, 168, 176, 184, 192,
899 200, 208, 216, 224, 232, 240, 248, 256};
901 80, 96, 112, 128, 144, 160,
902 176, 192, 208, 224, 240, 256};
904 case WGMMATypes::f16:
905 case WGMMATypes::tf32:
906 case WGMMATypes::bf16:
907 case WGMMATypes::e4m3:
908 case WGMMATypes::e5m2:
909 if (llvm::is_contained(allowedN, sizeN))
915 if (llvm::is_contained(allowedNshort, sizeN))
918 case WGMMATypes::f32:
919 case WGMMATypes::s32:
920 llvm_unreachable(
"unsupported input types");
927 Value outValue = getResults();
928 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
930 return emitOpError() <<
"expected results to be struct";
931 int outputSize = stype.getBody().size();
932 WGMMATypes typeD = getTypeD();
933 WGMMATypes typeA = getTypeA();
934 WGMMATypes typeB = getTypeB();
936 for (
Type t : stype.getBody()) {
937 if (t != stype.getBody().front())
939 <<
"all elements in struct must be same type but there is " << t;
942 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
943 typeD != WGMMATypes::s32) {
944 return emitOpError() <<
"does not support the given output type "
945 << NVVM::stringifyWGMMATypes(typeD);
947 if (typeD == WGMMATypes::s32 &&
948 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
949 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
953 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
954 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
955 << NVVM::stringifyWGMMATypes(typeB)
956 <<
", it is not supported.";
961 return emitOpError() <<
"shape 'm' must be 64";
965 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
966 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
967 <<
" for input type "
968 << NVVM::stringifyWGMMATypes(typeA);
972 return emitOpError() <<
"has input type "
973 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
974 <<
getShape().getN() <<
", it is not supported.";
981 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
982 (getLayoutA() == mlir::NVVM::MMALayout::col ||
983 getLayoutB() == mlir::NVVM::MMALayout::row)) {
985 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
986 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
987 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
988 << stringifyWGMMATypes(typeB)
989 <<
" requires transpose. However, this is only supported for: "
990 << stringifyMMATypes(MMATypes::f16) <<
" and "
991 << stringifyMMATypes(MMATypes::bf16);
995 int expectedOutput = 0;
996 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
997 expectedOutput =
getShape().getN() / 2;
998 if (typeD == WGMMATypes::f16)
999 expectedOutput =
getShape().getN() / 4;
1000 if (outputSize != expectedOutput) {
1001 return emitOpError() <<
"results " << expectedOutput
1002 <<
", however output struct has " << outputSize
1006 if (typeD != WGMMATypes::s32 &&
1007 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1008 NVVM::MMAIntOverflow::satfinite) {
1009 return emitOpError()
1010 <<
" `satfinite` can be only used with s32 accumulator, however "
1011 "the current accumulator is "
1012 << NVVM::stringifyWGMMATypes(typeD);
1018 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1021 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1023 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1025 int expectedOutputRegisters = 0;
1026 if (getTypeD() == WGMMATypes::f16)
1027 expectedOutputRegisters =
getShape().getN() / 4;
1029 expectedOutputRegisters =
getShape().getN() / 2;
1032 llvm::raw_string_ostream ss(ptx);
1037 << ((expectedOutputRegisters * 2) + 2)
1039 "wgmma.mma_async.sync.aligned.m"
1040 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1041 << stringifyWGMMATypes(getTypeA()) <<
"."
1042 << stringifyWGMMATypes(getTypeB());
1043 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1044 NVVM::MMAIntOverflow::satfinite)
1048 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1049 ss <<
"$" << regCnt;
1050 if (regCnt != expectedOutputRegisters - 1)
1056 regCnt = (regCnt * 2);
1057 ss <<
" $" << (regCnt) <<
","
1058 <<
" $" << (regCnt + 1) <<
","
1060 if (getTypeD() != WGMMATypes::s32) {
1061 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1065 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1072 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1076 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1083 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1085 if (getTypeD() != WGMMATypes::s32) {
1086 asmValues.push_back(
1087 {makeConstantI32(rewriter,
1088 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1090 asmValues.push_back(
1091 {makeConstantI32(rewriter,
1092 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1096 asmValues.push_back(
1097 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1099 asmValues.push_back(
1100 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1105 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1106 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1107 if (getKind() == NVVM::ProxyKind::GENERIC)
1108 return emitOpError() <<
"generic proxy not a supported proxy kind";
1109 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1110 return emitOpError() <<
"async_shared fence requires space attribute";
1112 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1113 return emitOpError() <<
"only async_shared fence can have space attribute";
1119 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1120 return emitOpError(
"uni-directional proxies only support generic for "
1121 "from_proxy attribute");
1123 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1124 return emitOpError(
"uni-directional proxies only support tensormap "
1125 "for to_proxy attribute");
1131 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1132 return emitOpError(
"uni-directional proxies only support generic for "
1133 "from_proxy attribute");
1135 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1136 return emitOpError(
"uni-directional proxies only support tensormap "
1137 "for to_proxy attribute");
1143 if (getRegCount() % 8)
1144 return emitOpError(
"new register size must be multiple of 8");
1145 if (getRegCount() < 24 || getRegCount() > 256)
1146 return emitOpError(
"new register size must be in between 24 to 256");
1151 if (getNumberOfThreads() && !getBarrierId())
1153 "barrier id is missing, it should be set between 0 to 15");
1158 auto mc = getMulticast();
1160 using SH = Tcgen05CpShape;
1161 using MC = Tcgen05CpMulticast;
1163 case SH::SHAPE_128x256b:
1164 case SH::SHAPE_128x128b:
1165 case SH::SHAPE_4x256b:
1167 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1169 case SH::SHAPE_64x128b:
1170 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1171 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1172 "warpx2_02_13 for tcgen05.cp Op");
1174 case SH::SHAPE_32x128b:
1175 if (mc != MC::WARPX4)
1177 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1184 if (getKind() == NVVM::MatchSyncKind::all) {
1185 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1186 if (!type || type.getBody().size() != 2 ||
1187 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1188 return emitOpError(
"match.sync 'all' returns a two element struct with "
1189 "first element as i32 and second element as i1");
1192 if (!
getType().isInteger(32)) {
1193 return emitOpError(
"match.sync 'any' returns an i32");
1200 if (getKind() == NVVM::VoteSyncKind::ballot) {
1201 if (!
getType().isInteger(32)) {
1202 return emitOpError(
"vote.sync 'ballot' returns an i32");
1205 if (!
getType().isInteger(1)) {
1206 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1214 using CacheLevel = NVVM::PrefetchCacheLevel;
1216 unsigned addressSpace =
1217 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1218 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1221 if (getCacheLevel() != CacheLevel::L1)
1222 return emitOpError(
"unsupported cache level, the only supported uniform "
1223 "cache level is L1");
1227 "prefetch to uniform cache requires a generic pointer");
1230 if (evictPriority) {
1231 if (getCacheLevel() != CacheLevel::L2)
1233 "cache eviction priority supported only for cache level L2");
1236 return emitOpError(
"cache eviction priority requires a global pointer");
1238 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1239 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1241 "unsupported cache eviction priority, only evict_last and "
1242 "evict_normal are supported");
1250 static llvm::Value *
1252 llvm::Value *result,
1254 unsigned sizeInBits,
1256 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1258 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1259 if (mask != 0xffffffffu)
1260 field = builder.CreateAnd(field, builder.getInt32(mask));
1262 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1263 field = builder.CreateShl(field, start);
1265 return builder.CreateOr(result, field);
1268 void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1270 llvm::IRBuilderBase &builder) {
1271 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1272 llvm::Value *smemDesc = builder.getInt64(0);
1277 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1279 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1285 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1289 mt.
mapValue(thisOp.getRes()) = smemDesc;
1296 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1297 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1299 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1300 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1307 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1308 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1309 switch (cpAsyncOp.getSize()) {
1317 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1322 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1326 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1327 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1329 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1336 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1341 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1345 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1346 llvm::Value *i64Unused =
1348 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1349 args.push_back(builder.getInt1(hasCacheHint));
1351 return {id, std::move(args)};
1356 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1359 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1362 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
1363 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1367 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1368 llvm::Value *i64Unused =
1370 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1371 args.push_back(builder.getInt1(hasCacheHint));
1374 if (
mlir::Value byteMask = thisOp.getByteMask()) {
1376 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1379 return {id, std::move(args)};
1384 switch (tensorDims) {
1386 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1388 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1391 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1392 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1395 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1396 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1399 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1400 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1402 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1406 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1407 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1409 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1410 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1411 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1413 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1417 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1419 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1421 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1423 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1425 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1427 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1432 int tensorDims, NVVM::TMAReduxKind
kind,
bool isIm2Col) {
1433 using RedTy = NVVM::TMAReduxKind;
1452 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1457 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1458 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1459 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1461 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1462 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1463 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1466 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1467 NVVM::SaturationMode sat,
bool hasRelu) {
1468 using RndMode = NVVM::FPRoundingMode;
1469 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1478 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
1482 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1483 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1484 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1487 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type,
bool hasRelu) {
1489 case NVVM::ConvertFP6Type::E2M3:
1491 case NVVM::ConvertFP6Type::E3M2:
1494 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
1497 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1498 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1499 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1501 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1502 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1503 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1506 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1507 NVVM::FPRoundingMode rnd,
1508 NVVM::SaturationMode sat,
bool hasRelu) {
1509 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1510 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1511 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1514 case NVVM::ConvertFP8Type::E4M3:
1516 case NVVM::ConvertFP8Type::E5M2:
1518 case NVVM::ConvertFP8Type::UE8M0:
1519 if (hasRoundingModeRZ)
1521 else if (hasRoundingModeRP)
1524 llvm_unreachable(
"Invalid conversion in CvtFloatToF8x2Op");
1527 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1528 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1529 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1532 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
bool hasRelu) {
1534 case NVVM::ConvertFP8Type::E4M3:
1536 case NVVM::ConvertFP8Type::E5M2:
1539 llvm_unreachable(
"Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1543 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1544 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1545 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1548 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1549 NVVM::SaturationMode sat) {
1550 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1552 case NVVM::FPRoundingMode::RZ:
1554 case NVVM::FPRoundingMode::RP:
1557 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
1562 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
1565 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1566 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1569 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1573 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1574 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1576 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1577 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1590 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1591 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1592 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1593 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1602 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1603 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1604 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1606 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1607 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1608 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1611 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
1614 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1615 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1618 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
1619 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1628 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
1633 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1634 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1636 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1637 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1638 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1640 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1642 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1643 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1644 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1645 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1646 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1650 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1651 bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1652 auto srcFmt = curOp.getSrcFormat();
1653 auto mc = curOp.getMulticast();
1655 switch (curOp.getShape()) {
1656 case Tcgen05CpShape::SHAPE_128x256b:
1658 case Tcgen05CpShape::SHAPE_128x128b:
1660 case Tcgen05CpShape::SHAPE_4x256b:
1662 case Tcgen05CpShape::SHAPE_32x128b:
1664 case Tcgen05CpShape::SHAPE_64x128b:
1665 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1669 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
1676 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1678 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1684 LogicalResult result = success();
1685 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1686 result =
emitError(
"shape 16x32bx2 requires offset argument");
1688 auto resTy = getRes().getType();
1689 unsigned resLen = isa<VectorType>(resTy)
1690 ? llvm::cast<VectorType>(resTy).getNumElements()
1693 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
1694 "{1} in tcgen05.ld Op",
1695 resLen, stringifyEnum(
getShape())));
1701 LogicalResult result = success();
1702 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1703 result =
emitError(
"shape 16x32bx2 requires offset argument");
1705 auto valTy = getVal().getType();
1706 unsigned valLen = isa<VectorType>(valTy)
1707 ? llvm::cast<VectorType>(valTy).getNumElements()
1710 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
1711 "{1} in tcgen05.st Op",
1712 valLen, stringifyEnum(
getShape())));
1722 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
1723 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1724 rangeAttr.getLower(), rangeAttr.getUpper()});
1729 llvm::IRBuilderBase &builder) {
1730 return builder.CreateBitCast(arg,
1731 llvm::Type::getInt32Ty(builder.getContext()));
1736 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1743 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1744 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1745 unsigned type = (isASigned << 1) | isBSigned;
1747 llvm::Intrinsic::nvvm_idp4a_u_u,
1748 llvm::Intrinsic::nvvm_idp4a_u_s,
1749 llvm::Intrinsic::nvvm_idp4a_s_u,
1750 llvm::Intrinsic::nvvm_idp4a_s_s,
1752 return {ids[type], args};
1757 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1762 args.push_back(builder.getInt1(curOp.getBHi()));
1765 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1766 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1767 unsigned type = (isASigned << 1) | isBSigned;
1769 llvm::Intrinsic::nvvm_idp2a_u_u,
1770 llvm::Intrinsic::nvvm_idp2a_u_s,
1771 llvm::Intrinsic::nvvm_idp2a_s_u,
1772 llvm::Intrinsic::nvvm_idp2a_s_s,
1774 return {ids[type], args};
1779 using CacheLevel = NVVM::PrefetchCacheLevel;
1781 NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1782 std::optional<NVVM::CacheEvictionPriority> evictPriority =
1783 op.getEvictPriority();
1784 unsigned addressSpace =
1785 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1788 if (op.getUniform() && cacheLevel == CacheLevel::L1)
1789 return llvm::Intrinsic::nvvm_prefetchu_L1;
1791 if (evictPriority && cacheLevel == CacheLevel::L2) {
1792 switch (*evictPriority) {
1793 case NVVM::CacheEvictionPriority::EvictLast:
1794 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1795 case NVVM::CacheEvictionPriority::EvictNormal:
1796 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1798 llvm_unreachable(
"Invalid cache eviction priority");
1802 switch (addressSpace) {
1804 return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1805 : llvm::Intrinsic::nvvm_prefetch_L2;
1807 return cacheLevel == CacheLevel::L1
1808 ? llvm::Intrinsic::nvvm_prefetch_global_L1
1809 : llvm::Intrinsic::nvvm_prefetch_global_L2;
1811 return cacheLevel == CacheLevel::L1
1812 ? llvm::Intrinsic::nvvm_prefetch_local_L1
1813 : llvm::Intrinsic::nvvm_prefetch_local_L2;
1815 llvm_unreachable(
"Invalid pointer address space");
1824 void NVVMDialect::initialize() {
1827 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1830 #define GET_ATTRDEF_LIST
1831 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1836 allowUnknownOperations();
1837 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1838 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1841 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1843 StringAttr attrName = attr.
getName();
1845 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1846 if (!isa<LLVM::LLVMFuncOp>(op)) {
1847 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1848 <<
"' attribute attached to unexpected op";
1853 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1854 attrName == NVVMDialect::getReqntidAttrName() ||
1855 attrName == NVVMDialect::getClusterDimAttrName()) {
1856 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1857 if (!values || values.empty() || values.size() > 3)
1860 <<
"' attribute must be integer array with maximum 3 index";
1864 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1865 attrName == NVVMDialect::getMaxnregAttrName() ||
1866 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1867 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1869 <<
"'" << attrName <<
"' attribute must be integer constant";
1875 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1876 unsigned regionIndex,
1879 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1883 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1884 StringAttr attrName = argAttr.
getName();
1885 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1889 <<
"' attribute must be present only on kernel arguments";
1891 if (!isa<UnitAttr>(argAttr.
getValue()))
1892 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1893 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1896 <<
"' attribute requires the argument to also have attribute '"
1897 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1909 int optLevel, StringRef triple, StringRef chip,
1910 StringRef features, DictionaryAttr flags,
1911 ArrayAttr files,
bool verifyTarget) {
1912 if (optLevel < 0 || optLevel > 3) {
1913 emitError() <<
"The optimization level must be a number between 0 and 3.";
1916 if (triple.empty()) {
1917 emitError() <<
"The target triple cannot be empty.";
1921 emitError() <<
"The target chip cannot be empty.";
1925 return mlir::isa_and_nonnull<StringAttr>(attr);
1927 emitError() <<
"All the elements in the `link` array must be strings.";
1933 LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
1934 if (!getVerifyTarget())
1937 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1940 "NVVM target attribute must be attached to a GPU module");
1944 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1947 "Minimum NVVM target SM version is sm_20");
1951 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1954 op->emitOpError() <<
"is not supported on " << getChip();
1955 return WalkResult::interrupt();
1964 #define GET_OP_CLASSES
1965 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1967 #define GET_ATTRDEF_CLASSES
1968 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1225::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
NVVMMemorySpace
NVVM memory space identifiers.
@ kGenericMemorySpace
Generic memory space identifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kLocalMemorySpace
Local memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)